import torch import torch.optim as optim import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import os from torch.utils.data import DataLoader # Imports from hsmm_model import BatchedGaussianHSMM from hsmm_inference import viterbi_decode from real_data import get_real_dataloaders # <--- NEW IMPORT # --- CONFIGURATION --- CONFIG = { # Path to your file "DATA_PATH": "/u/schmitt/experiments/2025_10_02_unsup_asr_shared_enc/work/i6_experiments/users/schmitt/experiments/exp2025_10_02_shared_enc/librispeech/data/audio_preprocessing/Wav2VecUFeaturizeAudioJob.mkGrrp0YWy8y/output/audio_features/train.npy", "N_STATES": 50, # Increased for real speech (approx 40-50 phonemes) "PCA_DIM": 30, # Reduce 512 -> 30 dimensions "MAX_DUR": 30, # Max duration in frames (30 * 20ms = 600ms) "LR": 0.01, # Slightly lower LR for real data "EPOCHS": 20, "BATCH_SIZE": 64, # 64 chunks of 3 seconds each "CROP_LEN": 150, # Training window (150 frames = 3 seconds) "CHECKPOINT_PATH": "hsmm_librispeech.pth", "RESUME": False } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"--- Using device: {device} ---") def save_checkpoint(model, optimizer, epoch, loss, path): torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, path) def load_checkpoint(model, optimizer, path): if os.path.exists(path): print(f"Loading checkpoint from {path}...") checkpoint = torch.load(path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) return checkpoint['epoch'] + 1 return 0 def train(): # 1. Load Real Data print("--- 1. Loading Real Data ---") train_ds, val_ds = get_real_dataloaders( CONFIG["DATA_PATH"], batch_size=CONFIG["BATCH_SIZE"], crop_len=CONFIG["CROP_LEN"], pca_dim=CONFIG["PCA_DIM"] ) # Loader for Training (Batched, Cropped) train_loader = DataLoader( train_ds, batch_size=CONFIG["BATCH_SIZE"], shuffle=True, num_workers=4, pin_memory=True, drop_last=True # Avoid partial batch issues ) # 2. Init Model print("--- 2. Initializing Model ---") model = BatchedGaussianHSMM(CONFIG["N_STATES"], CONFIG["PCA_DIM"], CONFIG["MAX_DUR"]) model.to(device) # Smart Init (using PCA-reduced data from the dataset) if not CONFIG["RESUME"] and not os.path.exists(CONFIG["CHECKPOINT_PATH"]): print("--- 2b. Running Smart Initialization ---") # Grab a batch to init means init_batch = next(iter(train_loader)).to(device) # (B, T, D) flat_data = init_batch.view(-1, CONFIG["PCA_DIM"]) # Pick random frames indices = torch.randperm(flat_data.size(0))[:CONFIG["N_STATES"]] model.means.data.copy_(flat_data[indices]) print("Means initialized.") optimizer = optim.Adam(model.parameters(), lr=CONFIG["LR"]) start_epoch = 0 if CONFIG["RESUME"]: start_epoch = load_checkpoint(model, optimizer, CONFIG["CHECKPOINT_PATH"]) # 3. Training Loop print(f"--- 3. Training Loop ---") for epoch in range(start_epoch, CONFIG["EPOCHS"]): total_loss = 0 model.train() for batch_idx, batch_data in enumerate(train_loader): batch_data = batch_data.to(device) optimizer.zero_grad() loss = model(batch_data) loss.backward() optimizer.step() total_loss += loss.item() if batch_idx % 50 == 0: print(f"Epoch {epoch} | Batch {batch_idx} | Loss {loss.item():.4f}") avg_loss = total_loss / len(train_loader) print(f"Epoch {epoch:02d} DONE | Avg NLL: {avg_loss:.4f}") save_checkpoint(model, optimizer, epoch, avg_loss, CONFIG["CHECKPOINT_PATH"]) # 4. Visualization (Using Validation Set - Full Length) print("\n--- 4. Visualizing Inference on Real Audio ---") # Grab the first file from validation set (Index 0) # val_ds[0] returns (Time, Dim) -> add batch dim -> (1, T, D) test_seq = val_ds[0].unsqueeze(0).to(device) # Run Inference path = viterbi_decode(model, test_seq) # Move to CPU for plotting raw_data = test_seq.squeeze(0).cpu().numpy() # Plotting fig, ax = plt.subplots(2, 1, figsize=(15, 6), sharex=True) # Plot PCA Features ax[0].imshow(raw_data.T, aspect='auto', cmap='viridis', interpolation='nearest') ax[0].set_title(f"PCA Reduced Features ({CONFIG['PCA_DIM']} Dim)") ax[0].set_ylabel("PCA Dim") # Plot States path_img = np.array(path)[np.newaxis, :] ax[1].imshow(path_img, aspect='auto', cmap='tab20', interpolation='nearest') ax[1].set_title("Inferred Phoneme States") ax[1].set_ylabel("State ID") ax[1].set_xlabel("Time (Frames)") plt.tight_layout() plt.savefig("librispeech_result.png") print("Saved librispeech_result.png") if __name__ == "__main__": train()