This commit is contained in:
jbkzi
2026-01-22 13:50:41 +01:00
parent e5cb97d2e5
commit 251fd3e9be
6 changed files with 323 additions and 159 deletions

View File

@@ -1,48 +1,33 @@
import torch
import numpy as np
def generate_toy_data(n_samples=50, seq_len=300, n_clusters=10, dim=5):
def generate_toy_data(n_samples=500, seq_len=300, n_clusters=10, dim=5):
"""
Generates sequences where the hidden states are clusters in D-dimensional space.
Args:
n_samples: Number of audio 'files'
seq_len: Length of each file
n_clusters: Number of states (phonemes)
dim: Number of features (e.g. 30 for Wav2Vec PCA, 5 for testing)
"""
data_list = []
# 1. Generate Random Cluster Centers
# We multiply by 10 to ensure they are far apart in space
# Shape: (10, 5)
# Cluster Centers (Spread out)
centers = np.random.randn(n_clusters, dim) * 10.0
print(f"Generated {n_clusters} cluster centers in {dim}D space.")
print(f"Example Center 0: {np.round(centers[0], 2)}")
for _ in range(n_samples):
seq = []
state = np.random.randint(0, n_clusters)
t = 0
while t < seq_len:
# Random duration
dur = np.random.randint(10, 30)
# 2. Generate Segment
# Shape: (Duration, Dim)
# Center[state] + Gaussian Noise
noise = np.random.randn(dur, dim) # Standard normal noise
# Segment: Center + Noise
noise = np.random.randn(dur, dim)
segment = noise + centers[state]
seq.append(segment)
# 3. Transition (No self-loops)
# Transition (No self-loops)
next_state = state
while next_state == state:
next_state = np.random.randint(0, n_clusters)
state = next_state
t += dur
full_seq = np.concatenate(seq)[:seq_len]