diff --git a/config.toml b/config.toml index ca947a9eb..422fdbd87 100644 --- a/config.toml +++ b/config.toml @@ -1,7 +1,47 @@ batch_size = 64 -learning_rate = 3e-4 -steps = 300 -print_every = 30 -seed = 5678 +steps = 500000 +print_every = 10000 +seed = 0 cnn_storage = "./res/cnn.eqx" -sae_storage = "./res/sae.eqx" + +[[sae]] +layer = 6 +hidden_size = 1000 +input_size = 64 +learning_rate = 0.1 + +[[sae]] +layer = 6 +hidden_size = 1000 +input_size = 64 +learning_rate = 3e-2 + +[[sae]] +layer = 6 +hidden_size = 1000 +input_size = 64 +learning_rate = 1e-2 + +[[sae]] +layer = 6 +hidden_size = 1000 +input_size = 64 +learning_rate = 3e-3 + +[[sae]] +layer = 6 +hidden_size = 1000 +input_size = 64 +learning_rate = 1e-3 + +[[sae]] +layer = 6 +hidden_size = 1000 +input_size = 64 +learning_rate = 3e-4 + +[[sae]] +layer = 6 +hidden_size = 1000 +input_size = 64 +learning_rate = 1e-4 diff --git a/requirements.txt b/requirements.txt index 400331ae0..06717fe76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,25 +1,21 @@ -argtoml +argtoml[save] build debugpy equinox -jax>=0.4.14 +jax jaxtyping +jo3mnist +jo3util matplotlib -nbclassic -notebook optax pandas pyright scikit-learn -tensorboard -tensorboardX +tensorflow torch torchvision tqdm twine typeguard -git+file:///mnt/nas/git/jo3util -git+https://github.com/JJJHolscher/jupytools -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --f https://download.pytorch.org/whl/cu118 -git+file:///mnt/nas/git/mnist +--extra-index-url https://download.pytorch.org/whl/cpu diff --git a/src/__main__.py b/src/__main__.py index c38859731..e9772f5fc 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,30 +1,30 @@ #! /usr/bin/env python3 # vim:fenc=utf-8 -""" - -""" +from pathlib import Path import argtoml import equinox as eqx import jax import jax.numpy as jnp -from jaxtyping import Float # https://github.com/google/jaxtyping -from jaxtyping import Array, Int, PyTree -from torch.utils.data import DataLoader +import jo3mnist +from jo3util import eqx as jo3eqx +from jo3util.root import run_dir +from tqdm import tqdm from .cnn import CNN, train_cnn -from .sae import SAE, train_sae +from .sae import SAE, sample_features, train_sae # Hyperparameters O = argtoml.parse_args() + key = jax.random.PRNGKey(O.seed) key, subkey = jax.random.split(key) -if O.cnn_storage.exists(): - model = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey)) +if (Path(".") / O.cnn_storage).exists(): + cnn = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey)) else: - model = train_cnn( + cnn = train_cnn( subkey, O.batch_size, O.learning_rate, @@ -32,15 +32,59 @@ else: O.print_every, O.cnn_storage, ) + eqx.tree_serialise_leaves(O.cnn_storage, cnn) -sae = train_sae( - key, - model, - lambda m: m.layers[6], - 64, - O.batch_size, - O.learning_rate, - O.steps, - O.print_every, - O.sae_storage, -) +for sae_hyperparams in O.sae: + sae_dir = run_dir(sae_hyperparams) + if sae_dir.exists(): + continue + + sae = train_sae( + key, + cnn, + lambda m: m.layers[sae_hyperparams.layer], + sae_hyperparams.input_size, + sae_hyperparams.hidden_size, + O.batch_size, + sae_hyperparams.learning_rate, + O.steps, + O.print_every, + ) + + sae_dir.mkdir() + argtoml.save(sae_hyperparams, sae_dir / "sae-hyperparams.toml") + argtoml.save(O, sae_dir / "config.toml") + jo3eqx.save( + sae_dir / f"sae.eqx", + sae, + { + "in_size": sae_hyperparams.input_size, + "hidden_size": sae_hyperparams.hidden_size, + }, + ) + +for sae_hyperparams in O.sae: + sae_dir = run_dir(sae_hyperparams) + sae = jo3eqx.load(sae_dir / f"sae.eqx", SAE) + sown_cnn = jo3eqx.sow(lambda m: m.layers[sae_hyperparams.layer], cnn) + trainloader, testloader = jo3mnist.load( + batch_size=O.batch_size, shuffle=False + ) + + train_dir = sae_dir / f"train" + if not train_dir.exists(): + print("saving features from the training set") + train_dir.mkdir() + for i, features in tqdm( + sample_features(sown_cnn, sae, trainloader), total=len(trainloader) + ): + jnp.save(train_dir / f"{i}.npy", features, allow_pickle=False) + + test_dir = sae_dir / f"test" + if not test_dir.exists(): + print("saving features from the test set") + test_dir.mkdir() + for i, features in tqdm( + sample_features(sown_cnn, sae, testloader), total=len(testloader) + ): + jnp.save(test_dir / f"{i}.npy", features, allow_pickle=False) diff --git a/src/cnn.py b/src/cnn.py index fdf02d0e2..d161d085d 100644 --- a/src/cnn.py +++ b/src/cnn.py @@ -152,5 +152,4 @@ def train_cnn( model = train_loop( model, trainloader, testloader, optim, steps, print_every ) - eqx.tree_serialise_leaves(model_storage, model) return model diff --git a/src/sae.py b/src/sae.py index a336d26cc..7afc9190e 100644 --- a/src/sae.py +++ b/src/sae.py @@ -58,6 +58,13 @@ class SAE(eqx.Module): return out +def sample_features(cnn, sae, loader): + for i, (x, _) in enumerate(loader): + x = x.numpy() + activ = jax.vmap(cnn)(x)[0] + yield i, sae.encode(activ) + + def evaluate(model: eqx.Module, testloader: DataLoader): """This function evaluates the model on the test dataset, computing both the average loss and the average accuracy. @@ -68,11 +75,27 @@ def evaluate(model: eqx.Module, testloader: DataLoader): y = y.numpy() # Note that all the JAX operations happen inside `loss` and `compute_accuracy`, # and both have JIT wrappers, so this is fast. - pred_y = jnp.argmax(jax.vmap(model)(x)[0], axis=1) + pred_y = jnp.argmax(jax.vmap(model)(x)[-1], axis=1) avg_acc += jnp.mean(y == pred_y) return avg_acc / len(testloader) +@eqx.filter_jit +def make_step( + sae: SAE, + model: eqx.Module, + optim, + opt_state: PyTree, + x: Float[Array, "batch 1 28 28"], + y: Int[Array, " batch"], +): + activ = jax.vmap(model)(x)[0] + loss_value, grads = SAE.loss(sae, activ[0], 1) + updates, opt_state = optim.update(grads, opt_state, sae) + sae = eqx.apply_updates(sae, updates) + return sae, opt_state, loss_value + + def train_loop( sae: SAE, model: eqx.Module, @@ -89,23 +112,7 @@ def train_loop( print(f"test_accuracy={evaluate(model, testloader).item()}") - # Always wrap everything -- computing gradients, running the optimiser, updating - # the model -- into a single JIT region. This ensures things run as fast as - # possible. - @eqx.filter_jit - def make_step( - sae: SAE, - model: eqx.Module, - opt_state: PyTree, - x: Float[Array, "batch 1 28 28"], - y: Int[Array, " batch"], - ): - _, activ = jax.vmap(model)(x) - loss_value, grads = SAE.loss(sae, activ[0], 1) - updates, opt_state = optim.update(grads, opt_state, sae) - sae = eqx.apply_updates(sae, updates) - return sae, opt_state, loss_value - + # Loop over our training dataset as many times as we need. def infinite_trainloader(): while True: @@ -116,7 +123,7 @@ def train_loop( # so convert them to NumPy arrays. x = x.numpy() y = y.numpy() - sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y) + sae, opt_state, train_loss = make_step(sae, model, optim, opt_state, x, y) if (step % print_every) == 0 or (step == steps - 1): model_with_sae = insert_after(sae_pos, model, sae) test_accuracy = evaluate(model_with_sae, testloader) @@ -132,16 +139,16 @@ def train_sae( cnn, sae_pos, activ_size, + hidden_size, batch_size, learning_rate, steps, print_every, - sae_storage="./res/sae_layer_6.eqx", ): trainloader, testloader = jo3mnist.load(batch_size=batch_size) - sae = SAE(activ_size, 1000, key) + sae = SAE(activ_size, hidden_size, key) optim = optax.adamw(learning_rate) - sae = train_loop( + return train_loop( sae, cnn, sae_pos, @@ -151,5 +158,3 @@ def train_sae( steps, print_every, ) - eqx.tree_serialise_leaves(sae_storage, sae) - return sae