import torch import torch.nn as nn import torch.nn.functional as F import numpy as np # --- JIT Compiled Forward Loop for Speed --- from typing import List @torch.jit.script def hsmm_forward_loop(T: int, N: int, D_max: int, log_emit: torch.Tensor, log_trans: torch.Tensor, log_dur: torch.Tensor, log_pi: torch.Tensor) -> torch.Tensor: # We use a List to store alpha steps. This avoids "in-place" errors. # We initialize with dummy tensors to handle negative indexing if needed, # but logically we just append. alpha_list: List[torch.Tensor] = [] for t in range(T): # Initialize current step with -inf current_alpha = torch.full((N,), -float('inf'), device=log_emit.device) for d in range(1, D_max + 1): if t - d + 1 < 0: continue # 1. Emission Score (Sum) seg_emit = log_emit[t-d+1 : t+1].sum(dim=0) # 2. Duration Score dur_score = log_dur[:, d-1] # 3. Transition Score if t - d + 1 == 0: # Init (t=0 or start of a duration) score = log_pi + dur_score + seg_emit current_alpha = torch.logaddexp(current_alpha, score) else: # Recursion: look back at alpha[t-d] # In a list, alpha[t-d] is just alpha_list[t-d] prev_alpha = alpha_list[t-d] trans_score = torch.logsumexp(prev_alpha.unsqueeze(1) + log_trans, dim=0) score = trans_score + dur_score + seg_emit current_alpha = torch.logaddexp(current_alpha, score) # Save this step to the history list alpha_list.append(current_alpha) # The result is the sum of the very last step return torch.logsumexp(alpha_list[-1], dim=0) class GaussianHSMM(nn.Module): def __init__(self, n_states, input_dim, max_dur=20): super().__init__() self.n_states = n_states self.max_dur = max_dur # --- Parameters --- # 1. Start Probabilities self.pi_logits = nn.Parameter(torch.randn(n_states)) # 2. Transition Matrix (Bigram) # We manually mask the diagonal later to forbid self-transitions self.trans_logits = nn.Parameter(torch.randn(n_states, n_states)) # 3. Duration Model (Categorical weights for 1..max_dur) self.dur_logits = nn.Parameter(torch.randn(n_states, max_dur)) # 4. Emissions (Gaussian Means & LogVars) self.means = nn.Parameter(torch.randn(n_states, input_dim)) self.log_vars = nn.Parameter(torch.zeros(n_states, input_dim)) def compute_emission_log_probs(self, x): """ Calculates Gaussian Log-Likelihood: (T, N) """ # x: (T, Dim) -> (T, 1, Dim) # means: (N, Dim) -> (1, N, Dim) # log_prob = -0.5 * (log(2pi) + log_var + (x-mu)^2/var) diff = x.unsqueeze(1) - self.means.unsqueeze(0) vars = self.log_vars.exp().unsqueeze(0) log_vars = self.log_vars.unsqueeze(0) log_prob = -0.5 * (np.log(2 * np.pi) + log_vars + (diff**2) / vars) return log_prob.sum(dim=-1) # Sum over feature dimensions def get_masked_transitions(self): """ Enforces A_ii = -inf (No self-transitions allowed in Bigram) """ mask = torch.eye(self.n_states, device=self.trans_logits.device).bool() return self.trans_logits.masked_fill(mask, -float('inf')) def forward(self, x): """ Returns Negative Log Likelihood (Scalar Loss) """ T = x.shape[0] # 1. Precompute static probabilities log_emit = self.compute_emission_log_probs(x) log_trans = F.log_softmax(self.get_masked_transitions(), dim=1) log_dur = F.log_softmax(self.dur_logits, dim=1) log_pi = F.log_softmax(self.pi_logits, dim=0) # 2. Run JIT Loop total_ll = hsmm_forward_loop(T, self.n_states, self.max_dur, log_emit, log_trans, log_dur, log_pi) return -total_ll