initial
This commit is contained in:
80
hsmm/main.py
Normal file
80
hsmm/main.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import matplotlib.pyplot as plt
|
||||
from hsmm_model import GaussianHSMM
|
||||
from hsmm_inference import viterbi_decode
|
||||
from toy_data import generate_toy_data
|
||||
|
||||
# --- Settings ---
|
||||
N_STATES = 10
|
||||
INPUT_DIM = 5 # Matches the 'dim' in generate_toy_data
|
||||
MAX_DUR = 50
|
||||
LR = 0.05
|
||||
EPOCHS = 20
|
||||
|
||||
# In train():
|
||||
def train():
|
||||
print("1. Generating Data...")
|
||||
train_data = generate_toy_data(n_samples=30, seq_len=300, n_clusters=N_STATES, dim=INPUT_DIM)
|
||||
|
||||
print("2. Initializing Model...")
|
||||
model = GaussianHSMM(N_STATES, INPUT_DIM, MAX_DUR)
|
||||
optimizer = optim.Adam(model.parameters(), lr=LR)
|
||||
|
||||
print("3. Training Loop...")
|
||||
loss_history = []
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
epoch_loss = 0
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Batching: Gradient Accumulation
|
||||
for seq in train_data:
|
||||
loss = model(seq) # Forward pass
|
||||
loss.backward() # Backward pass
|
||||
epoch_loss += loss.item()
|
||||
|
||||
# Normalize gradients
|
||||
for p in model.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad /= len(train_data)
|
||||
|
||||
optimizer.step()
|
||||
loss_history.append(epoch_loss)
|
||||
|
||||
if epoch % 5 == 0:
|
||||
print(f"Epoch {epoch:02d} | NLL Loss: {epoch_loss:.2f}")
|
||||
|
||||
# --- Verification ---
|
||||
print("\n4. Results:")
|
||||
learned_means = model.means.detach().view(-1).numpy()
|
||||
learned_means.sort()
|
||||
print(f"True Means: [-5.0, 0.0, 5.0]")
|
||||
print(f"Learned Means: {learned_means}")
|
||||
|
||||
# --- Visualization Block in main.py ---
|
||||
print("5. Visualizing Inference...")
|
||||
test_seq = train_data[0]
|
||||
predicted_path = viterbi_decode(model, test_seq)
|
||||
|
||||
fig, ax = plt.subplots(2, 1, figsize=(12, 6), sharex=True)
|
||||
|
||||
# Plot 1: The Multi-Dimensional Data (Transposed so Time is X-axis)
|
||||
# This shows the "features" changing color as the state changes
|
||||
ax[0].imshow(test_seq.numpy().T, aspect='auto', cmap='viridis', interpolation='nearest')
|
||||
ax[0].set_title(f"Raw Data ({INPUT_DIM} Dimensions)")
|
||||
ax[0].set_ylabel("Feature Dim")
|
||||
|
||||
# Plot 2: The Inferred States
|
||||
# Reshape path to (1, T) for imshow
|
||||
path_img = np.array(predicted_path)[np.newaxis, :]
|
||||
ax[1].imshow(path_img, aspect='auto', cmap='tab10', interpolation='nearest')
|
||||
ax[1].set_title("Inferred HSMM States")
|
||||
ax[1].set_ylabel("State ID")
|
||||
ax[1].set_xlabel("Time (Frames)")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user