diff --git a/config.toml b/config.toml index 422fdbd87..149fb8e79 100644 --- a/config.toml +++ b/config.toml @@ -1,47 +1,33 @@ batch_size = 64 -steps = 500000 -print_every = 10000 +steps = 1000000 +print_every = 50000 seed = 0 cnn_storage = "./res/cnn.eqx" [[sae]] layer = 6 -hidden_size = 1000 -input_size = 64 -learning_rate = 0.1 - -[[sae]] -layer = 6 -hidden_size = 1000 -input_size = 64 -learning_rate = 3e-2 - -[[sae]] -layer = 6 -hidden_size = 1000 -input_size = 64 -learning_rate = 1e-2 - -[[sae]] -layer = 6 -hidden_size = 1000 -input_size = 64 -learning_rate = 3e-3 - -[[sae]] -layer = 6 -hidden_size = 1000 +hidden_size = 300 input_size = 64 learning_rate = 1e-3 +l1 = 3e-4 # from Neel Nanda's sae git [[sae]] layer = 6 -hidden_size = 1000 +hidden_size = 300 input_size = 64 learning_rate = 3e-4 +l1 = 3e-4 [[sae]] layer = 6 -hidden_size = 1000 +hidden_size = 300 input_size = 64 learning_rate = 1e-4 +l1 = 3e-4 + +[[sae]] +layer = 6 +hidden_size = 300 +input_size = 64 +learning_rate = 3e-5 +l1 = 3e-4 diff --git a/src/__main__.py b/src/__main__.py index e9772f5fc..97ec07146 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -10,6 +10,7 @@ import jax.numpy as jnp import jo3mnist from jo3util import eqx as jo3eqx from jo3util.root import run_dir +from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from .cnn import CNN, train_cnn @@ -49,6 +50,7 @@ for sae_hyperparams in O.sae: sae_hyperparams.learning_rate, O.steps, O.print_every, + SummaryWriter(sae_dir / "log") ) sae_dir.mkdir() diff --git a/src/sae.py b/src/sae.py index 7afc9190e..2b8585fa5 100644 --- a/src/sae.py +++ b/src/sae.py @@ -5,6 +5,7 @@ """ +from datetime import datetime from typing import Callable, List import equinox as eqx @@ -15,6 +16,7 @@ import optax from jaxtyping import Array, Float, Int, PyTree from jo3util.eqx import insert_after, sow from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter class SAE(eqx.Module): @@ -87,10 +89,10 @@ def make_step( optim, opt_state: PyTree, x: Float[Array, "batch 1 28 28"], - y: Int[Array, " batch"], + λ: float, ): activ = jax.vmap(model)(x)[0] - loss_value, grads = SAE.loss(sae, activ[0], 1) + loss_value, grads = SAE.loss(sae, activ[0], λ) updates, opt_state = optim.update(grads, opt_state, sae) sae = eqx.apply_updates(sae, updates) return sae, opt_state, loss_value @@ -98,39 +100,43 @@ def make_step( def train_loop( sae: SAE, - model: eqx.Module, + cnn: eqx.Module, sae_pos: Callable, trainloader: DataLoader, testloader: DataLoader, optim: optax.GradientTransformation, steps: int, print_every: int, + tensorboard, + λ, ) -> eqx.Module: opt_state = optim.init(eqx.filter(sae, eqx.is_array)) - model = sow(sae_pos, model) + cnn = sow(sae_pos, cnn) - print(f"test_accuracy={evaluate(model, testloader).item()}") + print(f"test_accuracy={evaluate(cnn, testloader).item()}") - # Loop over our training dataset as many times as we need. def infinite_trainloader(): while True: yield from trainloader - for step, (x, y) in zip(range(steps), infinite_trainloader()): + for step, (x, _) in zip(range(steps), infinite_trainloader()): # PyTorch dataloaders give PyTorch tensors by default, # so convert them to NumPy arrays. - x = x.numpy() - y = y.numpy() - sae, opt_state, train_loss = make_step(sae, model, optim, opt_state, x, y) + sae, opt_state, train_loss = make_step( + sae, cnn, optim, opt_state, x.numpy(), λ + ) + tensorboard.add_scalar("loss", train_loss.item(), step) if (step % print_every) == 0 or (step == steps - 1): - model_with_sae = insert_after(sae_pos, model, sae) - test_accuracy = evaluate(model_with_sae, testloader) + cnn_with_sae = insert_after(sae_pos, cnn, sae) + test_accuracy = evaluate(cnn_with_sae, testloader) print( + datetime.now().strftime("%H:%M"), f"{step=}, train_loss={train_loss.item()}, " - f"test_accuracy={test_accuracy.item()}" + f"test_accuracy={test_accuracy.item()}", ) + tensorboard.add_scalar("accu", test_accuracy.item(), step) return sae @@ -144,6 +150,8 @@ def train_sae( learning_rate, steps, print_every, + tensorboard, + λ, ): trainloader, testloader = jo3mnist.load(batch_size=batch_size) sae = SAE(activ_size, hidden_size, key) @@ -157,4 +165,6 @@ def train_sae( optim, steps, print_every, + tensorboard, + λ )