Music VAE β CNN pretrained on the Lakh MIDI Dataset
Overview
This is a Convolutional Variational Autoencoder (CNN VAE) pretrained on the Lakh MIDI Dataset (lmd_full, ~175k MIDI files).
It was trained as part of a machine learning course assignment at Purdue University to give students a meaningful starting point for music-generation tasks.
Input / Output Format
| Property | Value |
|---|---|
| Input shape | [batch, 1, 88, 32] β float32 |
| Output shape | [batch, 1, 88, 32] β float32 |
| Pitch range | MIDI 21β108 (A0 β C8, 88 keys) |
| Time resolution | 16th notes at 120 BPM |
| Segment length | 2 bars = 32 timesteps |
| Value range | [0, 1] (Sigmoid output) |
A tensor value of 1 at position [pitch_idx, time_step] means that pitch
21 + pitch_idx is active at that 16th-note time step.
Architecture Summary
ENCODER
Conv2d(1 β 32, k=4, s=2, p=1) + ReLU + BN β [B, 32, 44, 16]
Conv2d(32 β 64, k=4, s=2, p=1) + ReLU + BN β [B, 64, 22, 8]
Conv2d(64 β 128,k=4, s=2, p=1) + ReLU + BN β [B, 128, 11, 4]
Conv2d(128β 256,k=4, s=2, p=1) + ReLU + BN β [B, 256, 5, 2]
Flatten β 2560
Linear β mu [B, 256]
Linear β log_var [B, 256]
REPARAMETERISATION
z = mu + eps * exp(0.5 * log_var), eps ~ N(0, I)
DECODER
Linear(256 β 2560) β Reshape [B, 256, 5, 2]
ConvTranspose2d(256β128, k=4, s=2, p=1, output_padding=(1,0)) β [B, 128, 11, 4]
ConvTranspose2d(128β 64, k=4, s=2, p=1) β [B, 64, 22, 8]
ConvTranspose2d( 64β 32, k=4, s=2, p=1) β [B, 32, 44, 16]
ConvTranspose2d( 32β 1, k=4, s=2, p=1) + Sigmoid β [B, 1, 88, 32]
- Latent dimension: 256
- Trainable parameters: ~4.2M
Loading the Model (Course Assignment)
import torch
from model import MusicVAE # copy src/model.py into your project
# Load checkpoint
ckpt = torch.load("best_model.pt", map_location="cpu")
config = ckpt["config"]
model = MusicVAE(latent_dim=config["latent_dim"])
model.load_state_dict(ckpt["model_state"])
model.eval()
# Generate new piano rolls
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
samples = model.sample(n=4, device=device) # [4, 1, 88, 32]
# Encode a segment and reconstruct it
x = ... # your [1, 1, 88, 32] piano-roll tensor
x_recon, mu, log_var = model(x.to(device))
# Interpolate between two points in latent space
z1 = mu[0:1]
z2 = mu[1:2] # second example
interp = model.interpolate(z1, z2, steps=8)
Also see src/utils.py for pianoroll_to_midi() and visualize_pianoroll().
Training Details
- Dataset: Lakh MIDI Dataset (lmd_full)
- Piano roll: 88-pitch binary, 16th-note resolution, 120 BPM normalised
- Segments: 2 bars (32 frames), stride 1 bar (16 frames)
- Loss: BCE reconstruction + Ξ²-annealed KL (Ξ²: 0 β 1 over 50 epochs) + free bits (Ξ»=0.5)
- Optimizer: Adam, lr=1e-3, ReduceLROnPlateau (patience=10, factor=0.5, min_lr=1e-5)
- Batch size: 256 | Epochs: 100 | Gradient clip: 1.0
Citation
If you use this model in your work, please cite the Lakh MIDI Dataset:
@inproceedings{Raffel2016,
author = {Colin Raffel},
title = {Learning-Based Methods for Comparing Sequences, with Applications
to Audio-to-{MIDI} Alignment and Matching},
booktitle = {PhD Thesis, Columbia University},
year = {2016}
}
- Downloads last month
- 25
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support