I propose modifying the KORMo modelling to ensure compatibility with both Transformers 4.57.1 and 5.2.

#3
by jungsin3 - opened
Files changed (1) hide show
  1. _modeling_kormo.py +31 -3
_modeling_kormo.py CHANGED
@@ -94,7 +94,13 @@ def rotate_half(x):
94
  x1 = x[..., : x.shape[-1] // 2]
95
  x2 = x[..., x.shape[-1] // 2 :]
96
  return torch.cat((-x2, x1), dim=-1)
97
-
 
 
 
 
 
 
98
  class Attention(nn.Module):
99
  """Multi-headed attention from 'Attention Is All You Need' paper"""
100
 
@@ -237,11 +243,24 @@ class RotaryEmbedding(nn.Module):
237
  self.original_max_seq_len = config.max_position_embeddings
238
 
239
  self.config = config
240
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
 
 
241
 
242
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
243
  self.register_buffer("inv_freq", inv_freq, persistent=False)
244
- self.original_inv_freq = self.inv_freq
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  @torch.no_grad()
247
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
@@ -289,6 +308,15 @@ class KORMoPreTrainedModel(PreTrainedModel):
289
  module.weight.data[module.padding_idx].zero_()
290
  elif isinstance(module, RMSNorm):
291
  module.weight.data.fill_(1.0)
 
 
 
 
 
 
 
 
 
292
 
293
 
294
  class KORMoModel(KORMoPreTrainedModel):
 
94
  x1 = x[..., : x.shape[-1] // 2]
95
  x2 = x[..., x.shape[-1] // 2 :]
96
  return torch.cat((-x2, x1), dim=-1)
97
+
98
+ def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
99
+ if not getattr(tensor, "_is_hf_initialized", False):
100
+ with torch.no_grad():
101
+ return tensor.copy_(other)
102
+ return tensor
103
+
104
  class Attention(nn.Module):
105
  """Multi-headed attention from 'Attention Is All You Need' paper"""
106
 
 
243
  self.original_max_seq_len = config.max_position_embeddings
244
 
245
  self.config = config
246
+ rope_init_fn = self.compute_default_rope_parameters
247
+ rope_init_fn = ROPE_INIT_FUNCTIONS.get(self.rope_type, rope_init_fn)
248
+ self.rope_init_fn = rope_init_fn
249
 
250
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
251
  self.register_buffer("inv_freq", inv_freq, persistent=False)
252
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
253
+
254
+ @staticmethod
255
+ def compute_default_rope_parameters(config: KORMoConfig, device=None, seq_len =None):
256
+ base = config.rope_theta
257
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
258
+
259
+ attention_factor = 1.0
260
+ inv_freq = 1.0 / (
261
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
262
+ )
263
+ return inv_freq, attention_factor
264
 
265
  @torch.no_grad()
266
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
 
308
  module.weight.data[module.padding_idx].zero_()
309
  elif isinstance(module, RMSNorm):
310
  module.weight.data.fill_(1.0)
311
+ elif "RotaryEmbedding" in module.__class__.__name__ and hasattr(module, "original_inv_freq"):
312
+ rope_fn = (
313
+ ROPE_INIT_FUNCTIONS[module.rope_type]
314
+ if module.rope_type != "default"
315
+ else module.compute_default_rope_parameters
316
+ )
317
+ buffer_value, _ = rope_fn(module.config)
318
+ copy_(module.inv_freq, buffer_value)
319
+ copy_(module.original_inv_freq, buffer_value)
320
 
321
 
322
  class KORMoModel(KORMoPreTrainedModel):