Files
i6_setups/hsmm/hsmm_inference.py
2026-01-22 13:50:41 +01:00

71 lines
2.5 KiB
Python

import torch
import torch.nn.functional as F
def viterbi_decode(model, x):
"""
Returns the optimal sequence of states (path) using Viterbi algorithm.
x: (Time, Dim) or (1, Time, Dim)
"""
with torch.no_grad():
# Handle Batch Dimension if missing
if x.dim() == 2:
x = x.unsqueeze(0) # (1, T, D)
T = x.shape[1]
N = model.n_states
D_max = model.max_dur
# 1. Get Probs (Using Batched Model Function)
log_emit = model.compute_emission_log_probs(x)
log_emit = log_emit.squeeze(0) # (T, N)
# Get other probs
mask = torch.eye(N, device=x.device).bool()
log_trans = F.log_softmax(model.trans_logits.masked_fill(mask, -float('inf')), dim=1)
log_dur = F.log_softmax(model.dur_logits, dim=1)
log_pi = F.log_softmax(model.pi_logits, dim=0)
# 2. Viterbi Tables
max_prob = torch.full((T, N), -float('inf'), device=x.device)
backpointers = {}
# 3. Dynamic Programming Loop
for t in range(T):
for d in range(1, D_max + 1):
if t - d + 1 < 0: continue
seg_emit = log_emit[t-d+1 : t+1].sum(dim=0)
dur_prob = log_dur[:, d-1]
if t - d + 1 == 0:
score = log_pi + dur_prob + seg_emit
for s in range(N):
if score[s] > max_prob[t, s]:
max_prob[t, s] = score[s]
backpointers[(t, s)] = (-1, d)
else:
prev_scores = max_prob[t-d]
trans_scores = prev_scores.unsqueeze(1) + log_trans
best_prev_score, best_prev_idx = trans_scores.max(dim=0)
current_score = best_prev_score + dur_prob + seg_emit
for s in range(N):
if current_score[s] > max_prob[t, s]:
max_prob[t, s] = current_score[s]
backpointers[(t, s)] = (best_prev_idx[s].item(), d)
# 4. Backtracking
best_end_state = torch.argmax(max_prob[T-1]).item()
path = []
curr_t = T - 1
curr_s = best_end_state
while curr_t >= 0:
if (curr_t, curr_s) not in backpointers: break
prev_s, d = backpointers[(curr_t, curr_s)]
path = [curr_s] * d + path
curr_t -= d
curr_s = prev_s
return path