| |
| import os |
| import torch |
| import torch.nn as nn |
| import onnx |
| from sam_audio.model.vision_encoder import PerceptionEncoder |
| from onnx_export.standalone_config import PerceptionEncoderConfig |
|
|
| class VisionEncoderWrapper(nn.Module): |
| """ |
| Wrapper for the Vision Encoder (CLIP visual backbone). |
| """ |
| def __init__(self, vision_encoder): |
| super().__init__() |
| self.model = vision_encoder.model |
| self.normalize = vision_encoder.normalize_feature |
|
|
| def forward(self, x): |
| |
| |
| return self.model.encode_image(x, normalize=self.normalize) |
|
|
| def export_vision_encoder(model_id="facebook/sam-audio-small", output_dir="onnx_models"): |
| """Export the vision encoder to ONNX.""" |
| print(f"Loading Vision Encoder from {model_id}...") |
| |
| import torch |
| from transformers import AutoConfig |
| from sam_audio.model.vision_encoder import PerceptionEncoder |
| from onnx_export.standalone_config import PerceptionEncoderConfig |
| |
| print("Fetching config...") |
| cfg_hf = AutoConfig.from_pretrained(model_id) |
| cfg_dict = cfg_hf.to_dict() |
| |
| |
| v_cfg_dict = cfg_dict.get("vision_encoder", {}) |
| v_cfg = PerceptionEncoderConfig(**v_cfg_dict) |
| |
| print(f"Initializing PerceptionEncoder with name: {v_cfg.name}...") |
| vision_encoder = PerceptionEncoder(v_cfg) |
| |
| |
| print("Loading weights from SAM Audio checkpoint...") |
| from huggingface_hub import hf_hub_download |
| checkpoint_path = hf_hub_download(repo_id=model_id, filename="checkpoint.pt") |
| state_dict = torch.load(checkpoint_path, map_location="cpu", mmap=True) |
| |
| |
| vision_state = {} |
| prefix = "vision_encoder." |
| for key, value in state_dict.items(): |
| if key.startswith(prefix): |
| new_key = key[len(prefix):] |
| vision_state[new_key] = value |
| |
| if vision_state: |
| print(f" Loading {len(vision_state)} tensors into vision encoder...") |
| vision_encoder.load_state_dict(vision_state) |
| print(" ✓ Vision encoder weights loaded.") |
| else: |
| print(" WARNING: No 'vision_encoder' weights found in checkpoint. Using base weights.") |
|
|
| image_size = vision_encoder.image_size |
| print(f" Image size: {image_size}") |
|
|
|
|
| wrapper = VisionEncoderWrapper(vision_encoder).eval() |
| |
| |
| image_size = vision_encoder.image_size |
| dummy_input = torch.randn(1, 3, image_size, image_size) |
| |
| output_path = os.path.join(output_dir, "vision_encoder.onnx") |
| os.makedirs(output_dir, exist_ok=True) |
| |
| print(f"Exporting to {output_path}...") |
| input_names = ["video_frames"] |
| output_names = ["vision_features"] |
| opset_version = 17 |
| torch.onnx.export( |
| wrapper, |
| dummy_input, |
| output_path, |
| input_names=input_names, |
| output_names=output_names, |
| dynamic_axes={ |
| "video_frames": {0: "num_frames"}, |
| "vision_features": {0: "num_frames"}, |
| }, |
| opset_version=opset_version, |
| do_constant_folding=True, |
| dynamo=False, |
| external_data=True, |
| ) |
| |
| |
| data_path = output_path + ".data" |
| if os.path.exists(data_path): |
| print(f" Large model detected, weights saved to {data_path}") |
| |
| print("✓ Vision encoder export complete!") |
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", type=str, default="facebook/sam-audio-small") |
| parser.add_argument("--output", type=str, default="onnx_models") |
| args = parser.parse_args() |
| |
| export_vision_encoder(args.model, args.output) |
|
|