120 lines
4.5 KiB
Python
120 lines
4.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import List
|
|
|
|
# --- Batched JIT Forward Loop ---
|
|
@torch.jit.script
|
|
def hsmm_forward_loop_batched(T: int, N: int, D_max: int,
|
|
log_emit: torch.Tensor,
|
|
log_trans: torch.Tensor,
|
|
log_dur: torch.Tensor,
|
|
log_pi: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Computes Marginal Log-Likelihood for a BATCH of sequences in parallel.
|
|
Uses a list for alpha history to avoid in-place modification errors.
|
|
"""
|
|
BatchSize = log_emit.shape[0]
|
|
|
|
# alpha_list[t] will hold the alpha tensor for time t
|
|
alpha_list: List[torch.Tensor] = []
|
|
|
|
for t in range(T):
|
|
# Init accumulator for this time step with -inf
|
|
current_alpha = torch.full((BatchSize, N), -float('inf'), device=log_emit.device)
|
|
|
|
for d in range(1, D_max + 1):
|
|
if t - d + 1 < 0: continue
|
|
|
|
# 1. Emission Sum for segment: Sum(t-d+1 ... t)
|
|
# Slice: (Batch, d, N) -> Sum dim 1 -> (Batch, N)
|
|
seg_emit = log_emit[:, t-d+1 : t+1, :].sum(dim=1)
|
|
|
|
# 2. Duration: (N) -> (1, N)
|
|
dur_score = log_dur[:, d-1].unsqueeze(0)
|
|
|
|
# 3. Transition Logic
|
|
if t - d + 1 == 0:
|
|
# Initialization (Segment starts at t=0)
|
|
path_score = log_pi.unsqueeze(0) + dur_score + seg_emit
|
|
current_alpha = torch.logaddexp(current_alpha, path_score)
|
|
else:
|
|
# Transition from prev_alpha at t-d
|
|
prev_alpha = alpha_list[t-d]
|
|
|
|
# Broadcast for Transition Matrix:
|
|
# prev: (Batch, N, 1)
|
|
# trans: (1, N, N)
|
|
# sum: (Batch, N, N) -> LogSumExp over Prev State -> (Batch, N)
|
|
trans_score = torch.logsumexp(
|
|
prev_alpha.unsqueeze(2) + log_trans.unsqueeze(0),
|
|
dim=1
|
|
)
|
|
|
|
path_score = trans_score + dur_score + seg_emit
|
|
current_alpha = torch.logaddexp(current_alpha, path_score)
|
|
|
|
alpha_list.append(current_alpha)
|
|
|
|
# Final sum over states for each batch element: (Batch,)
|
|
return torch.logsumexp(alpha_list[-1], dim=1)
|
|
|
|
|
|
class BatchedGaussianHSMM(nn.Module):
|
|
def __init__(self, n_states, input_dim, max_dur=20):
|
|
super().__init__()
|
|
self.n_states = n_states
|
|
self.max_dur = max_dur
|
|
|
|
# --- Learnable Parameters ---
|
|
self.pi_logits = nn.Parameter(torch.randn(n_states))
|
|
self.trans_logits = nn.Parameter(torch.randn(n_states, n_states))
|
|
self.dur_logits = nn.Parameter(torch.randn(n_states, max_dur))
|
|
self.means = nn.Parameter(torch.randn(n_states, input_dim))
|
|
self.log_vars = nn.Parameter(torch.zeros(n_states, input_dim))
|
|
|
|
def compute_emission_log_probs(self, x):
|
|
"""
|
|
Calculates Gaussian Log-Likelihood for a Batch.
|
|
x: (Batch, T, Dim)
|
|
Returns: (Batch, T, N_States)
|
|
"""
|
|
# x: (Batch, T, 1, Dim)
|
|
# means: (1, 1, N, Dim)
|
|
diff = x.unsqueeze(2) - self.means.reshape(1, 1, self.n_states, -1)
|
|
|
|
vars = self.log_vars.exp().reshape(1, 1, self.n_states, -1)
|
|
log_vars = self.log_vars.reshape(1, 1, self.n_states, -1)
|
|
|
|
# Log Gaussian PDF
|
|
log_prob = -0.5 * (torch.log(torch.tensor(2 * 3.14159, device=x.device)) + log_vars + (diff**2) / vars)
|
|
|
|
# Sum over Feature Dimension -> (Batch, T, N)
|
|
return log_prob.sum(dim=-1)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
x: (Batch, T, Dim)
|
|
Returns: Scalar Loss (Mean NLL over batch)
|
|
"""
|
|
B, T, D = x.shape
|
|
|
|
log_emit = self.compute_emission_log_probs(x) # (B, T, N)
|
|
|
|
# Mask diagonal of transition matrix (No self-loops)
|
|
mask = torch.eye(self.n_states, device=self.trans_logits.device).bool()
|
|
masked_trans = self.trans_logits.masked_fill(mask, -float('inf'))
|
|
|
|
log_trans = F.log_softmax(masked_trans, dim=1)
|
|
log_dur = F.log_softmax(self.dur_logits, dim=1)
|
|
log_pi = F.log_softmax(self.pi_logits, dim=0)
|
|
|
|
# Run Batched JIT Loop
|
|
batch_log_likelihoods = hsmm_forward_loop_batched(
|
|
T, self.n_states, self.max_dur,
|
|
log_emit, log_trans, log_dur, log_pi
|
|
)
|
|
|
|
# Return Mean Negative Log Likelihood
|
|
return -batch_log_likelihoods.mean()
|