testing
This commit is contained in:
@@ -3,51 +3,52 @@ import torch.nn.functional as F
|
||||
|
||||
def viterbi_decode(model, x):
|
||||
"""
|
||||
Returns the optimal sequence of states (path).
|
||||
Returns the optimal sequence of states (path) using Viterbi algorithm.
|
||||
x: (Time, Dim) or (1, Time, Dim)
|
||||
"""
|
||||
with torch.no_grad():
|
||||
T = x.shape[0]
|
||||
# Handle Batch Dimension if missing
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(0) # (1, T, D)
|
||||
|
||||
T = x.shape[1]
|
||||
N = model.n_states
|
||||
D_max = model.max_dur
|
||||
|
||||
# 1. Setup Probs
|
||||
log_emit = model.compute_emission_log_probs(x)
|
||||
log_trans = F.log_softmax(model.get_masked_transitions(), dim=1)
|
||||
# 1. Get Probs (Using Batched Model Function)
|
||||
log_emit = model.compute_emission_log_probs(x)
|
||||
log_emit = log_emit.squeeze(0) # (T, N)
|
||||
|
||||
# Get other probs
|
||||
mask = torch.eye(N, device=x.device).bool()
|
||||
log_trans = F.log_softmax(model.trans_logits.masked_fill(mask, -float('inf')), dim=1)
|
||||
log_dur = F.log_softmax(model.dur_logits, dim=1)
|
||||
log_pi = F.log_softmax(model.pi_logits, dim=0)
|
||||
|
||||
# 2. Viterbi Tables
|
||||
# max_prob[t, s] = Best log-prob ending at t in state s
|
||||
max_prob = torch.full((T, N), -float('inf'), device=x.device)
|
||||
# backpointers[t, s] = (previous_state, duration_used)
|
||||
backpointers = {}
|
||||
|
||||
# 3. Dynamic Programming
|
||||
# 3. Dynamic Programming Loop
|
||||
for t in range(T):
|
||||
for d in range(1, D_max + 1):
|
||||
if t - d + 1 < 0: continue
|
||||
|
||||
# Emission sum for segment
|
||||
seg_emit = log_emit[t-d+1 : t+1].sum(dim=0)
|
||||
dur_prob = log_dur[:, d-1]
|
||||
|
||||
if t - d + 1 == 0:
|
||||
# Init
|
||||
score = log_pi + dur_prob + seg_emit
|
||||
for s in range(N):
|
||||
if score[s] > max_prob[t, s]:
|
||||
max_prob[t, s] = score[s]
|
||||
backpointers[(t, s)] = (-1, d) # -1 is Start
|
||||
backpointers[(t, s)] = (-1, d)
|
||||
else:
|
||||
# Transition
|
||||
prev_scores = max_prob[t-d] # (N,)
|
||||
# Find best transition for each target state s
|
||||
# (N, 1) + (N, N) -> (N, N)
|
||||
prev_scores = max_prob[t-d]
|
||||
trans_scores = prev_scores.unsqueeze(1) + log_trans
|
||||
best_prev_score, best_prev_idx = trans_scores.max(dim=0) # (N,)
|
||||
best_prev_score, best_prev_idx = trans_scores.max(dim=0)
|
||||
|
||||
current_score = best_prev_score + dur_prob + seg_emit
|
||||
|
||||
for s in range(N):
|
||||
if current_score[s] > max_prob[t, s]:
|
||||
max_prob[t, s] = current_score[s]
|
||||
@@ -62,10 +63,8 @@ def viterbi_decode(model, x):
|
||||
while curr_t >= 0:
|
||||
if (curr_t, curr_s) not in backpointers: break
|
||||
prev_s, d = backpointers[(curr_t, curr_s)]
|
||||
|
||||
# Append this state 'd' times
|
||||
path = [curr_s] * d + path
|
||||
curr_t -= d
|
||||
curr_s = prev_s
|
||||
|
||||
return path
|
||||
return path
|
||||
|
||||
Reference in New Issue
Block a user