import torch import torch.optim as optim import matplotlib.pyplot as plt from hsmm_model import GaussianHSMM from hsmm_inference import viterbi_decode from toy_data import generate_toy_data # --- Settings --- N_STATES = 10 INPUT_DIM = 5 # Matches the 'dim' in generate_toy_data MAX_DUR = 50 LR = 0.05 EPOCHS = 20 # In train(): def train(): print("1. Generating Data...") train_data = generate_toy_data(n_samples=30, seq_len=300, n_clusters=N_STATES, dim=INPUT_DIM) print("2. Initializing Model...") model = GaussianHSMM(N_STATES, INPUT_DIM, MAX_DUR) optimizer = optim.Adam(model.parameters(), lr=LR) print("3. Training Loop...") loss_history = [] for epoch in range(EPOCHS): epoch_loss = 0 optimizer.zero_grad() # Batching: Gradient Accumulation for seq in train_data: loss = model(seq) # Forward pass loss.backward() # Backward pass epoch_loss += loss.item() # Normalize gradients for p in model.parameters(): if p.grad is not None: p.grad /= len(train_data) optimizer.step() loss_history.append(epoch_loss) if epoch % 5 == 0: print(f"Epoch {epoch:02d} | NLL Loss: {epoch_loss:.2f}") # --- Verification --- print("\n4. Results:") learned_means = model.means.detach().view(-1).numpy() learned_means.sort() print(f"True Means: [-5.0, 0.0, 5.0]") print(f"Learned Means: {learned_means}") # --- Visualization Block in main.py --- print("5. Visualizing Inference...") test_seq = train_data[0] predicted_path = viterbi_decode(model, test_seq) fig, ax = plt.subplots(2, 1, figsize=(12, 6), sharex=True) # Plot 1: The Multi-Dimensional Data (Transposed so Time is X-axis) # This shows the "features" changing color as the state changes ax[0].imshow(test_seq.numpy().T, aspect='auto', cmap='viridis', interpolation='nearest') ax[0].set_title(f"Raw Data ({INPUT_DIM} Dimensions)") ax[0].set_ylabel("Feature Dim") # Plot 2: The Inferred States # Reshape path to (1, T) for imshow path_img = np.array(predicted_path)[np.newaxis, :] ax[1].imshow(path_img, aspect='auto', cmap='tab10', interpolation='nearest') ax[1].set_title("Inferred HSMM States") ax[1].set_ylabel("State ID") ax[1].set_xlabel("Time (Frames)") plt.tight_layout() plt.show() if __name__ == "__main__": train()