| | import pdb |
| | from typing import Tuple |
| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel |
| | import argparse |
| | import importlib |
| | import json |
| | import math |
| | import multiprocessing as mp |
| | import os |
| | import time |
| | from argparse import Namespace |
| | from pathlib import Path |
| |
|
| | |
| | import scipy |
| | import numpy as np |
| |
|
| | scipy.inf = np.inf |
| |
|
| | import librosa |
| | import torch |
| | from ema_pytorch import EMA |
| | from loguru import logger |
| | from muq import MuQ |
| | from musicfm.model.musicfm_25hz import MusicFM25Hz |
| | from omegaconf import OmegaConf |
| | from tqdm import tqdm |
| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from configuration_songformer import SongFormerConfig |
| | from model_config import ModelConfig |
| |
|
| | from model import Model |
| | from omegaconf import OmegaConf |
| |
|
| | |
| | MUSICFM_HOME_PATH = "/home/node59_tmpdata3/cbhao/SongFormer_kaiyuan_test/github_test/SongFormer/src/SongFormer/ckpts/MusicFM" |
| |
|
| | BEFORE_DOWNSAMPLING_FRAME_RATES = 25 |
| | AFTER_DOWNSAMPLING_FRAME_RATES = 8.333 |
| |
|
| | DATASET_LABEL = "SongForm-HX-8Class" |
| | DATASET_IDS = [5] |
| |
|
| | TIME_DUR = 420 |
| | INPUT_SAMPLING_RATE = 24000 |
| |
|
| | from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID |
| | from postprocessing.functional import postprocess_functional_structure |
| |
|
| |
|
| | def rule_post_processing(msa_list): |
| | if len(msa_list) <= 2: |
| | return msa_list |
| |
|
| | result = msa_list.copy() |
| |
|
| | while len(result) > 2: |
| | first_duration = result[1][0] - result[0][0] |
| | if first_duration < 1.0 and len(result) > 2: |
| | result[0] = (result[0][0], result[1][1]) |
| | result = [result[0]] + result[2:] |
| | else: |
| | break |
| |
|
| | while len(result) > 2: |
| | last_label_duration = result[-1][0] - result[-2][0] |
| | if last_label_duration < 1.0: |
| | result = result[:-2] + [result[-1]] |
| | else: |
| | break |
| |
|
| | while len(result) > 2: |
| | if result[0][1] == result[1][1] and result[1][0] <= 10.0: |
| | result = [(result[0][0], result[0][1])] + result[2:] |
| | else: |
| | break |
| |
|
| | while len(result) > 2: |
| | last_duration = result[-1][0] - result[-2][0] |
| | if result[-2][1] == result[-3][1] and last_duration <= 10.0: |
| | result = result[:-2] + [result[-1]] |
| | else: |
| | break |
| |
|
| | return result |
| |
|
| |
|
| | class SongFormerModel(PreTrainedModel): |
| | config_class = SongFormerConfig |
| |
|
| | def __init__(self, config: SongFormerConfig): |
| | super().__init__(config) |
| | device = "cpu" |
| | root_dir = os.environ["SONGFORMER_LOCAL_DIR"] |
| | with open(os.path.join(root_dir, "muq_config2.json"), "r") as f: |
| | muq_config_file = OmegaConf.load(f) |
| | |
| | self.muq = MuQ(muq_config_file) |
| |
|
| | self.musicfm = MusicFM25Hz( |
| | is_flash=False, |
| | stat_path=os.path.join(root_dir, "msd_stats.json"), |
| | |
| | ) |
| | self.songformer = Model(ModelConfig()) |
| |
|
| | num_classes = config.num_classes |
| | dataset_id2label_mask = {} |
| | for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items(): |
| | dataset_id2label_mask[key] = np.ones(config.num_classes, dtype=bool) |
| | dataset_id2label_mask[key][allowed_ids] = False |
| |
|
| | self.num_classes = num_classes |
| | self.dataset_id2label_mask = dataset_id2label_mask |
| | self.config = config |
| |
|
| | def forward(self, input): |
| | with torch.no_grad(): |
| | INPUT_SAMPLING_RATE = 24000 |
| |
|
| | device = next(self.parameters()).device |
| | |
| | if isinstance(input, (torch.Tensor, np.ndarray)): |
| | audio = torch.tensor(input).to(device) |
| | elif os.path.exists(input): |
| | wav, sr = librosa.load(input, sr=INPUT_SAMPLING_RATE) |
| | audio = torch.tensor(wav).to(device) |
| | else: |
| | raise ValueError("input should be a tensor/numpy or a valid file path") |
| |
|
| | win_size = self.config.win_size |
| | hop_size = self.config.hop_size |
| | num_classes = self.config.num_classes |
| | total_len = ( |
| | (audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR |
| | ) * TIME_DUR + TIME_DUR |
| | total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES) |
| |
|
| | logits = { |
| | "function_logits": np.zeros([total_frames, num_classes]), |
| | "boundary_logits": np.zeros([total_frames]), |
| | } |
| | logits_num = { |
| | "function_logits": np.zeros([total_frames, num_classes]), |
| | "boundary_logits": np.zeros([total_frames]), |
| | } |
| |
|
| | lens = 0 |
| | i = 0 |
| | while True: |
| | start_idx = i * INPUT_SAMPLING_RATE |
| | end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1]) |
| | if start_idx >= audio.shape[-1]: |
| | break |
| | if end_idx - start_idx <= 1024: |
| | continue |
| | audio_seg = audio[start_idx:end_idx] |
| |
|
| | |
| | muq_output = self.muq(audio_seg.unsqueeze(0), output_hidden_states=True) |
| | muq_embd_420s = muq_output["hidden_states"][10] |
| | del muq_output |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | _, musicfm_hidden_states = self.musicfm.get_predictions( |
| | audio_seg.unsqueeze(0) |
| | ) |
| | musicfm_embd_420s = musicfm_hidden_states[10] |
| | del musicfm_hidden_states |
| | torch.cuda.empty_cache() |
| |
|
| | wraped_muq_embd_30s = [] |
| | wraped_musicfm_embd_30s = [] |
| |
|
| | for idx_30s in range(i, i + hop_size, 30): |
| | start_idx_30s = idx_30s * INPUT_SAMPLING_RATE |
| | end_idx_30s = min( |
| | (idx_30s + 30) * INPUT_SAMPLING_RATE, |
| | audio.shape[-1], |
| | (i + hop_size) * INPUT_SAMPLING_RATE, |
| | ) |
| | if start_idx_30s >= audio.shape[-1]: |
| | break |
| | if end_idx_30s - start_idx_30s <= 1024: |
| | continue |
| | wraped_muq_embd_30s.append( |
| | self.muq( |
| | audio[start_idx_30s:end_idx_30s].unsqueeze(0), |
| | output_hidden_states=True, |
| | )["hidden_states"][10] |
| | ) |
| | torch.cuda.empty_cache() |
| | wraped_musicfm_embd_30s.append( |
| | self.musicfm.get_predictions( |
| | audio[start_idx_30s:end_idx_30s].unsqueeze(0) |
| | )[1][10] |
| | ) |
| | torch.cuda.empty_cache() |
| |
|
| | wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1) |
| | wraped_musicfm_embd_30s = torch.concatenate( |
| | wraped_musicfm_embd_30s, dim=1 |
| | ) |
| | all_embds = [ |
| | wraped_musicfm_embd_30s, |
| | wraped_muq_embd_30s, |
| | musicfm_embd_420s, |
| | muq_embd_420s, |
| | ] |
| |
|
| | if len(all_embds) > 1: |
| | embd_lens = [x.shape[1] for x in all_embds] |
| | max_embd_len = max(embd_lens) |
| | min_embd_len = min(embd_lens) |
| | if abs(max_embd_len - min_embd_len) > 4: |
| | raise ValueError( |
| | f"Embedding shapes differ too much: {max_embd_len} vs {min_embd_len}" |
| | ) |
| |
|
| | for idx in range(len(all_embds)): |
| | all_embds[idx] = all_embds[idx][:, :min_embd_len, :] |
| |
|
| | embd = torch.concatenate(all_embds, axis=-1) |
| |
|
| | dataset_label = DATASET_LABEL |
| | dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long) |
| | msa_info, chunk_logits = self.songformer.infer( |
| | input_embeddings=embd, |
| | dataset_ids=dataset_ids, |
| | label_id_masks=torch.Tensor( |
| | self.dataset_id2label_mask[ |
| | DATASET_LABEL_TO_DATASET_ID[dataset_label] |
| | ] |
| | ) |
| | .to(device, dtype=bool) |
| | .unsqueeze(0) |
| | .unsqueeze(0), |
| | with_logits=True, |
| | ) |
| |
|
| | start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES) |
| | end_frame = start_frame + min( |
| | math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES), |
| | chunk_logits["boundary_logits"][0].shape[0], |
| | ) |
| |
|
| | logits["function_logits"][start_frame:end_frame, :] += ( |
| | chunk_logits["function_logits"][0].detach().cpu().numpy() |
| | ) |
| | logits["boundary_logits"][start_frame:end_frame] = ( |
| | chunk_logits["boundary_logits"][0].detach().cpu().numpy() |
| | ) |
| | logits_num["function_logits"][start_frame:end_frame, :] += 1 |
| | logits_num["boundary_logits"][start_frame:end_frame] += 1 |
| | lens += end_frame - start_frame |
| |
|
| | i += hop_size |
| | logits["function_logits"] /= logits_num["function_logits"] |
| | logits["boundary_logits"] /= logits_num["boundary_logits"] |
| |
|
| | logits["function_logits"] = torch.from_numpy( |
| | logits["function_logits"][:lens] |
| | ).unsqueeze(0) |
| | logits["boundary_logits"] = torch.from_numpy( |
| | logits["boundary_logits"][:lens] |
| | ).unsqueeze(0) |
| |
|
| | msa_infer_output = postprocess_functional_structure(logits, self.config) |
| |
|
| | assert msa_infer_output[-1][-1] == "end" |
| | if not self.config.no_rule_post_processing: |
| | msa_infer_output = rule_post_processing(msa_infer_output) |
| | msa_json = [] |
| | for idx in range(len(msa_infer_output) - 1): |
| | msa_json.append( |
| | { |
| | "label": msa_infer_output[idx][1], |
| | "start": msa_infer_output[idx][0], |
| | "end": msa_infer_output[idx + 1][0], |
| | } |
| | ) |
| | return msa_json |
| |
|
| | @staticmethod |
| | def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]: |
| | """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" |
| |
|
| | |
| | if key.startswith("muq."): |
| | return key, False |
| | |
| |
|
| | |
| | |
| | if key.endswith("LayerNorm.beta"): |
| | return key.replace("LayerNorm.beta", "LayerNorm.bias"), True |
| | if key.endswith("LayerNorm.gamma"): |
| | return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True |
| |
|
| | |
| | |
| | |
| | if hasattr(nn.utils.parametrizations, "weight_norm"): |
| | if key.endswith("weight_g"): |
| | return key.replace( |
| | "weight_g", "parametrizations.weight.original0" |
| | ), True |
| | if key.endswith("weight_v"): |
| | return key.replace( |
| | "weight_v", "parametrizations.weight.original1" |
| | ), True |
| | else: |
| | if key.endswith("parametrizations.weight.original0"): |
| | return key.replace( |
| | "parametrizations.weight.original0", "weight_g" |
| | ), True |
| | if key.endswith("parametrizations.weight.original1"): |
| | return key.replace( |
| | "parametrizations.weight.original1", "weight_v" |
| | ), True |
| |
|
| | return key, False |
| |
|