testing
This commit is contained in:
@@ -3,51 +3,52 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
def viterbi_decode(model, x):
|
def viterbi_decode(model, x):
|
||||||
"""
|
"""
|
||||||
Returns the optimal sequence of states (path).
|
Returns the optimal sequence of states (path) using Viterbi algorithm.
|
||||||
|
x: (Time, Dim) or (1, Time, Dim)
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
T = x.shape[0]
|
# Handle Batch Dimension if missing
|
||||||
|
if x.dim() == 2:
|
||||||
|
x = x.unsqueeze(0) # (1, T, D)
|
||||||
|
|
||||||
|
T = x.shape[1]
|
||||||
N = model.n_states
|
N = model.n_states
|
||||||
D_max = model.max_dur
|
D_max = model.max_dur
|
||||||
|
|
||||||
# 1. Setup Probs
|
# 1. Get Probs (Using Batched Model Function)
|
||||||
log_emit = model.compute_emission_log_probs(x)
|
log_emit = model.compute_emission_log_probs(x)
|
||||||
log_trans = F.log_softmax(model.get_masked_transitions(), dim=1)
|
log_emit = log_emit.squeeze(0) # (T, N)
|
||||||
|
|
||||||
|
# Get other probs
|
||||||
|
mask = torch.eye(N, device=x.device).bool()
|
||||||
|
log_trans = F.log_softmax(model.trans_logits.masked_fill(mask, -float('inf')), dim=1)
|
||||||
log_dur = F.log_softmax(model.dur_logits, dim=1)
|
log_dur = F.log_softmax(model.dur_logits, dim=1)
|
||||||
log_pi = F.log_softmax(model.pi_logits, dim=0)
|
log_pi = F.log_softmax(model.pi_logits, dim=0)
|
||||||
|
|
||||||
# 2. Viterbi Tables
|
# 2. Viterbi Tables
|
||||||
# max_prob[t, s] = Best log-prob ending at t in state s
|
|
||||||
max_prob = torch.full((T, N), -float('inf'), device=x.device)
|
max_prob = torch.full((T, N), -float('inf'), device=x.device)
|
||||||
# backpointers[t, s] = (previous_state, duration_used)
|
|
||||||
backpointers = {}
|
backpointers = {}
|
||||||
|
|
||||||
# 3. Dynamic Programming
|
# 3. Dynamic Programming Loop
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
for d in range(1, D_max + 1):
|
for d in range(1, D_max + 1):
|
||||||
if t - d + 1 < 0: continue
|
if t - d + 1 < 0: continue
|
||||||
|
|
||||||
# Emission sum for segment
|
|
||||||
seg_emit = log_emit[t-d+1 : t+1].sum(dim=0)
|
seg_emit = log_emit[t-d+1 : t+1].sum(dim=0)
|
||||||
dur_prob = log_dur[:, d-1]
|
dur_prob = log_dur[:, d-1]
|
||||||
|
|
||||||
if t - d + 1 == 0:
|
if t - d + 1 == 0:
|
||||||
# Init
|
|
||||||
score = log_pi + dur_prob + seg_emit
|
score = log_pi + dur_prob + seg_emit
|
||||||
for s in range(N):
|
for s in range(N):
|
||||||
if score[s] > max_prob[t, s]:
|
if score[s] > max_prob[t, s]:
|
||||||
max_prob[t, s] = score[s]
|
max_prob[t, s] = score[s]
|
||||||
backpointers[(t, s)] = (-1, d) # -1 is Start
|
backpointers[(t, s)] = (-1, d)
|
||||||
else:
|
else:
|
||||||
# Transition
|
prev_scores = max_prob[t-d]
|
||||||
prev_scores = max_prob[t-d] # (N,)
|
|
||||||
# Find best transition for each target state s
|
|
||||||
# (N, 1) + (N, N) -> (N, N)
|
|
||||||
trans_scores = prev_scores.unsqueeze(1) + log_trans
|
trans_scores = prev_scores.unsqueeze(1) + log_trans
|
||||||
best_prev_score, best_prev_idx = trans_scores.max(dim=0) # (N,)
|
best_prev_score, best_prev_idx = trans_scores.max(dim=0)
|
||||||
|
|
||||||
current_score = best_prev_score + dur_prob + seg_emit
|
current_score = best_prev_score + dur_prob + seg_emit
|
||||||
|
|
||||||
for s in range(N):
|
for s in range(N):
|
||||||
if current_score[s] > max_prob[t, s]:
|
if current_score[s] > max_prob[t, s]:
|
||||||
max_prob[t, s] = current_score[s]
|
max_prob[t, s] = current_score[s]
|
||||||
@@ -62,10 +63,8 @@ def viterbi_decode(model, x):
|
|||||||
while curr_t >= 0:
|
while curr_t >= 0:
|
||||||
if (curr_t, curr_s) not in backpointers: break
|
if (curr_t, curr_s) not in backpointers: break
|
||||||
prev_s, d = backpointers[(curr_t, curr_s)]
|
prev_s, d = backpointers[(curr_t, curr_s)]
|
||||||
|
|
||||||
# Append this state 'd' times
|
|
||||||
path = [curr_s] * d + path
|
path = [curr_s] * d + path
|
||||||
curr_t -= d
|
curr_t -= d
|
||||||
curr_s = prev_s
|
curr_s = prev_s
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|||||||
@@ -1,107 +1,119 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# --- JIT Compiled Forward Loop for Speed ---
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
# --- Batched JIT Forward Loop ---
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def hsmm_forward_loop(T: int, N: int, D_max: int,
|
def hsmm_forward_loop_batched(T: int, N: int, D_max: int,
|
||||||
log_emit: torch.Tensor,
|
log_emit: torch.Tensor,
|
||||||
log_trans: torch.Tensor,
|
log_trans: torch.Tensor,
|
||||||
log_dur: torch.Tensor,
|
log_dur: torch.Tensor,
|
||||||
log_pi: torch.Tensor) -> torch.Tensor:
|
log_pi: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Computes Marginal Log-Likelihood for a BATCH of sequences in parallel.
|
||||||
|
Uses a list for alpha history to avoid in-place modification errors.
|
||||||
|
"""
|
||||||
|
BatchSize = log_emit.shape[0]
|
||||||
|
|
||||||
# We use a List to store alpha steps. This avoids "in-place" errors.
|
# alpha_list[t] will hold the alpha tensor for time t
|
||||||
# We initialize with dummy tensors to handle negative indexing if needed,
|
|
||||||
# but logically we just append.
|
|
||||||
alpha_list: List[torch.Tensor] = []
|
alpha_list: List[torch.Tensor] = []
|
||||||
|
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
# Initialize current step with -inf
|
# Init accumulator for this time step with -inf
|
||||||
current_alpha = torch.full((N,), -float('inf'), device=log_emit.device)
|
current_alpha = torch.full((BatchSize, N), -float('inf'), device=log_emit.device)
|
||||||
|
|
||||||
for d in range(1, D_max + 1):
|
for d in range(1, D_max + 1):
|
||||||
if t - d + 1 < 0: continue
|
if t - d + 1 < 0: continue
|
||||||
|
|
||||||
# 1. Emission Score (Sum)
|
# 1. Emission Sum for segment: Sum(t-d+1 ... t)
|
||||||
seg_emit = log_emit[t-d+1 : t+1].sum(dim=0)
|
# Slice: (Batch, d, N) -> Sum dim 1 -> (Batch, N)
|
||||||
|
seg_emit = log_emit[:, t-d+1 : t+1, :].sum(dim=1)
|
||||||
|
|
||||||
# 2. Duration Score
|
# 2. Duration: (N) -> (1, N)
|
||||||
dur_score = log_dur[:, d-1]
|
dur_score = log_dur[:, d-1].unsqueeze(0)
|
||||||
|
|
||||||
# 3. Transition Score
|
# 3. Transition Logic
|
||||||
if t - d + 1 == 0:
|
if t - d + 1 == 0:
|
||||||
# Init (t=0 or start of a duration)
|
# Initialization (Segment starts at t=0)
|
||||||
score = log_pi + dur_score + seg_emit
|
path_score = log_pi.unsqueeze(0) + dur_score + seg_emit
|
||||||
current_alpha = torch.logaddexp(current_alpha, score)
|
current_alpha = torch.logaddexp(current_alpha, path_score)
|
||||||
else:
|
else:
|
||||||
# Recursion: look back at alpha[t-d]
|
# Transition from prev_alpha at t-d
|
||||||
# In a list, alpha[t-d] is just alpha_list[t-d]
|
|
||||||
prev_alpha = alpha_list[t-d]
|
prev_alpha = alpha_list[t-d]
|
||||||
|
|
||||||
trans_score = torch.logsumexp(prev_alpha.unsqueeze(1) + log_trans, dim=0)
|
# Broadcast for Transition Matrix:
|
||||||
score = trans_score + dur_score + seg_emit
|
# prev: (Batch, N, 1)
|
||||||
current_alpha = torch.logaddexp(current_alpha, score)
|
# trans: (1, N, N)
|
||||||
|
# sum: (Batch, N, N) -> LogSumExp over Prev State -> (Batch, N)
|
||||||
|
trans_score = torch.logsumexp(
|
||||||
|
prev_alpha.unsqueeze(2) + log_trans.unsqueeze(0),
|
||||||
|
dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
path_score = trans_score + dur_score + seg_emit
|
||||||
|
current_alpha = torch.logaddexp(current_alpha, path_score)
|
||||||
|
|
||||||
# Save this step to the history list
|
|
||||||
alpha_list.append(current_alpha)
|
alpha_list.append(current_alpha)
|
||||||
|
|
||||||
# The result is the sum of the very last step
|
# Final sum over states for each batch element: (Batch,)
|
||||||
return torch.logsumexp(alpha_list[-1], dim=0)
|
return torch.logsumexp(alpha_list[-1], dim=1)
|
||||||
|
|
||||||
class GaussianHSMM(nn.Module):
|
|
||||||
|
class BatchedGaussianHSMM(nn.Module):
|
||||||
def __init__(self, n_states, input_dim, max_dur=20):
|
def __init__(self, n_states, input_dim, max_dur=20):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_states = n_states
|
self.n_states = n_states
|
||||||
self.max_dur = max_dur
|
self.max_dur = max_dur
|
||||||
|
|
||||||
# --- Parameters ---
|
# --- Learnable Parameters ---
|
||||||
# 1. Start Probabilities
|
|
||||||
self.pi_logits = nn.Parameter(torch.randn(n_states))
|
self.pi_logits = nn.Parameter(torch.randn(n_states))
|
||||||
|
|
||||||
# 2. Transition Matrix (Bigram)
|
|
||||||
# We manually mask the diagonal later to forbid self-transitions
|
|
||||||
self.trans_logits = nn.Parameter(torch.randn(n_states, n_states))
|
self.trans_logits = nn.Parameter(torch.randn(n_states, n_states))
|
||||||
|
|
||||||
# 3. Duration Model (Categorical weights for 1..max_dur)
|
|
||||||
self.dur_logits = nn.Parameter(torch.randn(n_states, max_dur))
|
self.dur_logits = nn.Parameter(torch.randn(n_states, max_dur))
|
||||||
|
|
||||||
# 4. Emissions (Gaussian Means & LogVars)
|
|
||||||
self.means = nn.Parameter(torch.randn(n_states, input_dim))
|
self.means = nn.Parameter(torch.randn(n_states, input_dim))
|
||||||
self.log_vars = nn.Parameter(torch.zeros(n_states, input_dim))
|
self.log_vars = nn.Parameter(torch.zeros(n_states, input_dim))
|
||||||
|
|
||||||
def compute_emission_log_probs(self, x):
|
def compute_emission_log_probs(self, x):
|
||||||
""" Calculates Gaussian Log-Likelihood: (T, N) """
|
"""
|
||||||
# x: (T, Dim) -> (T, 1, Dim)
|
Calculates Gaussian Log-Likelihood for a Batch.
|
||||||
# means: (N, Dim) -> (1, N, Dim)
|
x: (Batch, T, Dim)
|
||||||
# log_prob = -0.5 * (log(2pi) + log_var + (x-mu)^2/var)
|
Returns: (Batch, T, N_States)
|
||||||
|
"""
|
||||||
|
# x: (Batch, T, 1, Dim)
|
||||||
|
# means: (1, 1, N, Dim)
|
||||||
|
diff = x.unsqueeze(2) - self.means.reshape(1, 1, self.n_states, -1)
|
||||||
|
|
||||||
diff = x.unsqueeze(1) - self.means.unsqueeze(0)
|
vars = self.log_vars.exp().reshape(1, 1, self.n_states, -1)
|
||||||
vars = self.log_vars.exp().unsqueeze(0)
|
log_vars = self.log_vars.reshape(1, 1, self.n_states, -1)
|
||||||
log_vars = self.log_vars.unsqueeze(0)
|
|
||||||
|
|
||||||
log_prob = -0.5 * (np.log(2 * np.pi) + log_vars + (diff**2) / vars)
|
# Log Gaussian PDF
|
||||||
return log_prob.sum(dim=-1) # Sum over feature dimensions
|
log_prob = -0.5 * (torch.log(torch.tensor(2 * 3.14159, device=x.device)) + log_vars + (diff**2) / vars)
|
||||||
|
|
||||||
def get_masked_transitions(self):
|
# Sum over Feature Dimension -> (Batch, T, N)
|
||||||
""" Enforces A_ii = -inf (No self-transitions allowed in Bigram) """
|
return log_prob.sum(dim=-1)
|
||||||
mask = torch.eye(self.n_states, device=self.trans_logits.device).bool()
|
|
||||||
return self.trans_logits.masked_fill(mask, -float('inf'))
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
""" Returns Negative Log Likelihood (Scalar Loss) """
|
"""
|
||||||
T = x.shape[0]
|
x: (Batch, T, Dim)
|
||||||
|
Returns: Scalar Loss (Mean NLL over batch)
|
||||||
|
"""
|
||||||
|
B, T, D = x.shape
|
||||||
|
|
||||||
# 1. Precompute static probabilities
|
log_emit = self.compute_emission_log_probs(x) # (B, T, N)
|
||||||
log_emit = self.compute_emission_log_probs(x)
|
|
||||||
log_trans = F.log_softmax(self.get_masked_transitions(), dim=1)
|
# Mask diagonal of transition matrix (No self-loops)
|
||||||
|
mask = torch.eye(self.n_states, device=self.trans_logits.device).bool()
|
||||||
|
masked_trans = self.trans_logits.masked_fill(mask, -float('inf'))
|
||||||
|
|
||||||
|
log_trans = F.log_softmax(masked_trans, dim=1)
|
||||||
log_dur = F.log_softmax(self.dur_logits, dim=1)
|
log_dur = F.log_softmax(self.dur_logits, dim=1)
|
||||||
log_pi = F.log_softmax(self.pi_logits, dim=0)
|
log_pi = F.log_softmax(self.pi_logits, dim=0)
|
||||||
|
|
||||||
# 2. Run JIT Loop
|
# Run Batched JIT Loop
|
||||||
total_ll = hsmm_forward_loop(T, self.n_states, self.max_dur,
|
batch_log_likelihoods = hsmm_forward_loop_batched(
|
||||||
log_emit, log_trans, log_dur, log_pi)
|
T, self.n_states, self.max_dur,
|
||||||
|
log_emit, log_trans, log_dur, log_pi
|
||||||
|
)
|
||||||
|
|
||||||
return -total_ll
|
# Return Mean Negative Log Likelihood
|
||||||
|
return -batch_log_likelihoods.mean()
|
||||||
|
|||||||
185
hsmm/main.py
185
hsmm/main.py
@@ -1,80 +1,151 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from hsmm_model import GaussianHSMM
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Imports
|
||||||
|
from hsmm_model import BatchedGaussianHSMM
|
||||||
from hsmm_inference import viterbi_decode
|
from hsmm_inference import viterbi_decode
|
||||||
from toy_data import generate_toy_data
|
from real_data import get_real_dataloaders # <--- NEW IMPORT
|
||||||
|
|
||||||
# --- Settings ---
|
# --- CONFIGURATION ---
|
||||||
N_STATES = 10
|
CONFIG = {
|
||||||
INPUT_DIM = 5 # Matches the 'dim' in generate_toy_data
|
# Path to your file
|
||||||
MAX_DUR = 50
|
"DATA_PATH": "/u/schmitt/experiments/2025_10_02_unsup_asr_shared_enc/work/i6_experiments/users/schmitt/experiments/exp2025_10_02_shared_enc/librispeech/data/audio_preprocessing/Wav2VecUFeaturizeAudioJob.mkGrrp0YWy8y/output/audio_features/train.npy",
|
||||||
LR = 0.05
|
|
||||||
EPOCHS = 20
|
"N_STATES": 50, # Increased for real speech (approx 40-50 phonemes)
|
||||||
|
"PCA_DIM": 30, # Reduce 512 -> 30 dimensions
|
||||||
|
"MAX_DUR": 30, # Max duration in frames (30 * 20ms = 600ms)
|
||||||
|
"LR": 0.01, # Slightly lower LR for real data
|
||||||
|
"EPOCHS": 20,
|
||||||
|
"BATCH_SIZE": 64, # 64 chunks of 3 seconds each
|
||||||
|
"CROP_LEN": 150, # Training window (150 frames = 3 seconds)
|
||||||
|
|
||||||
|
"CHECKPOINT_PATH": "hsmm_librispeech.pth",
|
||||||
|
"RESUME": False
|
||||||
|
}
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"--- Using device: {device} ---")
|
||||||
|
|
||||||
|
def save_checkpoint(model, optimizer, epoch, loss, path):
|
||||||
|
torch.save({
|
||||||
|
'epoch': epoch,
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'loss': loss,
|
||||||
|
}, path)
|
||||||
|
|
||||||
|
def load_checkpoint(model, optimizer, path):
|
||||||
|
if os.path.exists(path):
|
||||||
|
print(f"Loading checkpoint from {path}...")
|
||||||
|
checkpoint = torch.load(path)
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
return checkpoint['epoch'] + 1
|
||||||
|
return 0
|
||||||
|
|
||||||
# In train():
|
|
||||||
def train():
|
def train():
|
||||||
print("1. Generating Data...")
|
# 1. Load Real Data
|
||||||
train_data = generate_toy_data(n_samples=30, seq_len=300, n_clusters=N_STATES, dim=INPUT_DIM)
|
print("--- 1. Loading Real Data ---")
|
||||||
|
train_ds, val_ds = get_real_dataloaders(
|
||||||
|
CONFIG["DATA_PATH"],
|
||||||
|
batch_size=CONFIG["BATCH_SIZE"],
|
||||||
|
crop_len=CONFIG["CROP_LEN"],
|
||||||
|
pca_dim=CONFIG["PCA_DIM"]
|
||||||
|
)
|
||||||
|
|
||||||
print("2. Initializing Model...")
|
# Loader for Training (Batched, Cropped)
|
||||||
model = GaussianHSMM(N_STATES, INPUT_DIM, MAX_DUR)
|
train_loader = DataLoader(
|
||||||
optimizer = optim.Adam(model.parameters(), lr=LR)
|
train_ds,
|
||||||
|
batch_size=CONFIG["BATCH_SIZE"],
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True # Avoid partial batch issues
|
||||||
|
)
|
||||||
|
|
||||||
print("3. Training Loop...")
|
# 2. Init Model
|
||||||
loss_history = []
|
print("--- 2. Initializing Model ---")
|
||||||
|
model = BatchedGaussianHSMM(CONFIG["N_STATES"], CONFIG["PCA_DIM"], CONFIG["MAX_DUR"])
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
for epoch in range(EPOCHS):
|
# Smart Init (using PCA-reduced data from the dataset)
|
||||||
epoch_loss = 0
|
if not CONFIG["RESUME"] and not os.path.exists(CONFIG["CHECKPOINT_PATH"]):
|
||||||
optimizer.zero_grad()
|
print("--- 2b. Running Smart Initialization ---")
|
||||||
|
# Grab a batch to init means
|
||||||
|
init_batch = next(iter(train_loader)).to(device) # (B, T, D)
|
||||||
|
flat_data = init_batch.view(-1, CONFIG["PCA_DIM"])
|
||||||
|
|
||||||
# Batching: Gradient Accumulation
|
# Pick random frames
|
||||||
for seq in train_data:
|
indices = torch.randperm(flat_data.size(0))[:CONFIG["N_STATES"]]
|
||||||
loss = model(seq) # Forward pass
|
model.means.data.copy_(flat_data[indices])
|
||||||
loss.backward() # Backward pass
|
print("Means initialized.")
|
||||||
epoch_loss += loss.item()
|
|
||||||
|
|
||||||
# Normalize gradients
|
|
||||||
for p in model.parameters():
|
|
||||||
if p.grad is not None:
|
|
||||||
p.grad /= len(train_data)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
loss_history.append(epoch_loss)
|
|
||||||
|
|
||||||
if epoch % 5 == 0:
|
|
||||||
print(f"Epoch {epoch:02d} | NLL Loss: {epoch_loss:.2f}")
|
|
||||||
|
|
||||||
# --- Verification ---
|
optimizer = optim.Adam(model.parameters(), lr=CONFIG["LR"])
|
||||||
print("\n4. Results:")
|
start_epoch = 0
|
||||||
learned_means = model.means.detach().view(-1).numpy()
|
if CONFIG["RESUME"]:
|
||||||
learned_means.sort()
|
start_epoch = load_checkpoint(model, optimizer, CONFIG["CHECKPOINT_PATH"])
|
||||||
print(f"True Means: [-5.0, 0.0, 5.0]")
|
|
||||||
print(f"Learned Means: {learned_means}")
|
|
||||||
|
|
||||||
# --- Visualization Block in main.py ---
|
# 3. Training Loop
|
||||||
print("5. Visualizing Inference...")
|
print(f"--- 3. Training Loop ---")
|
||||||
test_seq = train_data[0]
|
|
||||||
predicted_path = viterbi_decode(model, test_seq)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(2, 1, figsize=(12, 6), sharex=True)
|
for epoch in range(start_epoch, CONFIG["EPOCHS"]):
|
||||||
|
total_loss = 0
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
for batch_idx, batch_data in enumerate(train_loader):
|
||||||
|
batch_data = batch_data.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss = model(batch_data)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
if batch_idx % 50 == 0:
|
||||||
|
print(f"Epoch {epoch} | Batch {batch_idx} | Loss {loss.item():.4f}")
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(train_loader)
|
||||||
|
print(f"Epoch {epoch:02d} DONE | Avg NLL: {avg_loss:.4f}")
|
||||||
|
save_checkpoint(model, optimizer, epoch, avg_loss, CONFIG["CHECKPOINT_PATH"])
|
||||||
|
|
||||||
|
# 4. Visualization (Using Validation Set - Full Length)
|
||||||
|
print("\n--- 4. Visualizing Inference on Real Audio ---")
|
||||||
|
|
||||||
# Plot 1: The Multi-Dimensional Data (Transposed so Time is X-axis)
|
# Grab the first file from validation set (Index 0)
|
||||||
# This shows the "features" changing color as the state changes
|
# val_ds[0] returns (Time, Dim) -> add batch dim -> (1, T, D)
|
||||||
ax[0].imshow(test_seq.numpy().T, aspect='auto', cmap='viridis', interpolation='nearest')
|
test_seq = val_ds[0].unsqueeze(0).to(device)
|
||||||
ax[0].set_title(f"Raw Data ({INPUT_DIM} Dimensions)")
|
|
||||||
ax[0].set_ylabel("Feature Dim")
|
|
||||||
|
|
||||||
# Plot 2: The Inferred States
|
# Run Inference
|
||||||
# Reshape path to (1, T) for imshow
|
path = viterbi_decode(model, test_seq)
|
||||||
path_img = np.array(predicted_path)[np.newaxis, :]
|
|
||||||
ax[1].imshow(path_img, aspect='auto', cmap='tab10', interpolation='nearest')
|
# Move to CPU for plotting
|
||||||
ax[1].set_title("Inferred HSMM States")
|
raw_data = test_seq.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
|
# Plotting
|
||||||
|
fig, ax = plt.subplots(2, 1, figsize=(15, 6), sharex=True)
|
||||||
|
|
||||||
|
# Plot PCA Features
|
||||||
|
ax[0].imshow(raw_data.T, aspect='auto', cmap='viridis', interpolation='nearest')
|
||||||
|
ax[0].set_title(f"PCA Reduced Features ({CONFIG['PCA_DIM']} Dim)")
|
||||||
|
ax[0].set_ylabel("PCA Dim")
|
||||||
|
|
||||||
|
# Plot States
|
||||||
|
path_img = np.array(path)[np.newaxis, :]
|
||||||
|
ax[1].imshow(path_img, aspect='auto', cmap='tab20', interpolation='nearest')
|
||||||
|
ax[1].set_title("Inferred Phoneme States")
|
||||||
ax[1].set_ylabel("State ID")
|
ax[1].set_ylabel("State ID")
|
||||||
ax[1].set_xlabel("Time (Frames)")
|
ax[1].set_xlabel("Time (Frames)")
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
plt.savefig("librispeech_result.png")
|
||||||
|
print("Saved librispeech_result.png")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train()
|
train()
|
||||||
|
|||||||
@@ -8,3 +8,9 @@ dependencies = [
|
|||||||
"matplotlib>=3.10.8",
|
"matplotlib>=3.10.8",
|
||||||
"torch>=2.9.1",
|
"torch>=2.9.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.pyright]
|
||||||
|
# "venvPath" specifies the folder *containing* the venv directory
|
||||||
|
venvPath = "."
|
||||||
|
# "venv" specifies the *name* of the venv directory
|
||||||
|
venv = ".venv"
|
||||||
|
|||||||
91
hsmm/real_data.py
Normal file
91
hsmm/real_data.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
|
||||||
|
class RealAudioDataset(Dataset):
|
||||||
|
def __init__(self, npy_path, len_path=None, crop_len=None, pca_dim=None, pca_model=None):
|
||||||
|
"""
|
||||||
|
npy_path: Path to the huge .npy file
|
||||||
|
len_path: Path to the .lengths file (optional, tries to infer if None)
|
||||||
|
crop_len: If set (e.g., 200), we randomly crop sequences to this length for training.
|
||||||
|
pca_dim: If set (e.g., 30), we learn/apply PCA reduction.
|
||||||
|
"""
|
||||||
|
# 1. Load Data (Memory Mapped to save RAM)
|
||||||
|
if not os.path.exists(npy_path):
|
||||||
|
raise FileNotFoundError(f"Could not find {npy_path}")
|
||||||
|
|
||||||
|
self.data = np.load(npy_path, mmap_mode='r')
|
||||||
|
self.input_dim = self.data.shape[1]
|
||||||
|
|
||||||
|
# 2. Load Lengths
|
||||||
|
if len_path is None:
|
||||||
|
# Assume .lengths is next to .npy
|
||||||
|
len_path = npy_path.replace('.npy', '.lengths')
|
||||||
|
|
||||||
|
if not os.path.exists(len_path):
|
||||||
|
raise FileNotFoundError(f"Could not find length file: {len_path}")
|
||||||
|
|
||||||
|
with open(len_path, 'r') as f:
|
||||||
|
self.lengths = [int(x) for x in f.read().strip().split()]
|
||||||
|
|
||||||
|
# Create Offsets (Where each sentence starts in the flat file)
|
||||||
|
self.offsets = np.cumsum([0] + self.lengths[:-1])
|
||||||
|
self.n_samples = len(self.lengths)
|
||||||
|
self.crop_len = crop_len
|
||||||
|
|
||||||
|
print(f"Loaded Dataset: {self.n_samples} files. Dim: {self.input_dim}")
|
||||||
|
|
||||||
|
# 3. Handle PCA
|
||||||
|
self.pca = pca_model
|
||||||
|
if pca_dim is not None and self.input_dim > pca_dim:
|
||||||
|
if self.pca is None:
|
||||||
|
print(f"Fitting PCA to reduce dim from {self.input_dim} -> {pca_dim}...")
|
||||||
|
# Fit on a subset (first 100k frames) to be fast
|
||||||
|
subset_size = min(len(self.data), 100000)
|
||||||
|
subset = self.data[:subset_size]
|
||||||
|
self.pca = PCA(n_components=pca_dim)
|
||||||
|
self.pca.fit(subset)
|
||||||
|
print("PCA Fit Complete.")
|
||||||
|
else:
|
||||||
|
print("Using provided PCA model.")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.n_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# 1. Locate the sentence
|
||||||
|
start = self.offsets[idx]
|
||||||
|
length = self.lengths[idx]
|
||||||
|
|
||||||
|
# 2. Extract Data
|
||||||
|
# If training (crop_len set), pick a random window
|
||||||
|
if self.crop_len and length > self.crop_len:
|
||||||
|
# Random Offset
|
||||||
|
max_start = length - self.crop_len
|
||||||
|
offset = np.random.randint(0, max_start + 1)
|
||||||
|
|
||||||
|
# Slice the mmap array
|
||||||
|
raw_seq = self.data[start+offset : start+offset+self.crop_len]
|
||||||
|
else:
|
||||||
|
# Validation/Inference (Return full sequence)
|
||||||
|
# Note: Batch size must be 1 for variable lengths!
|
||||||
|
raw_seq = self.data[start : start+length]
|
||||||
|
|
||||||
|
# 3. Apply PCA (On the fly)
|
||||||
|
if self.pca is not None:
|
||||||
|
raw_seq = self.pca.transform(raw_seq)
|
||||||
|
|
||||||
|
# 4. Convert to Tensor
|
||||||
|
return torch.tensor(raw_seq, dtype=torch.float32)
|
||||||
|
|
||||||
|
def get_real_dataloaders(npy_path, batch_size, crop_len=200, pca_dim=30):
|
||||||
|
# 1. Training Set (Random Crops)
|
||||||
|
train_ds = RealAudioDataset(npy_path, crop_len=crop_len, pca_dim=pca_dim)
|
||||||
|
|
||||||
|
# 2. Validation Set (Full Sequences, Shared PCA)
|
||||||
|
# We use batch_size=1 because lengths vary!
|
||||||
|
val_ds = RealAudioDataset(npy_path, crop_len=None, pca_dim=pca_dim, pca_model=train_ds.pca)
|
||||||
|
|
||||||
|
return train_ds, val_ds
|
||||||
@@ -1,48 +1,33 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
def generate_toy_data(n_samples=50, seq_len=300, n_clusters=10, dim=5):
|
def generate_toy_data(n_samples=500, seq_len=300, n_clusters=10, dim=5):
|
||||||
"""
|
"""
|
||||||
Generates sequences where the hidden states are clusters in D-dimensional space.
|
Generates sequences where the hidden states are clusters in D-dimensional space.
|
||||||
|
|
||||||
Args:
|
|
||||||
n_samples: Number of audio 'files'
|
|
||||||
seq_len: Length of each file
|
|
||||||
n_clusters: Number of states (phonemes)
|
|
||||||
dim: Number of features (e.g. 30 for Wav2Vec PCA, 5 for testing)
|
|
||||||
"""
|
"""
|
||||||
data_list = []
|
data_list = []
|
||||||
|
|
||||||
# 1. Generate Random Cluster Centers
|
# Cluster Centers (Spread out)
|
||||||
# We multiply by 10 to ensure they are far apart in space
|
|
||||||
# Shape: (10, 5)
|
|
||||||
centers = np.random.randn(n_clusters, dim) * 10.0
|
centers = np.random.randn(n_clusters, dim) * 10.0
|
||||||
|
|
||||||
print(f"Generated {n_clusters} cluster centers in {dim}D space.")
|
print(f"Generated {n_clusters} cluster centers in {dim}D space.")
|
||||||
print(f"Example Center 0: {np.round(centers[0], 2)}")
|
|
||||||
|
|
||||||
for _ in range(n_samples):
|
for _ in range(n_samples):
|
||||||
seq = []
|
seq = []
|
||||||
state = np.random.randint(0, n_clusters)
|
state = np.random.randint(0, n_clusters)
|
||||||
|
|
||||||
t = 0
|
t = 0
|
||||||
while t < seq_len:
|
while t < seq_len:
|
||||||
# Random duration
|
|
||||||
dur = np.random.randint(10, 30)
|
dur = np.random.randint(10, 30)
|
||||||
|
|
||||||
# 2. Generate Segment
|
# Segment: Center + Noise
|
||||||
# Shape: (Duration, Dim)
|
noise = np.random.randn(dur, dim)
|
||||||
# Center[state] + Gaussian Noise
|
|
||||||
noise = np.random.randn(dur, dim) # Standard normal noise
|
|
||||||
segment = noise + centers[state]
|
segment = noise + centers[state]
|
||||||
seq.append(segment)
|
seq.append(segment)
|
||||||
|
|
||||||
# 3. Transition (No self-loops)
|
# Transition (No self-loops)
|
||||||
next_state = state
|
next_state = state
|
||||||
while next_state == state:
|
while next_state == state:
|
||||||
next_state = np.random.randint(0, n_clusters)
|
next_state = np.random.randint(0, n_clusters)
|
||||||
state = next_state
|
state = next_state
|
||||||
|
|
||||||
t += dur
|
t += dur
|
||||||
|
|
||||||
full_seq = np.concatenate(seq)[:seq_len]
|
full_seq = np.concatenate(seq)[:seq_len]
|
||||||
|
|||||||
Reference in New Issue
Block a user