diff --git a/src/__main__.py b/src/__main__.py index 0eef5d17e..e800cb448 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -38,8 +38,8 @@ else: sae = train_sae( key, model, - lambda m: m.layers[4], - 512, + lambda m: m.layers[6], + 64, O.batch_size, O.learning_rate, O.steps, diff --git a/src/cnn.py b/src/cnn.py index 11d272d74..4aa389cd5 100644 --- a/src/cnn.py +++ b/src/cnn.py @@ -149,6 +149,8 @@ def train_cnn( trainloader, testloader = mnist.load(batch_size=batch_size) model = CNN(key) optim = optax.adamw(learning_rate) - model = train_loop(model, trainloader, testloader, optim, steps, print_every) + model = train_loop( + model, trainloader, testloader, optim, steps, print_every + ) eqx.tree_serialise_leaves(model_storage, model) return model diff --git a/src/sae.py b/src/sae.py index d2b99303c..d10a3e5fd 100644 --- a/src/sae.py +++ b/src/sae.py @@ -5,13 +5,15 @@ """ +from typing import Callable, List + import equinox as eqx import jax import jax.numpy as jnp import mnist import optax from jaxtyping import Array, Float, Int, PyTree -from jo3util.eqx import sow +from jo3util.eqx import insert_after, sow from torch.utils.data import DataLoader @@ -33,8 +35,13 @@ class SAE(eqx.Module): # decader bias self.bd = jax.random.uniform(k3, (in_size,)) + def __call__(self, x): + x = self.encode(x) + x = self.decode(x) + return x + def encode(self, x): - x = (x - self.bd) @ self.we + self.be + x = ((x - self.bd) @ self.we) + self.be return jax.nn.relu(x) def decode(self, fx): @@ -45,23 +52,42 @@ class SAE(eqx.Module): def loss(sae, x, λ): fx = jax.vmap(sae.encode)(x) x_ = jax.vmap(sae.decode)(fx) - sq_err = jnp.dot((x - x_), (x - x_)) - l1 = λ * jnp.dot(fx, fx) - return jnp.mean(sq_err + l1) + sq_err = jax.vmap(jnp.dot, (0, 0))((x - x_), (x - x_)) + l1 = λ * jax.vmap(jnp.dot, (0, 0))(fx, fx) + out = jnp.mean(sq_err + l1) + return out + + +def evaluate(model: eqx.Module, testloader: DataLoader): + """This function evaluates the model on the test dataset, + computing both the average loss and the average accuracy. + """ + avg_acc = 0 + for x, y in testloader: + x = x.numpy() + y = y.numpy() + # Note that all the JAX operations happen inside `loss` and `compute_accuracy`, + # and both have JIT wrappers, so this is fast. + pred_y = jnp.argmax(jax.vmap(model)(x)[0], axis=1) + avg_acc += jnp.mean(y == pred_y) + return avg_acc / len(testloader) def train_loop( sae: SAE, model: eqx.Module, + sae_pos: Callable, trainloader: DataLoader, testloader: DataLoader, optim: optax.GradientTransformation, steps: int, print_every: int, ) -> eqx.Module: - # Just like earlier: It only makes sense to train the arrays in our model, - # so filter out everything else. - opt_state = optim.init(eqx.filter(model, eqx.is_array)) + opt_state = optim.init(eqx.filter(sae, eqx.is_array)) + + model = sow(sae_pos, model) + + print(f"test_accuracy={evaluate(model, testloader).item()}") # Always wrap everything -- computing gradients, running the optimiser, updating # the model -- into a single JIT region. This ensures things run as fast as @@ -92,12 +118,13 @@ def train_loop( y = y.numpy() sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y) if (step % print_every) == 0 or (step == steps - 1): - test_loss, test_accuracy = evaluate(sae, model, testloader) + model_with_sae = insert_after(sae_pos, model, sae) + test_accuracy = evaluate(model_with_sae, testloader) print( f"{step=}, train_loss={train_loss.item()}, " - f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}" + f"test_accuracy={test_accuracy.item()}" ) - return model + return sae def train_sae( @@ -109,15 +136,20 @@ def train_sae( learning_rate, steps, print_every, - sae_storage="./res/sae.eqx", + sae_storage="./res/sae_layer_6.eqx", ): trainloader, testloader = mnist.load(batch_size=batch_size) - model = sow(sae_pos, cnn) - print(model) sae = SAE(activ_size, 1000, key) optim = optax.adamw(learning_rate) - model = train_loop( - sae, model, trainloader, testloader, optim, steps, print_every + sae = train_loop( + sae, + cnn, + sae_pos, + trainloader, + testloader, + optim, + steps, + print_every, ) - # eqx.tree_serialise_leaves(model_storage, model) - return model + eqx.tree_serialise_leaves(sae_storage, sae) + return sae