Files
i6_setups/hsmm/real_data.py
2026-01-22 13:50:41 +01:00

92 lines
3.5 KiB
Python

import torch
from torch.utils.data import Dataset
import numpy as np
import os
from sklearn.decomposition import PCA
class RealAudioDataset(Dataset):
def __init__(self, npy_path, len_path=None, crop_len=None, pca_dim=None, pca_model=None):
"""
npy_path: Path to the huge .npy file
len_path: Path to the .lengths file (optional, tries to infer if None)
crop_len: If set (e.g., 200), we randomly crop sequences to this length for training.
pca_dim: If set (e.g., 30), we learn/apply PCA reduction.
"""
# 1. Load Data (Memory Mapped to save RAM)
if not os.path.exists(npy_path):
raise FileNotFoundError(f"Could not find {npy_path}")
self.data = np.load(npy_path, mmap_mode='r')
self.input_dim = self.data.shape[1]
# 2. Load Lengths
if len_path is None:
# Assume .lengths is next to .npy
len_path = npy_path.replace('.npy', '.lengths')
if not os.path.exists(len_path):
raise FileNotFoundError(f"Could not find length file: {len_path}")
with open(len_path, 'r') as f:
self.lengths = [int(x) for x in f.read().strip().split()]
# Create Offsets (Where each sentence starts in the flat file)
self.offsets = np.cumsum([0] + self.lengths[:-1])
self.n_samples = len(self.lengths)
self.crop_len = crop_len
print(f"Loaded Dataset: {self.n_samples} files. Dim: {self.input_dim}")
# 3. Handle PCA
self.pca = pca_model
if pca_dim is not None and self.input_dim > pca_dim:
if self.pca is None:
print(f"Fitting PCA to reduce dim from {self.input_dim} -> {pca_dim}...")
# Fit on a subset (first 100k frames) to be fast
subset_size = min(len(self.data), 100000)
subset = self.data[:subset_size]
self.pca = PCA(n_components=pca_dim)
self.pca.fit(subset)
print("PCA Fit Complete.")
else:
print("Using provided PCA model.")
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
# 1. Locate the sentence
start = self.offsets[idx]
length = self.lengths[idx]
# 2. Extract Data
# If training (crop_len set), pick a random window
if self.crop_len and length > self.crop_len:
# Random Offset
max_start = length - self.crop_len
offset = np.random.randint(0, max_start + 1)
# Slice the mmap array
raw_seq = self.data[start+offset : start+offset+self.crop_len]
else:
# Validation/Inference (Return full sequence)
# Note: Batch size must be 1 for variable lengths!
raw_seq = self.data[start : start+length]
# 3. Apply PCA (On the fly)
if self.pca is not None:
raw_seq = self.pca.transform(raw_seq)
# 4. Convert to Tensor
return torch.tensor(raw_seq, dtype=torch.float32)
def get_real_dataloaders(npy_path, batch_size, crop_len=200, pca_dim=30):
# 1. Training Set (Random Crops)
train_ds = RealAudioDataset(npy_path, crop_len=crop_len, pca_dim=pca_dim)
# 2. Validation Set (Full Sequences, Shared PCA)
# We use batch_size=1 because lengths vary!
val_ds = RealAudioDataset(npy_path, crop_len=None, pca_dim=pca_dim, pca_model=train_ds.pca)
return train_ds, val_ds