Music VAE β€” CNN pretrained on the Lakh MIDI Dataset

Overview

This is a Convolutional Variational Autoencoder (CNN VAE) pretrained on the Lakh MIDI Dataset (lmd_full, ~175k MIDI files).

It was trained as part of a machine learning course assignment at Purdue University to give students a meaningful starting point for music-generation tasks.

Input / Output Format

Property Value
Input shape [batch, 1, 88, 32] β€” float32
Output shape [batch, 1, 88, 32] β€” float32
Pitch range MIDI 21–108 (A0 – C8, 88 keys)
Time resolution 16th notes at 120 BPM
Segment length 2 bars = 32 timesteps
Value range [0, 1] (Sigmoid output)

A tensor value of 1 at position [pitch_idx, time_step] means that pitch 21 + pitch_idx is active at that 16th-note time step.

Architecture Summary

ENCODER
  Conv2d(1 β†’ 32,  k=4, s=2, p=1) + ReLU + BN   β†’  [B, 32,  44, 16]
  Conv2d(32 β†’ 64, k=4, s=2, p=1) + ReLU + BN   β†’  [B, 64,  22,  8]
  Conv2d(64 β†’ 128,k=4, s=2, p=1) + ReLU + BN   β†’  [B, 128, 11,  4]
  Conv2d(128β†’ 256,k=4, s=2, p=1) + ReLU + BN   β†’  [B, 256,  5,  2]
  Flatten β†’ 2560
  Linear β†’ mu      [B, 256]
  Linear β†’ log_var [B, 256]

REPARAMETERISATION
  z = mu + eps * exp(0.5 * log_var),  eps ~ N(0, I)

DECODER
  Linear(256 β†’ 2560) β†’ Reshape [B, 256, 5, 2]
  ConvTranspose2d(256β†’128, k=4, s=2, p=1, output_padding=(1,0)) β†’ [B, 128, 11,  4]
  ConvTranspose2d(128β†’ 64, k=4, s=2, p=1) β†’ [B,  64, 22,  8]
  ConvTranspose2d( 64β†’ 32, k=4, s=2, p=1) β†’ [B,  32, 44, 16]
  ConvTranspose2d( 32β†’  1, k=4, s=2, p=1) + Sigmoid β†’ [B, 1, 88, 32]
  • Latent dimension: 256
  • Trainable parameters: ~4.2M

Loading the Model (Course Assignment)

import torch
from model import MusicVAE   # copy src/model.py into your project

# Load checkpoint
ckpt = torch.load("best_model.pt", map_location="cpu")
config = ckpt["config"]

model = MusicVAE(latent_dim=config["latent_dim"])
model.load_state_dict(ckpt["model_state"])
model.eval()

# Generate new piano rolls
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
samples = model.sample(n=4, device=device)   # [4, 1, 88, 32]

# Encode a segment and reconstruct it
x = ...   # your [1, 1, 88, 32] piano-roll tensor
x_recon, mu, log_var = model(x.to(device))

# Interpolate between two points in latent space
z1 = mu[0:1]
z2 = mu[1:2]   # second example
interp = model.interpolate(z1, z2, steps=8)

Also see src/utils.py for pianoroll_to_midi() and visualize_pianoroll().

Training Details

  • Dataset: Lakh MIDI Dataset (lmd_full)
  • Piano roll: 88-pitch binary, 16th-note resolution, 120 BPM normalised
  • Segments: 2 bars (32 frames), stride 1 bar (16 frames)
  • Loss: BCE reconstruction + Ξ²-annealed KL (Ξ²: 0 β†’ 1 over 50 epochs) + free bits (Ξ»=0.5)
  • Optimizer: Adam, lr=1e-3, ReduceLROnPlateau (patience=10, factor=0.5, min_lr=1e-5)
  • Batch size: 256 | Epochs: 100 | Gradient clip: 1.0

Citation

If you use this model in your work, please cite the Lakh MIDI Dataset:

@inproceedings{Raffel2016,
  author    = {Colin Raffel},
  title     = {Learning-Based Methods for Comparing Sequences, with Applications
               to Audio-to-{MIDI} Alignment and Matching},
  booktitle = {PhD Thesis, Columbia University},
  year      = {2016}
}
Downloads last month
25
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support