import torch import torch.nn as nn import torch.nn.functional as F from typing import List # --- Batched JIT Forward Loop --- @torch.jit.script def hsmm_forward_loop_batched(T: int, N: int, D_max: int, log_emit: torch.Tensor, log_trans: torch.Tensor, log_dur: torch.Tensor, log_pi: torch.Tensor) -> torch.Tensor: """ Computes Marginal Log-Likelihood for a BATCH of sequences in parallel. Uses a list for alpha history to avoid in-place modification errors. """ BatchSize = log_emit.shape[0] # alpha_list[t] will hold the alpha tensor for time t alpha_list: List[torch.Tensor] = [] for t in range(T): # Init accumulator for this time step with -inf current_alpha = torch.full((BatchSize, N), -float('inf'), device=log_emit.device) for d in range(1, D_max + 1): if t - d + 1 < 0: continue # 1. Emission Sum for segment: Sum(t-d+1 ... t) # Slice: (Batch, d, N) -> Sum dim 1 -> (Batch, N) seg_emit = log_emit[:, t-d+1 : t+1, :].sum(dim=1) # 2. Duration: (N) -> (1, N) dur_score = log_dur[:, d-1].unsqueeze(0) # 3. Transition Logic if t - d + 1 == 0: # Initialization (Segment starts at t=0) path_score = log_pi.unsqueeze(0) + dur_score + seg_emit current_alpha = torch.logaddexp(current_alpha, path_score) else: # Transition from prev_alpha at t-d prev_alpha = alpha_list[t-d] # Broadcast for Transition Matrix: # prev: (Batch, N, 1) # trans: (1, N, N) # sum: (Batch, N, N) -> LogSumExp over Prev State -> (Batch, N) trans_score = torch.logsumexp( prev_alpha.unsqueeze(2) + log_trans.unsqueeze(0), dim=1 ) path_score = trans_score + dur_score + seg_emit current_alpha = torch.logaddexp(current_alpha, path_score) alpha_list.append(current_alpha) # Final sum over states for each batch element: (Batch,) return torch.logsumexp(alpha_list[-1], dim=1) class BatchedGaussianHSMM(nn.Module): def __init__(self, n_states, input_dim, max_dur=20): super().__init__() self.n_states = n_states self.max_dur = max_dur # --- Learnable Parameters --- self.pi_logits = nn.Parameter(torch.randn(n_states)) self.trans_logits = nn.Parameter(torch.randn(n_states, n_states)) self.dur_logits = nn.Parameter(torch.randn(n_states, max_dur)) self.means = nn.Parameter(torch.randn(n_states, input_dim)) self.log_vars = nn.Parameter(torch.zeros(n_states, input_dim)) def compute_emission_log_probs(self, x): """ Calculates Gaussian Log-Likelihood for a Batch. x: (Batch, T, Dim) Returns: (Batch, T, N_States) """ # x: (Batch, T, 1, Dim) # means: (1, 1, N, Dim) diff = x.unsqueeze(2) - self.means.reshape(1, 1, self.n_states, -1) vars = self.log_vars.exp().reshape(1, 1, self.n_states, -1) log_vars = self.log_vars.reshape(1, 1, self.n_states, -1) # Log Gaussian PDF log_prob = -0.5 * (torch.log(torch.tensor(2 * 3.14159, device=x.device)) + log_vars + (diff**2) / vars) # Sum over Feature Dimension -> (Batch, T, N) return log_prob.sum(dim=-1) def forward(self, x): """ x: (Batch, T, Dim) Returns: Scalar Loss (Mean NLL over batch) """ B, T, D = x.shape log_emit = self.compute_emission_log_probs(x) # (B, T, N) # Mask diagonal of transition matrix (No self-loops) mask = torch.eye(self.n_states, device=self.trans_logits.device).bool() masked_trans = self.trans_logits.masked_fill(mask, -float('inf')) log_trans = F.log_softmax(masked_trans, dim=1) log_dur = F.log_softmax(self.dur_logits, dim=1) log_pi = F.log_softmax(self.pi_logits, dim=0) # Run Batched JIT Loop batch_log_likelihoods = hsmm_forward_loop_batched( T, self.n_states, self.max_dur, log_emit, log_trans, log_dur, log_pi ) # Return Mean Negative Log Likelihood return -batch_log_likelihoods.mean()