initial
This commit is contained in:
51
hsmm/toy_data.py
Normal file
51
hsmm/toy_data.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def generate_toy_data(n_samples=50, seq_len=300, n_clusters=10, dim=5):
|
||||
"""
|
||||
Generates sequences where the hidden states are clusters in D-dimensional space.
|
||||
|
||||
Args:
|
||||
n_samples: Number of audio 'files'
|
||||
seq_len: Length of each file
|
||||
n_clusters: Number of states (phonemes)
|
||||
dim: Number of features (e.g. 30 for Wav2Vec PCA, 5 for testing)
|
||||
"""
|
||||
data_list = []
|
||||
|
||||
# 1. Generate Random Cluster Centers
|
||||
# We multiply by 10 to ensure they are far apart in space
|
||||
# Shape: (10, 5)
|
||||
centers = np.random.randn(n_clusters, dim) * 10.0
|
||||
|
||||
print(f"Generated {n_clusters} cluster centers in {dim}D space.")
|
||||
print(f"Example Center 0: {np.round(centers[0], 2)}")
|
||||
|
||||
for _ in range(n_samples):
|
||||
seq = []
|
||||
state = np.random.randint(0, n_clusters)
|
||||
|
||||
t = 0
|
||||
while t < seq_len:
|
||||
# Random duration
|
||||
dur = np.random.randint(10, 30)
|
||||
|
||||
# 2. Generate Segment
|
||||
# Shape: (Duration, Dim)
|
||||
# Center[state] + Gaussian Noise
|
||||
noise = np.random.randn(dur, dim) # Standard normal noise
|
||||
segment = noise + centers[state]
|
||||
seq.append(segment)
|
||||
|
||||
# 3. Transition (No self-loops)
|
||||
next_state = state
|
||||
while next_state == state:
|
||||
next_state = np.random.randint(0, n_clusters)
|
||||
state = next_state
|
||||
|
||||
t += dur
|
||||
|
||||
full_seq = np.concatenate(seq)[:seq_len]
|
||||
data_list.append(torch.tensor(full_seq, dtype=torch.float32))
|
||||
|
||||
return data_list
|
||||
Reference in New Issue
Block a user