import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import wave
import numpy as np
import os
import glob
from pathlib import Path
import matplotlib.pyplot as plt
class AudioDataset(Dataset):
"""Dataset for loading and processing audio improvisations"""
def __init__(self, audio_dir, sequence_length=1024, sample_rate=22050):
self.audio_dir = Path(audio_dir)
self.sequence_length = sequence_length
self.sample_rate = sample_rate
self.audio_files = list(self.audio_dir.glob("*.wav"))
# Load and preprocess all audio
self.audio_data = []
self.fragment_names = []
self.load_audio_files()
def load_audio_files(self):
"""Load all audio files and convert to sequences"""
print(f"Loading {len(self.audio_files)} audio files...")
for audio_file in self.audio_files:
try:
audio_data = self.load_wav_file(str(audio_file))
if audio_data is not None:
# Normalize audio
audio_data = audio_data / np.max(np.abs(audio_data))
# Split into sequences
sequences = self.create_sequences(audio_data)
self.audio_data.extend(sequences)
# Store fragment name for each sequence
fragment_name = audio_file.stem
self.fragment_names.extend([fragment_name] * len(sequences))
print(f"Loaded {audio_file.name}: {len(sequences)} sequences")
except Exception as e:
print(f"Error loading {audio_file}: {e}")
def load_wav_file(self, file_path):
"""Load WAV file using Python's wave module"""
try:
with wave.open(file_path, 'rb') as wav_file:
# Get audio parameters
frames = wav_file.getnframes()
sample_width = wav_file.getsampwidth()
framerate = wav_file.getframerate()
# Read audio data
audio_bytes = wav_file.readframes(frames)
# Convert to numpy array
if sample_width == 1:
audio_data = np.frombuffer(audio_bytes, dtype=np.uint8)
audio_data = (audio_data - 128) / 128.0
elif sample_width == 2:
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
audio_data = audio_data / 32768.0
else:
print(f"Unsupported sample width: {sample_width}")
return None
# Resample if needed (basic resampling)
if framerate != self.sample_rate:
# Simple resampling - in practice, use librosa for better quality
step = framerate / self.sample_rate
indices = np.arange(0, len(audio_data), step).astype(int)
audio_data = audio_data[indices]
return audio_data
except Exception as e:
print(f"Error reading WAV file {file_path}: {e}")
return None
def create_sequences(self, audio_data):
"""Split audio into training sequences"""
sequences = []
# Create overlapping sequences
step_size = self.sequence_length // 2
for i in range(0, len(audio_data) - self.sequence_length, step_size):
sequence = audio_data[i:i + self.sequence_length]
sequences.append(sequence)
return sequences
def __len__(self):
return len(self.audio_data)
def __getitem__(self, idx):
audio_sequence = torch.FloatTensor(self.audio_data[idx])
fragment_name = self.fragment_names[idx]
# For autoregressive training: input is sequence[:-1], target is sequence[1:]
input_seq = audio_sequence[:-1]
target_seq = audio_sequence[1:]
return input_seq, target_seq, fragment_name
class MusicRNN(nn.Module):
"""RNN-based model for learning musical improvisation patterns"""
def __init__(self, input_size=1, hidden_size=256, num_layers=3, dropout=0.2):
super(MusicRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
# LSTM layers
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
batch_first=True
)
# Output layer
self.output = nn.Linear(hidden_size, input_size)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x, hidden=None):
# x shape: (batch_size, sequence_length, input_size)
lstm_out, hidden = self.lstm(x, hidden)
# Apply dropout
lstm_out = self.dropout(lstm_out)
# Generate output
output = self.output(lstm_out)
return output, hidden
def init_hidden(self, batch_size, device):
"""Initialize hidden state"""
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
return (h0, c0)
class MusicVAE(nn.Module):
"""Variational Autoencoder for learning musical style space"""
def __init__(self, sequence_length=1023, latent_dim=128, hidden_dim=256):
super(MusicVAE, self).__init__()
self.sequence_length = sequence_length
self.latent_dim = latent_dim
# Encoder
self.encoder_lstm = nn.LSTM(1, hidden_dim, 2, batch_first=True)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.decoder_fc = nn.Linear(latent_dim, hidden_dim)
self.decoder_lstm = nn.LSTM(hidden_dim, hidden_dim, 2, batch_first=True)
self.decoder_output = nn.Linear(hidden_dim, 1)
def encode(self, x):
"""Encode input to latent space"""
_, (h_n, _) = self.encoder_lstm(x)
h_n = h_n[-1] # Take last layer
mu = self.fc_mu(h_n)
logvar = self.fc_logvar(h_n)
return mu, logvar
def reparameterize(self, mu, logvar):
"""Reparameterization trick"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
"""Decode from latent space"""
batch_size = z.size(0)
# Expand latent vector to sequence
hidden = self.decoder_fc(z)
hidden = hidden.unsqueeze(1).repeat(1, self.sequence_length, 1)
# LSTM decoder
lstm_out, _ = self.decoder_lstm(hidden)
output = self.decoder_output(lstm_out)
return output
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar
class MusicGenerator:
"""Main class for training and generating music"""
def __init__(self, model_type='rnn', device=None):
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model_type = model_type
self.model = None
print(f"Using device: {self.device}")
def load_data(self, audio_dir, batch_size=32, sequence_length=1024):
"""Load and prepare training data"""
dataset = AudioDataset(audio_dir, sequence_length=sequence_length)
if len(dataset) == 0:
raise ValueError(f"No audio data found in {audio_dir}")
self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
self.sequence_length = sequence_length
print(f"Loaded {len(dataset)} sequences from {len(dataset.audio_files)} files")
# Get unique fragment names
self.fragment_names = list(set(dataset.fragment_names))
print(f"Fragment names: {self.fragment_names}")
return dataset
def create_model(self, **kwargs):
"""Create the generative model"""
if self.model_type == 'rnn':
self.model = MusicRNN(**kwargs)
elif self.model_type == 'vae':
self.model = MusicVAE(sequence_length=self.sequence_length-1, **kwargs)
else:
raise ValueError(f"Unknown model type: {self.model_type}")
self.model.to(self.device)
print(f"Created {self.model_type.upper()} model with {sum(p.numel() for p in self.model.parameters())} parameters")
return self.model
def train(self, epochs=100, learning_rate=0.001):
"""Train the model"""
if self.model is None:
raise ValueError("Model not created. Call create_model() first.")
optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
self.model.train()
losses = []
for epoch in range(epochs):
epoch_loss = 0
num_batches = 0
for batch_idx, (input_seq, target_seq, _) in enumerate(self.dataloader):
input_seq = input_seq.unsqueeze(-1).to(self.device) # Add feature dimension
target_seq = target_seq.unsqueeze(-1).to(self.device)
optimizer.zero_grad()
if self.model_type == 'rnn':
output, _ = self.model(input_seq)
loss = criterion(output, target_seq)
elif self.model_type == 'vae':
recon, mu, logvar = self.model(input_seq)
# VAE loss: reconstruction + KL divergence
recon_loss = criterion(recon, target_seq)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + 0.001 * kl_loss # Beta-VAE with beta=0.001
loss.backward()
optimizer.step()
epoch_loss += loss.item()
num_batches += 1
avg_loss = epoch_loss / num_batches
losses.append(avg_loss)
if epoch % 10 == 0:
print(f"Epoch {epoch}/{epochs}, Loss: {avg_loss:.6f}")
self.losses = losses
print("Training completed!")
return losses
def generate(self, length=2048, temperature=1.0, seed_sequence=None):
"""Generate new audio sequence"""
if self.model is None:
raise ValueError("Model not trained. Call train() first.")
self.model.eval()
with torch.no_grad():
if self.model_type == 'rnn':
return self._generate_rnn(length, temperature, seed_sequence)
elif self.model_type == 'vae':
return self._generate_vae(length)
def _generate_rnn(self, length, temperature, seed_sequence):
"""Generate using RNN model"""
if seed_sequence is not None:
generated = list(seed_sequence)
else:
# Start with a small random seed
generated = [np.random.randn() * 0.1]
hidden = None
for i in range(length):
# Use last part of generated sequence as input
input_seq = torch.FloatTensor(generated[-100:]).unsqueeze(0).unsqueeze(-1).to(self.device)
output, hidden = self.model(input_seq, hidden)
# Sample from output distribution with temperature
next_sample = output[0, -1, 0].item()
if temperature > 0:
next_sample = next_sample / temperature
next_sample = np.tanh(next_sample) # Keep values in reasonable range
generated.append(next_sample)
return np.array(generated)
def _generate_vae(self, length):
"""Generate using VAE model"""
# Sample from latent space
z = torch.randn(1, self.model.latent_dim).to(self.device)
# Decode to audio
generated = self.model.decode(z)
generated = generated.squeeze().cpu().numpy()
# Extend if needed
if len(generated) < length:
# Repeat the pattern
repeats = (length // len(generated)) + 1
generated = np.tile(generated, repeats)[:length]
return generated[:length]
def save_audio(self, audio_data, filename, sample_rate=22050):
"""Save generated audio to WAV file"""
# Normalize audio
audio_data = np.clip(audio_data, -1, 1)
audio_data = (audio_data * 32767).astype(np.int16)
with wave.open(filename, 'wb') as wav_file:
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_data.tobytes())
print(f"Audio saved to {filename}")
def plot_training_loss(self):
"""Plot training loss curve"""
if hasattr(self, 'losses'):
plt.figure(figsize=(10, 6))
plt.plot(self.losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()
# Example usage
def main():
# Create some example usage
print("Music Generation Framework")
print("=" * 50)
# Initialize generator
generator = MusicGenerator(model_type='rnn') # or 'vae'
# Note: You would replace this with your actual audio directory
audio_directory = "path/to/your/improvised/audio/files"
print(f"\nTo use this framework:")
print(f"1. Put your improvised audio files (.wav) in: {audio_directory}")
print(f"2. Name your files descriptively (e.g., 'jazz_piano_improv_1.wav')")
print(f"3. Run the training:")
print(f"")
print(f"# Load your audio data")
print(f"dataset = generator.load_data('{audio_directory}')")
print(f"")
print(f"# Create and train model")
print(f"model = generator.create_model(hidden_size=256, num_layers=3)")
print(f"losses = generator.train(epochs=100)")
print(f"")
print(f"# Generate new audio in your style")
print(f"generated_audio = generator.generate(length=4096)")
print(f"generator.save_audio(generated_audio, 'my_generated_improv.wav')")
# Example with synthetic data (for demonstration)
print(f"\n" + "="*50)
print("Demo with synthetic data:")
# Create synthetic audio data for demonstration
synthetic_dir = "synthetic_audio"
os.makedirs(synthetic_dir, exist_ok=True)
# Generate some synthetic "improvisations"
sample_rate = 22050
duration = 3 # seconds
for i, style in enumerate(['melodic', 'rhythmic', 'ambient']):
t = np.linspace(0, duration, sample_rate * duration)
if style == 'melodic':
# Create a melodic pattern
audio = np.sin(2 * np.pi * 440 * t) * np.exp(-t * 0.5)
audio += 0.3 * np.sin(2 * np.pi * 880 * t) * np.exp(-t * 0.3)
elif style == 'rhythmic':
# Create a rhythmic pattern
audio = np.zeros_like(t)
beats = np.arange(0, duration, 0.5)
for beat in beats:
idx = int(beat * sample_rate)
if idx < len(audio):
audio[idx:idx+1000] += np.sin(2 * np.pi * 200 * np.linspace(0, 0.1, 1000)) * np.exp(-np.linspace(0, 10, 1000))
else: # ambient
# Create ambient texture
audio = np.random.randn(len(t)) * 0.1
audio = np.convolve(audio, np.ones(100)/100, mode='same') # Smooth
# Add some variation
audio += 0.1 * np.random.randn(len(audio))
# Save synthetic audio
filename = f"{synthetic_dir}/{style}_improv_{i+1}.wav"
generator.save_audio(audio, filename, sample_rate)
print(f"Created synthetic audio files in {synthetic_dir}/")
print("You can now test the framework with these files!")
if __name__ == "__main__":
main()