API Reference¶
Data — src.data¶
generate_datasets(seed, num_latents, k, n_samples, input_dim=None)¶
Main entry point. Returns (train, val, ood, A) where each split is (Z, Y, labels) and A is the ground-truth mixing matrix.
from src.data import generate_datasets
train, val, ood, A = generate_datasets(
seed=0, num_latents=10, k=3, n_samples=2000
)
data_setup(Z_iid, Y_iid, val_Z, val_Y, Z_ood, Y_ood, batch_size, device)¶
Convert numpy arrays to torch tensors and build a DataLoader for training.
Models¶
SAEConfig¶
Dataclass for SAE experiments.
from models.saes import SAEConfig, run_sae_experiment
cfg = SAEConfig(
num_latents=10, k=3, n_samples=2000,
width=20, sae_type="topk", # "relu", "topk", "jumprelu", "MP"
epochs=1000, lr=5e-4, seed=0,
)
model, results, data, A, cfg = run_sae_experiment(cfg)
SAE(input_dim, width, sae_type, kval_topk=None, mp_kval=None)¶
torch.nn.Module. Call forward(x, return_hidden=False) for codes + reconstruction, or decode(codes) for decoding only.
train_sae(model, train_loader, inputs_iid, inputs_ood, device, cfg)¶
Train a single SAE. Returns dict with codes, reconstructions, and per-epoch metrics.
run_sae_experiment(cfg=None, **overrides)¶
End-to-end: generate data, train, evaluate. Returns (model, results, data, A, cfg).
fista(x, D, lam, n_iter=100, lr=None, z_init=None, nonneg=False)¶
FISTA with Nesterov momentum. Use with the ground-truth dictionary for oracle sparse inference.
from models.sparse_coding import fista
import torch
codes = fista(Y, D=torch.from_numpy(A), lam=0.1, n_iter=100)
ista(x, D, lam, n_iter=100, lr=None, z_init=None)¶
Plain ISTA (no momentum).
SparseCodingConfig¶
from models.sparse_coding import SparseCodingConfig
cfg = SparseCodingConfig(
input_dim=10, num_latents=20,
method="fista", # "fista", "ista", "lista", "direct"
lam=0.1, n_iter=100,
)
train_sparse_coding(X_iid, X_ood, cfg, A=None, D_init=None, device=None)¶
Train sparse coding with method specified in cfg.method. Returns dict with codes, learned dictionary, and training history.
LISTA(n_obs, n_latent, n_unroll=16)¶
Learned ISTA encoder. Call init_from_dictionary(D) to initialise, then forward(x) for fast inference.
refine_from_sae(sae_model, X, lam, n_iter=20, nonneg=True)¶
Extract an SAE's decoder and warm-start FISTA from its encoder output. Returns dict with original codes, refined codes, dictionary, and MSE gap.
Metrics — utils.metrics¶
compute_mcc(Z_true, Z_pred, seed=42) -> float¶
Mean Correlation Coefficient via Hungarian matching. Returns score in [0, 1].
match_columns(D, A) -> dict¶
Match learned dictionary D to ground-truth A by cosine similarity (Hungarian algorithm). Returns matched indices, angular errors, and summary statistics.
compute_support_metrics(Z_true, codes, row_ind, col_ind, threshold=1e-4) -> dict¶
Sparsity pattern recovery: precision, recall, F1, L0.
evaluate_accuracy(codes_iid, labels_iid, codes_ood, labels_ood) -> dict¶
Logistic regression on ID, evaluate on both splits. Returns {acc_id, acc_ood}.
evaluate_auc(codes_iid, labels_iid, codes_ood, labels_ood) -> dict¶
Per-feature ROC-AUC: best feature on ID, report both. Returns {auc_id, auc_ood}.
evaluate_all(codes_iid, labels_iid, codes_ood, labels_ood, Z_true_iid=None, Z_true_ood=None) -> dict¶
Combined evaluation: accuracy + AUC + optional MCC.
Experiment Helpers — experiments._common¶
eval_and_tag(codes_iid, codes_ood, data, method, **extra) -> dict¶
Run evaluate_all and tag the result with method name and metadata.
run_all_saes(data, input_dim, width, k, num_latents, ...) -> list¶
Train and evaluate all SAE variants (ReLU, TopK, JumpReLU, MP).
run_sparse_coding_methods(data, A, input_dim, num_latents, ...) -> list¶
Train and evaluate FISTA oracle, DL-FISTA, and Softplus-Adam.
run_linear_baselines(data, k, tag) -> list¶
Evaluate supervised linear probe baseline.
save_incremental(all_results, out_path)¶
Save results list to JSON with automatic directory creation.