37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
def generate_toy_data(n_samples=500, seq_len=300, n_clusters=10, dim=5):
|
|
"""
|
|
Generates sequences where the hidden states are clusters in D-dimensional space.
|
|
"""
|
|
data_list = []
|
|
|
|
# Cluster Centers (Spread out)
|
|
centers = np.random.randn(n_clusters, dim) * 10.0
|
|
print(f"Generated {n_clusters} cluster centers in {dim}D space.")
|
|
|
|
for _ in range(n_samples):
|
|
seq = []
|
|
state = np.random.randint(0, n_clusters)
|
|
t = 0
|
|
while t < seq_len:
|
|
dur = np.random.randint(10, 30)
|
|
|
|
# Segment: Center + Noise
|
|
noise = np.random.randn(dur, dim)
|
|
segment = noise + centers[state]
|
|
seq.append(segment)
|
|
|
|
# Transition (No self-loops)
|
|
next_state = state
|
|
while next_state == state:
|
|
next_state = np.random.randint(0, n_clusters)
|
|
state = next_state
|
|
t += dur
|
|
|
|
full_seq = np.concatenate(seq)[:seq_len]
|
|
data_list.append(torch.tensor(full_seq, dtype=torch.float32))
|
|
|
|
return data_list
|