import torch import torch.nn.functional as F def viterbi_decode(model, x): """ Returns the optimal sequence of states (path) using Viterbi algorithm. x: (Time, Dim) or (1, Time, Dim) """ with torch.no_grad(): # Handle Batch Dimension if missing if x.dim() == 2: x = x.unsqueeze(0) # (1, T, D) T = x.shape[1] N = model.n_states D_max = model.max_dur # 1. Get Probs (Using Batched Model Function) log_emit = model.compute_emission_log_probs(x) log_emit = log_emit.squeeze(0) # (T, N) # Get other probs mask = torch.eye(N, device=x.device).bool() log_trans = F.log_softmax(model.trans_logits.masked_fill(mask, -float('inf')), dim=1) log_dur = F.log_softmax(model.dur_logits, dim=1) log_pi = F.log_softmax(model.pi_logits, dim=0) # 2. Viterbi Tables max_prob = torch.full((T, N), -float('inf'), device=x.device) backpointers = {} # 3. Dynamic Programming Loop for t in range(T): for d in range(1, D_max + 1): if t - d + 1 < 0: continue seg_emit = log_emit[t-d+1 : t+1].sum(dim=0) dur_prob = log_dur[:, d-1] if t - d + 1 == 0: score = log_pi + dur_prob + seg_emit for s in range(N): if score[s] > max_prob[t, s]: max_prob[t, s] = score[s] backpointers[(t, s)] = (-1, d) else: prev_scores = max_prob[t-d] trans_scores = prev_scores.unsqueeze(1) + log_trans best_prev_score, best_prev_idx = trans_scores.max(dim=0) current_score = best_prev_score + dur_prob + seg_emit for s in range(N): if current_score[s] > max_prob[t, s]: max_prob[t, s] = current_score[s] backpointers[(t, s)] = (best_prev_idx[s].item(), d) # 4. Backtracking best_end_state = torch.argmax(max_prob[T-1]).item() path = [] curr_t = T - 1 curr_s = best_end_state while curr_t >= 0: if (curr_t, curr_s) not in backpointers: break prev_s, d = backpointers[(curr_t, curr_s)] path = [curr_s] * d + path curr_t -= d curr_s = prev_s return path