| | import torch |
| | from configuration_neuroclr import NeuroCLRConfig |
| | from modeling_neuroclr import NeuroCLRModel |
| |
|
| | |
| | CFG = dict( |
| | TSlength=128, |
| | nhead=2, |
| | nlayer=2, |
| | projector_out1=128, |
| | projector_out2=64, |
| | pooling="flatten", |
| | normalize_input=True, |
| | ) |
| | CKPT_PATH = "" |
| | OUT_DIR = "." |
| | |
| |
|
| | def remap_state_dict(sd): |
| | new_sd = {} |
| | for k, v in sd.items(): |
| | k2 = k.replace("module.", "") |
| | if k2.startswith("transformer_encoder.") or k2.startswith("projector."): |
| | new_sd["neuroclr." + k2] = v |
| | else: |
| | |
| | new_sd[k2] = v |
| | return new_sd |
| |
|
| | def main(): |
| | config = NeuroCLRConfig(**CFG) |
| |
|
| | |
| | config.auto_map = { |
| | "AutoConfig": "configuration_neuroclr.NeuroCLRConfig", |
| | "AutoModel": "modeling_neuroclr.NeuroCLRModel", |
| | } |
| |
|
| | model = NeuroCLRModel(config) |
| |
|
| | ckpt = torch.load(CKPT_PATH, map_location="cpu") |
| |
|
| | |
| | if isinstance(ckpt, dict) and "model_state_dict" in ckpt: |
| | sd = ckpt["model_state_dict"] |
| | elif isinstance(ckpt, dict) and "state_dict" in ckpt: |
| | sd = ckpt["state_dict"] |
| | else: |
| | sd = ckpt |
| |
|
| | sd = remap_state_dict(sd) |
| |
|
| | missing, unexpected = model.load_state_dict(sd, strict=False) |
| | print("Missing:", missing) |
| | print("Unexpected:", unexpected) |
| |
|
| | model.save_pretrained(OUT_DIR, safe_serialization=True) |
| | print("Saved HF pretraining model to:", OUT_DIR) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|