testing
This commit is contained in:
185
hsmm/main.py
185
hsmm/main.py
@@ -1,80 +1,151 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
from hsmm_model import GaussianHSMM
|
||||
import numpy as np
|
||||
import os
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Imports
|
||||
from hsmm_model import BatchedGaussianHSMM
|
||||
from hsmm_inference import viterbi_decode
|
||||
from toy_data import generate_toy_data
|
||||
from real_data import get_real_dataloaders # <--- NEW IMPORT
|
||||
|
||||
# --- Settings ---
|
||||
N_STATES = 10
|
||||
INPUT_DIM = 5 # Matches the 'dim' in generate_toy_data
|
||||
MAX_DUR = 50
|
||||
LR = 0.05
|
||||
EPOCHS = 20
|
||||
# --- CONFIGURATION ---
|
||||
CONFIG = {
|
||||
# Path to your file
|
||||
"DATA_PATH": "/u/schmitt/experiments/2025_10_02_unsup_asr_shared_enc/work/i6_experiments/users/schmitt/experiments/exp2025_10_02_shared_enc/librispeech/data/audio_preprocessing/Wav2VecUFeaturizeAudioJob.mkGrrp0YWy8y/output/audio_features/train.npy",
|
||||
|
||||
"N_STATES": 50, # Increased for real speech (approx 40-50 phonemes)
|
||||
"PCA_DIM": 30, # Reduce 512 -> 30 dimensions
|
||||
"MAX_DUR": 30, # Max duration in frames (30 * 20ms = 600ms)
|
||||
"LR": 0.01, # Slightly lower LR for real data
|
||||
"EPOCHS": 20,
|
||||
"BATCH_SIZE": 64, # 64 chunks of 3 seconds each
|
||||
"CROP_LEN": 150, # Training window (150 frames = 3 seconds)
|
||||
|
||||
"CHECKPOINT_PATH": "hsmm_librispeech.pth",
|
||||
"RESUME": False
|
||||
}
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"--- Using device: {device} ---")
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch, loss, path):
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
}, path)
|
||||
|
||||
def load_checkpoint(model, optimizer, path):
|
||||
if os.path.exists(path):
|
||||
print(f"Loading checkpoint from {path}...")
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
return checkpoint['epoch'] + 1
|
||||
return 0
|
||||
|
||||
# In train():
|
||||
def train():
|
||||
print("1. Generating Data...")
|
||||
train_data = generate_toy_data(n_samples=30, seq_len=300, n_clusters=N_STATES, dim=INPUT_DIM)
|
||||
# 1. Load Real Data
|
||||
print("--- 1. Loading Real Data ---")
|
||||
train_ds, val_ds = get_real_dataloaders(
|
||||
CONFIG["DATA_PATH"],
|
||||
batch_size=CONFIG["BATCH_SIZE"],
|
||||
crop_len=CONFIG["CROP_LEN"],
|
||||
pca_dim=CONFIG["PCA_DIM"]
|
||||
)
|
||||
|
||||
print("2. Initializing Model...")
|
||||
model = GaussianHSMM(N_STATES, INPUT_DIM, MAX_DUR)
|
||||
optimizer = optim.Adam(model.parameters(), lr=LR)
|
||||
# Loader for Training (Batched, Cropped)
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=CONFIG["BATCH_SIZE"],
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
drop_last=True # Avoid partial batch issues
|
||||
)
|
||||
|
||||
print("3. Training Loop...")
|
||||
loss_history = []
|
||||
# 2. Init Model
|
||||
print("--- 2. Initializing Model ---")
|
||||
model = BatchedGaussianHSMM(CONFIG["N_STATES"], CONFIG["PCA_DIM"], CONFIG["MAX_DUR"])
|
||||
model.to(device)
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
epoch_loss = 0
|
||||
optimizer.zero_grad()
|
||||
# Smart Init (using PCA-reduced data from the dataset)
|
||||
if not CONFIG["RESUME"] and not os.path.exists(CONFIG["CHECKPOINT_PATH"]):
|
||||
print("--- 2b. Running Smart Initialization ---")
|
||||
# Grab a batch to init means
|
||||
init_batch = next(iter(train_loader)).to(device) # (B, T, D)
|
||||
flat_data = init_batch.view(-1, CONFIG["PCA_DIM"])
|
||||
|
||||
# Batching: Gradient Accumulation
|
||||
for seq in train_data:
|
||||
loss = model(seq) # Forward pass
|
||||
loss.backward() # Backward pass
|
||||
epoch_loss += loss.item()
|
||||
|
||||
# Normalize gradients
|
||||
for p in model.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad /= len(train_data)
|
||||
|
||||
optimizer.step()
|
||||
loss_history.append(epoch_loss)
|
||||
|
||||
if epoch % 5 == 0:
|
||||
print(f"Epoch {epoch:02d} | NLL Loss: {epoch_loss:.2f}")
|
||||
# Pick random frames
|
||||
indices = torch.randperm(flat_data.size(0))[:CONFIG["N_STATES"]]
|
||||
model.means.data.copy_(flat_data[indices])
|
||||
print("Means initialized.")
|
||||
|
||||
# --- Verification ---
|
||||
print("\n4. Results:")
|
||||
learned_means = model.means.detach().view(-1).numpy()
|
||||
learned_means.sort()
|
||||
print(f"True Means: [-5.0, 0.0, 5.0]")
|
||||
print(f"Learned Means: {learned_means}")
|
||||
optimizer = optim.Adam(model.parameters(), lr=CONFIG["LR"])
|
||||
start_epoch = 0
|
||||
if CONFIG["RESUME"]:
|
||||
start_epoch = load_checkpoint(model, optimizer, CONFIG["CHECKPOINT_PATH"])
|
||||
|
||||
# --- Visualization Block in main.py ---
|
||||
print("5. Visualizing Inference...")
|
||||
test_seq = train_data[0]
|
||||
predicted_path = viterbi_decode(model, test_seq)
|
||||
# 3. Training Loop
|
||||
print(f"--- 3. Training Loop ---")
|
||||
|
||||
fig, ax = plt.subplots(2, 1, figsize=(12, 6), sharex=True)
|
||||
for epoch in range(start_epoch, CONFIG["EPOCHS"]):
|
||||
total_loss = 0
|
||||
model.train()
|
||||
|
||||
for batch_idx, batch_data in enumerate(train_loader):
|
||||
batch_data = batch_data.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch_data)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if batch_idx % 50 == 0:
|
||||
print(f"Epoch {epoch} | Batch {batch_idx} | Loss {loss.item():.4f}")
|
||||
|
||||
avg_loss = total_loss / len(train_loader)
|
||||
print(f"Epoch {epoch:02d} DONE | Avg NLL: {avg_loss:.4f}")
|
||||
save_checkpoint(model, optimizer, epoch, avg_loss, CONFIG["CHECKPOINT_PATH"])
|
||||
|
||||
# 4. Visualization (Using Validation Set - Full Length)
|
||||
print("\n--- 4. Visualizing Inference on Real Audio ---")
|
||||
|
||||
# Plot 1: The Multi-Dimensional Data (Transposed so Time is X-axis)
|
||||
# This shows the "features" changing color as the state changes
|
||||
ax[0].imshow(test_seq.numpy().T, aspect='auto', cmap='viridis', interpolation='nearest')
|
||||
ax[0].set_title(f"Raw Data ({INPUT_DIM} Dimensions)")
|
||||
ax[0].set_ylabel("Feature Dim")
|
||||
# Grab the first file from validation set (Index 0)
|
||||
# val_ds[0] returns (Time, Dim) -> add batch dim -> (1, T, D)
|
||||
test_seq = val_ds[0].unsqueeze(0).to(device)
|
||||
|
||||
# Plot 2: The Inferred States
|
||||
# Reshape path to (1, T) for imshow
|
||||
path_img = np.array(predicted_path)[np.newaxis, :]
|
||||
ax[1].imshow(path_img, aspect='auto', cmap='tab10', interpolation='nearest')
|
||||
ax[1].set_title("Inferred HSMM States")
|
||||
# Run Inference
|
||||
path = viterbi_decode(model, test_seq)
|
||||
|
||||
# Move to CPU for plotting
|
||||
raw_data = test_seq.squeeze(0).cpu().numpy()
|
||||
|
||||
# Plotting
|
||||
fig, ax = plt.subplots(2, 1, figsize=(15, 6), sharex=True)
|
||||
|
||||
# Plot PCA Features
|
||||
ax[0].imshow(raw_data.T, aspect='auto', cmap='viridis', interpolation='nearest')
|
||||
ax[0].set_title(f"PCA Reduced Features ({CONFIG['PCA_DIM']} Dim)")
|
||||
ax[0].set_ylabel("PCA Dim")
|
||||
|
||||
# Plot States
|
||||
path_img = np.array(path)[np.newaxis, :]
|
||||
ax[1].imshow(path_img, aspect='auto', cmap='tab20', interpolation='nearest')
|
||||
ax[1].set_title("Inferred Phoneme States")
|
||||
ax[1].set_ylabel("State ID")
|
||||
ax[1].set_xlabel("Time (Frames)")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
plt.savefig("librispeech_result.png")
|
||||
print("Saved librispeech_result.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
|
||||
Reference in New Issue
Block a user