From a3bad799e38a63e2aaffe58bc703dd1fbf4fd3ec Mon Sep 17 00:00:00 2001 From: JJJHolscher Date: Sat, 23 Dec 2023 15:32:18 +0100 Subject: [PATCH] sae werks with deep loss --- config.toml | 29 ++---------- src/__main__.py | 34 ++++++++------ src/sae.py | 120 +++++++++++++++++++++++++++++++++--------------- src/temp.py | 37 +++++++++++++++ 4 files changed, 143 insertions(+), 77 deletions(-) create mode 100644 src/temp.py diff --git a/config.toml b/config.toml index 149fb8e79..7158aeb8a 100644 --- a/config.toml +++ b/config.toml @@ -1,33 +1,12 @@ batch_size = 64 -steps = 1000000 -print_every = 50000 -seed = 0 +steps = 10000 +print_every = 500 +seed = 1 cnn_storage = "./res/cnn.eqx" [[sae]] layer = 6 -hidden_size = 300 -input_size = 64 -learning_rate = 1e-3 -l1 = 3e-4 # from Neel Nanda's sae git - -[[sae]] -layer = 6 -hidden_size = 300 -input_size = 64 -learning_rate = 3e-4 -l1 = 3e-4 - -[[sae]] -layer = 6 -hidden_size = 300 +hidden_size = 256 input_size = 64 learning_rate = 1e-4 l1 = 3e-4 - -[[sae]] -layer = 6 -hidden_size = 300 -input_size = 64 -learning_rate = 3e-5 -l1 = 3e-4 diff --git a/src/__main__.py b/src/__main__.py index 97ec07146..0a409eb34 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -39,21 +39,25 @@ for sae_hyperparams in O.sae: sae_dir = run_dir(sae_hyperparams) if sae_dir.exists(): continue - - sae = train_sae( - key, - cnn, - lambda m: m.layers[sae_hyperparams.layer], - sae_hyperparams.input_size, - sae_hyperparams.hidden_size, - O.batch_size, - sae_hyperparams.learning_rate, - O.steps, - O.print_every, - SummaryWriter(sae_dir / "log") - ) - sae_dir.mkdir() + + with SummaryWriter(sae_dir / "log") as tensorboard: + tensorboard.add_text("hyperparameters", str(sae_hyperparams)) + + sae = train_sae( + key, + cnn, + sae_hyperparams.layer, + sae_hyperparams.input_size, + sae_hyperparams.hidden_size, + O.batch_size, + sae_hyperparams.learning_rate, + O.steps, + O.print_every, + tensorboard, + sae_hyperparams.l1, + ) + argtoml.save(sae_hyperparams, sae_dir / "sae-hyperparams.toml") argtoml.save(O, sae_dir / "config.toml") jo3eqx.save( @@ -70,7 +74,7 @@ for sae_hyperparams in O.sae: sae = jo3eqx.load(sae_dir / f"sae.eqx", SAE) sown_cnn = jo3eqx.sow(lambda m: m.layers[sae_hyperparams.layer], cnn) trainloader, testloader = jo3mnist.load( - batch_size=O.batch_size, shuffle=False + batch_size=1, shuffle=False ) train_dir = sae_dir / f"train" diff --git a/src/sae.py b/src/sae.py index 2b8585fa5..293c7ef34 100644 --- a/src/sae.py +++ b/src/sae.py @@ -11,6 +11,7 @@ from typing import Callable, List import equinox as eqx import jax import jax.numpy as jnp +import jax.tree_util as jtu import jo3mnist import optax from jaxtyping import Array, Float, Int, PyTree @@ -18,6 +19,12 @@ from jo3util.eqx import insert_after, sow from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter +from .cnn import cross_entropy + + +def filter_value_and_grad_with_aux(f): + return eqx.filter_value_and_grad(f, has_aux=True) + class SAE(eqx.Module): we: Float @@ -27,15 +34,16 @@ class SAE(eqx.Module): def __init__(self, in_size, hidden_size, key=jax.random.PRNGKey(42)): k0, k1, k2, k3 = jax.random.split(key, 4) + initializer = jax.nn.initializers.he_uniform() # encoder weight matrix - self.we = jax.random.uniform(k0, (in_size, hidden_size)) + self.we = initializer(k0, (in_size, hidden_size)) # decoder weight matrix - self.wd = jax.random.uniform(k1, (hidden_size, in_size)) + self.wd = initializer(k1, (hidden_size, in_size)) # encoder bias - self.be = jax.random.uniform(k2, (hidden_size,)) + self.be = jnp.zeros((hidden_size,)) # decader bias - self.bd = jax.random.uniform(k3, (in_size,)) + self.bd = jnp.zeros((in_size,)) def __call__(self, x): x = self.encode(x) @@ -49,22 +57,32 @@ class SAE(eqx.Module): def decode(self, fx): return fx @ self.wd + self.bd + def l1(self, x): + x = self.encode(x) + return jax.vmap(jnp.dot, (0, 0))(x, x) + @staticmethod - @eqx.filter_value_and_grad - def loss(sae, x, λ): - fx = jax.vmap(sae.encode)(x) - x_ = jax.vmap(sae.decode)(fx) - sq_err = jax.vmap(jnp.dot, (0, 0))((x - x_), (x - x_)) - l1 = λ * jax.vmap(jnp.dot, (0, 0))(fx, fx) - out = jnp.mean(sq_err + l1) - return out + @filter_value_and_grad_with_aux + def loss(diff_model, static_model, sae_pos, x, y, λ): + model = eqx.combine(diff_model, static_model) + original_activ, reconstructed_activ, pred = jax.vmap(model)(x) + + reconstruction_err = jnp.mean(jax.vmap(jnp.dot, (0, 0))( + (original_activ - reconstructed_activ), + (original_activ - reconstructed_activ) + )) + l1 = λ * jnp.mean(sae_pos(model).l1(original_activ)) + deep_err = jnp.mean(cross_entropy(y, pred)) + + loss = reconstruction_err + l1 + deep_err + return loss, (reconstruction_err, l1, deep_err) def sample_features(cnn, sae, loader): for i, (x, _) in enumerate(loader): x = x.numpy() activ = jax.vmap(cnn)(x)[0] - yield i, sae.encode(activ) + yield i, sae.encode(activ)[0] def evaluate(model: eqx.Module, testloader: DataLoader): @@ -84,23 +102,25 @@ def evaluate(model: eqx.Module, testloader: DataLoader): @eqx.filter_jit def make_step( - sae: SAE, model: eqx.Module, + freeze_spec: PyTree, + sae_pos: Callable, optim, opt_state: PyTree, x: Float[Array, "batch 1 28 28"], + y: Float[Array, "batch"], λ: float, ): - activ = jax.vmap(model)(x)[0] - loss_value, grads = SAE.loss(sae, activ[0], λ) - updates, opt_state = optim.update(grads, opt_state, sae) - sae = eqx.apply_updates(sae, updates) - return sae, opt_state, loss_value + diff_model, static_model = eqx.partition(model, freeze_spec) + (loss, aux), grads = SAE.loss(diff_model, static_model, sae_pos, x, y, λ) + updates, opt_state = optim.update(grads, opt_state, model) + model = eqx.apply_updates(model, updates) + return model, opt_state, loss, *aux def train_loop( - sae: SAE, - cnn: eqx.Module, + model: eqx.Module, + freeze_spec: PyTree, sae_pos: Callable, trainloader: DataLoader, testloader: DataLoader, @@ -110,34 +130,57 @@ def train_loop( tensorboard, λ, ) -> eqx.Module: - opt_state = optim.init(eqx.filter(sae, eqx.is_array)) - - cnn = sow(sae_pos, cnn) - - print(f"test_accuracy={evaluate(cnn, testloader).item()}") + opt_state = optim.init(freeze_spec) # Loop over our training dataset as many times as we need. def infinite_trainloader(): while True: yield from trainloader - for step, (x, _) in zip(range(steps), infinite_trainloader()): + for step, (x, y) in zip(range(steps), infinite_trainloader()): # PyTorch dataloaders give PyTorch tensors by default, # so convert them to NumPy arrays. - sae, opt_state, train_loss = make_step( - sae, cnn, optim, opt_state, x.numpy(), λ + model, opt_state, loss, reconstruction_err, l1, deep_err = make_step( + model, + freeze_spec, + sae_pos, + optim, + opt_state, + x.numpy(), + y.numpy(), + λ ) - tensorboard.add_scalar("loss", train_loss.item(), step) + tensorboard.add_scalar("loss", loss.item(), step) if (step % print_every) == 0 or (step == steps - 1): - cnn_with_sae = insert_after(sae_pos, cnn, sae) - test_accuracy = evaluate(cnn_with_sae, testloader) + test_accuracy = evaluate(model, testloader) print( datetime.now().strftime("%H:%M"), - f"{step=}, train_loss={train_loss.item()}, " - f"test_accuracy={test_accuracy.item()}", + step, + f"{loss=:.3f}", + f"rec={reconstruction_err.item():.3f}", + f"{l1=:.3f}", + f"{deep_err=:.3f}", + f"{test_accuracy=:.3f}", ) tensorboard.add_scalar("accu", test_accuracy.item(), step) - return sae + return sae_pos(model) + + +def compose_model(cnn, sae, layer): + sae_pos = lambda m: m.layers[layer] + model = sow(sae_pos, cnn) + model = insert_after(sae_pos, model, sae) + model = sow(sae_pos, model) + + sae_pos = lambda m: m.layers[layer].children[1] + freeze_spec = jtu.tree_map(lambda _: False, model) + freeze_spec = eqx.tree_at( + sae_pos, + freeze_spec, + replace=jtu.tree_map(lambda leaf: eqx.is_array(leaf), sae) + ) + + return model, freeze_spec, sae_pos def train_sae( @@ -154,11 +197,14 @@ def train_sae( λ, ): trainloader, testloader = jo3mnist.load(batch_size=batch_size) + # print(f"test_accuracy={evaluate(cnn, testloader).item()}") + sae = SAE(activ_size, hidden_size, key) + model, freeze_spec, sae_pos = compose_model(cnn, sae, sae_pos) optim = optax.adamw(learning_rate) return train_loop( - sae, - cnn, + model, + freeze_spec, sae_pos, trainloader, testloader, diff --git a/src/temp.py b/src/temp.py new file mode 100644 index 000000000..c2263adb8 --- /dev/null +++ b/src/temp.py @@ -0,0 +1,37 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 + +import argtoml +import equinox as eqx +import jax +import jax.tree_util as jtu +from jo3util.eqx import insert_after, sow + +from cnn import CNN +from sae import SAE + + +O = argtoml.parse_args() + +key = jax.random.PRNGKey(O.seed) +key, subkey = jax.random.split(key) + +cnn = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey)) + +sae_hyparam = O.sae[0] +sae_pos = lambda m: m.layers[sae_hyparam.layer] + +sae = SAE(sae_hyparam.input_size, sae_hyparam.hidden_size, key) + +model = sow(sae_pos, cnn) +model = insert_after(sae_pos, model, sae) +model = sow(sae_pos, model) + +freeze_spec = jtu.tree_map(lambda _: False, model) +freeze_spec = eqx.tree_at( + lambda m: m.layers[sae_hyparam.layer].children[1], + freeze_spec, + replace=jtu.tree_map(lambda leaf: eqx.is_array(leaf), sae) +) +# print(model) +print(freeze_spec)