initial
This commit is contained in:
107
hsmm/hsmm_model.py
Normal file
107
hsmm/hsmm_model.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
# --- JIT Compiled Forward Loop for Speed ---
|
||||
from typing import List
|
||||
|
||||
@torch.jit.script
|
||||
def hsmm_forward_loop(T: int, N: int, D_max: int,
|
||||
log_emit: torch.Tensor,
|
||||
log_trans: torch.Tensor,
|
||||
log_dur: torch.Tensor,
|
||||
log_pi: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# We use a List to store alpha steps. This avoids "in-place" errors.
|
||||
# We initialize with dummy tensors to handle negative indexing if needed,
|
||||
# but logically we just append.
|
||||
alpha_list: List[torch.Tensor] = []
|
||||
|
||||
for t in range(T):
|
||||
# Initialize current step with -inf
|
||||
current_alpha = torch.full((N,), -float('inf'), device=log_emit.device)
|
||||
|
||||
for d in range(1, D_max + 1):
|
||||
if t - d + 1 < 0: continue
|
||||
|
||||
# 1. Emission Score (Sum)
|
||||
seg_emit = log_emit[t-d+1 : t+1].sum(dim=0)
|
||||
|
||||
# 2. Duration Score
|
||||
dur_score = log_dur[:, d-1]
|
||||
|
||||
# 3. Transition Score
|
||||
if t - d + 1 == 0:
|
||||
# Init (t=0 or start of a duration)
|
||||
score = log_pi + dur_score + seg_emit
|
||||
current_alpha = torch.logaddexp(current_alpha, score)
|
||||
else:
|
||||
# Recursion: look back at alpha[t-d]
|
||||
# In a list, alpha[t-d] is just alpha_list[t-d]
|
||||
prev_alpha = alpha_list[t-d]
|
||||
|
||||
trans_score = torch.logsumexp(prev_alpha.unsqueeze(1) + log_trans, dim=0)
|
||||
score = trans_score + dur_score + seg_emit
|
||||
current_alpha = torch.logaddexp(current_alpha, score)
|
||||
|
||||
# Save this step to the history list
|
||||
alpha_list.append(current_alpha)
|
||||
|
||||
# The result is the sum of the very last step
|
||||
return torch.logsumexp(alpha_list[-1], dim=0)
|
||||
|
||||
class GaussianHSMM(nn.Module):
|
||||
def __init__(self, n_states, input_dim, max_dur=20):
|
||||
super().__init__()
|
||||
self.n_states = n_states
|
||||
self.max_dur = max_dur
|
||||
|
||||
# --- Parameters ---
|
||||
# 1. Start Probabilities
|
||||
self.pi_logits = nn.Parameter(torch.randn(n_states))
|
||||
|
||||
# 2. Transition Matrix (Bigram)
|
||||
# We manually mask the diagonal later to forbid self-transitions
|
||||
self.trans_logits = nn.Parameter(torch.randn(n_states, n_states))
|
||||
|
||||
# 3. Duration Model (Categorical weights for 1..max_dur)
|
||||
self.dur_logits = nn.Parameter(torch.randn(n_states, max_dur))
|
||||
|
||||
# 4. Emissions (Gaussian Means & LogVars)
|
||||
self.means = nn.Parameter(torch.randn(n_states, input_dim))
|
||||
self.log_vars = nn.Parameter(torch.zeros(n_states, input_dim))
|
||||
|
||||
def compute_emission_log_probs(self, x):
|
||||
""" Calculates Gaussian Log-Likelihood: (T, N) """
|
||||
# x: (T, Dim) -> (T, 1, Dim)
|
||||
# means: (N, Dim) -> (1, N, Dim)
|
||||
# log_prob = -0.5 * (log(2pi) + log_var + (x-mu)^2/var)
|
||||
|
||||
diff = x.unsqueeze(1) - self.means.unsqueeze(0)
|
||||
vars = self.log_vars.exp().unsqueeze(0)
|
||||
log_vars = self.log_vars.unsqueeze(0)
|
||||
|
||||
log_prob = -0.5 * (np.log(2 * np.pi) + log_vars + (diff**2) / vars)
|
||||
return log_prob.sum(dim=-1) # Sum over feature dimensions
|
||||
|
||||
def get_masked_transitions(self):
|
||||
""" Enforces A_ii = -inf (No self-transitions allowed in Bigram) """
|
||||
mask = torch.eye(self.n_states, device=self.trans_logits.device).bool()
|
||||
return self.trans_logits.masked_fill(mask, -float('inf'))
|
||||
|
||||
def forward(self, x):
|
||||
""" Returns Negative Log Likelihood (Scalar Loss) """
|
||||
T = x.shape[0]
|
||||
|
||||
# 1. Precompute static probabilities
|
||||
log_emit = self.compute_emission_log_probs(x)
|
||||
log_trans = F.log_softmax(self.get_masked_transitions(), dim=1)
|
||||
log_dur = F.log_softmax(self.dur_logits, dim=1)
|
||||
log_pi = F.log_softmax(self.pi_logits, dim=0)
|
||||
|
||||
# 2. Run JIT Loop
|
||||
total_ll = hsmm_forward_loop(T, self.n_states, self.max_dur,
|
||||
log_emit, log_trans, log_dur, log_pi)
|
||||
|
||||
return -total_ll
|
||||
Reference in New Issue
Block a user