sae werks with deep loss

This commit is contained in:
JJJHolscher 2023-12-23 15:32:18 +01:00
parent ecd7d3bd65
commit a3bad799e3
4 changed files with 143 additions and 77 deletions

View File

@ -1,33 +1,12 @@
batch_size = 64 batch_size = 64
steps = 1000000 steps = 10000
print_every = 50000 print_every = 500
seed = 0 seed = 1
cnn_storage = "./res/cnn.eqx" cnn_storage = "./res/cnn.eqx"
[[sae]] [[sae]]
layer = 6 layer = 6
hidden_size = 300 hidden_size = 256
input_size = 64
learning_rate = 1e-3
l1 = 3e-4 # from Neel Nanda's sae git
[[sae]]
layer = 6
hidden_size = 300
input_size = 64
learning_rate = 3e-4
l1 = 3e-4
[[sae]]
layer = 6
hidden_size = 300
input_size = 64 input_size = 64
learning_rate = 1e-4 learning_rate = 1e-4
l1 = 3e-4 l1 = 3e-4
[[sae]]
layer = 6
hidden_size = 300
input_size = 64
learning_rate = 3e-5
l1 = 3e-4

View File

@ -39,21 +39,25 @@ for sae_hyperparams in O.sae:
sae_dir = run_dir(sae_hyperparams) sae_dir = run_dir(sae_hyperparams)
if sae_dir.exists(): if sae_dir.exists():
continue continue
sae_dir.mkdir()
with SummaryWriter(sae_dir / "log") as tensorboard:
tensorboard.add_text("hyperparameters", str(sae_hyperparams))
sae = train_sae( sae = train_sae(
key, key,
cnn, cnn,
lambda m: m.layers[sae_hyperparams.layer], sae_hyperparams.layer,
sae_hyperparams.input_size, sae_hyperparams.input_size,
sae_hyperparams.hidden_size, sae_hyperparams.hidden_size,
O.batch_size, O.batch_size,
sae_hyperparams.learning_rate, sae_hyperparams.learning_rate,
O.steps, O.steps,
O.print_every, O.print_every,
SummaryWriter(sae_dir / "log") tensorboard,
sae_hyperparams.l1,
) )
sae_dir.mkdir()
argtoml.save(sae_hyperparams, sae_dir / "sae-hyperparams.toml") argtoml.save(sae_hyperparams, sae_dir / "sae-hyperparams.toml")
argtoml.save(O, sae_dir / "config.toml") argtoml.save(O, sae_dir / "config.toml")
jo3eqx.save( jo3eqx.save(
@ -70,7 +74,7 @@ for sae_hyperparams in O.sae:
sae = jo3eqx.load(sae_dir / f"sae.eqx", SAE) sae = jo3eqx.load(sae_dir / f"sae.eqx", SAE)
sown_cnn = jo3eqx.sow(lambda m: m.layers[sae_hyperparams.layer], cnn) sown_cnn = jo3eqx.sow(lambda m: m.layers[sae_hyperparams.layer], cnn)
trainloader, testloader = jo3mnist.load( trainloader, testloader = jo3mnist.load(
batch_size=O.batch_size, shuffle=False batch_size=1, shuffle=False
) )
train_dir = sae_dir / f"train" train_dir = sae_dir / f"train"

View File

@ -11,6 +11,7 @@ from typing import Callable, List
import equinox as eqx import equinox as eqx
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax.tree_util as jtu
import jo3mnist import jo3mnist
import optax import optax
from jaxtyping import Array, Float, Int, PyTree from jaxtyping import Array, Float, Int, PyTree
@ -18,6 +19,12 @@ from jo3util.eqx import insert_after, sow
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from .cnn import cross_entropy
def filter_value_and_grad_with_aux(f):
return eqx.filter_value_and_grad(f, has_aux=True)
class SAE(eqx.Module): class SAE(eqx.Module):
we: Float we: Float
@ -27,15 +34,16 @@ class SAE(eqx.Module):
def __init__(self, in_size, hidden_size, key=jax.random.PRNGKey(42)): def __init__(self, in_size, hidden_size, key=jax.random.PRNGKey(42)):
k0, k1, k2, k3 = jax.random.split(key, 4) k0, k1, k2, k3 = jax.random.split(key, 4)
initializer = jax.nn.initializers.he_uniform()
# encoder weight matrix # encoder weight matrix
self.we = jax.random.uniform(k0, (in_size, hidden_size)) self.we = initializer(k0, (in_size, hidden_size))
# decoder weight matrix # decoder weight matrix
self.wd = jax.random.uniform(k1, (hidden_size, in_size)) self.wd = initializer(k1, (hidden_size, in_size))
# encoder bias # encoder bias
self.be = jax.random.uniform(k2, (hidden_size,)) self.be = jnp.zeros((hidden_size,))
# decader bias # decader bias
self.bd = jax.random.uniform(k3, (in_size,)) self.bd = jnp.zeros((in_size,))
def __call__(self, x): def __call__(self, x):
x = self.encode(x) x = self.encode(x)
@ -49,22 +57,32 @@ class SAE(eqx.Module):
def decode(self, fx): def decode(self, fx):
return fx @ self.wd + self.bd return fx @ self.wd + self.bd
def l1(self, x):
x = self.encode(x)
return jax.vmap(jnp.dot, (0, 0))(x, x)
@staticmethod @staticmethod
@eqx.filter_value_and_grad @filter_value_and_grad_with_aux
def loss(sae, x, λ): def loss(diff_model, static_model, sae_pos, x, y, λ):
fx = jax.vmap(sae.encode)(x) model = eqx.combine(diff_model, static_model)
x_ = jax.vmap(sae.decode)(fx) original_activ, reconstructed_activ, pred = jax.vmap(model)(x)
sq_err = jax.vmap(jnp.dot, (0, 0))((x - x_), (x - x_))
l1 = λ * jax.vmap(jnp.dot, (0, 0))(fx, fx) reconstruction_err = jnp.mean(jax.vmap(jnp.dot, (0, 0))(
out = jnp.mean(sq_err + l1) (original_activ - reconstructed_activ),
return out (original_activ - reconstructed_activ)
))
l1 = λ * jnp.mean(sae_pos(model).l1(original_activ))
deep_err = jnp.mean(cross_entropy(y, pred))
loss = reconstruction_err + l1 + deep_err
return loss, (reconstruction_err, l1, deep_err)
def sample_features(cnn, sae, loader): def sample_features(cnn, sae, loader):
for i, (x, _) in enumerate(loader): for i, (x, _) in enumerate(loader):
x = x.numpy() x = x.numpy()
activ = jax.vmap(cnn)(x)[0] activ = jax.vmap(cnn)(x)[0]
yield i, sae.encode(activ) yield i, sae.encode(activ)[0]
def evaluate(model: eqx.Module, testloader: DataLoader): def evaluate(model: eqx.Module, testloader: DataLoader):
@ -84,23 +102,25 @@ def evaluate(model: eqx.Module, testloader: DataLoader):
@eqx.filter_jit @eqx.filter_jit
def make_step( def make_step(
sae: SAE,
model: eqx.Module, model: eqx.Module,
freeze_spec: PyTree,
sae_pos: Callable,
optim, optim,
opt_state: PyTree, opt_state: PyTree,
x: Float[Array, "batch 1 28 28"], x: Float[Array, "batch 1 28 28"],
y: Float[Array, "batch"],
λ: float, λ: float,
): ):
activ = jax.vmap(model)(x)[0] diff_model, static_model = eqx.partition(model, freeze_spec)
loss_value, grads = SAE.loss(sae, activ[0], λ) (loss, aux), grads = SAE.loss(diff_model, static_model, sae_pos, x, y, λ)
updates, opt_state = optim.update(grads, opt_state, sae) updates, opt_state = optim.update(grads, opt_state, model)
sae = eqx.apply_updates(sae, updates) model = eqx.apply_updates(model, updates)
return sae, opt_state, loss_value return model, opt_state, loss, *aux
def train_loop( def train_loop(
sae: SAE, model: eqx.Module,
cnn: eqx.Module, freeze_spec: PyTree,
sae_pos: Callable, sae_pos: Callable,
trainloader: DataLoader, trainloader: DataLoader,
testloader: DataLoader, testloader: DataLoader,
@ -110,34 +130,57 @@ def train_loop(
tensorboard, tensorboard,
λ, λ,
) -> eqx.Module: ) -> eqx.Module:
opt_state = optim.init(eqx.filter(sae, eqx.is_array)) opt_state = optim.init(freeze_spec)
cnn = sow(sae_pos, cnn)
print(f"test_accuracy={evaluate(cnn, testloader).item()}")
# Loop over our training dataset as many times as we need. # Loop over our training dataset as many times as we need.
def infinite_trainloader(): def infinite_trainloader():
while True: while True:
yield from trainloader yield from trainloader
for step, (x, _) in zip(range(steps), infinite_trainloader()): for step, (x, y) in zip(range(steps), infinite_trainloader()):
# PyTorch dataloaders give PyTorch tensors by default, # PyTorch dataloaders give PyTorch tensors by default,
# so convert them to NumPy arrays. # so convert them to NumPy arrays.
sae, opt_state, train_loss = make_step( model, opt_state, loss, reconstruction_err, l1, deep_err = make_step(
sae, cnn, optim, opt_state, x.numpy(), λ model,
freeze_spec,
sae_pos,
optim,
opt_state,
x.numpy(),
y.numpy(),
λ
) )
tensorboard.add_scalar("loss", train_loss.item(), step) tensorboard.add_scalar("loss", loss.item(), step)
if (step % print_every) == 0 or (step == steps - 1): if (step % print_every) == 0 or (step == steps - 1):
cnn_with_sae = insert_after(sae_pos, cnn, sae) test_accuracy = evaluate(model, testloader)
test_accuracy = evaluate(cnn_with_sae, testloader)
print( print(
datetime.now().strftime("%H:%M"), datetime.now().strftime("%H:%M"),
f"{step=}, train_loss={train_loss.item()}, " step,
f"test_accuracy={test_accuracy.item()}", f"{loss=:.3f}",
f"rec={reconstruction_err.item():.3f}",
f"{l1=:.3f}",
f"{deep_err=:.3f}",
f"{test_accuracy=:.3f}",
) )
tensorboard.add_scalar("accu", test_accuracy.item(), step) tensorboard.add_scalar("accu", test_accuracy.item(), step)
return sae return sae_pos(model)
def compose_model(cnn, sae, layer):
sae_pos = lambda m: m.layers[layer]
model = sow(sae_pos, cnn)
model = insert_after(sae_pos, model, sae)
model = sow(sae_pos, model)
sae_pos = lambda m: m.layers[layer].children[1]
freeze_spec = jtu.tree_map(lambda _: False, model)
freeze_spec = eqx.tree_at(
sae_pos,
freeze_spec,
replace=jtu.tree_map(lambda leaf: eqx.is_array(leaf), sae)
)
return model, freeze_spec, sae_pos
def train_sae( def train_sae(
@ -154,11 +197,14 @@ def train_sae(
λ, λ,
): ):
trainloader, testloader = jo3mnist.load(batch_size=batch_size) trainloader, testloader = jo3mnist.load(batch_size=batch_size)
# print(f"test_accuracy={evaluate(cnn, testloader).item()}")
sae = SAE(activ_size, hidden_size, key) sae = SAE(activ_size, hidden_size, key)
model, freeze_spec, sae_pos = compose_model(cnn, sae, sae_pos)
optim = optax.adamw(learning_rate) optim = optax.adamw(learning_rate)
return train_loop( return train_loop(
sae, model,
cnn, freeze_spec,
sae_pos, sae_pos,
trainloader, trainloader,
testloader, testloader,

37
src/temp.py Normal file
View File

@ -0,0 +1,37 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
import argtoml
import equinox as eqx
import jax
import jax.tree_util as jtu
from jo3util.eqx import insert_after, sow
from cnn import CNN
from sae import SAE
O = argtoml.parse_args()
key = jax.random.PRNGKey(O.seed)
key, subkey = jax.random.split(key)
cnn = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
sae_hyparam = O.sae[0]
sae_pos = lambda m: m.layers[sae_hyparam.layer]
sae = SAE(sae_hyparam.input_size, sae_hyparam.hidden_size, key)
model = sow(sae_pos, cnn)
model = insert_after(sae_pos, model, sae)
model = sow(sae_pos, model)
freeze_spec = jtu.tree_map(lambda _: False, model)
freeze_spec = eqx.tree_at(
lambda m: m.layers[sae_hyparam.layer].children[1],
freeze_spec,
replace=jtu.tree_map(lambda leaf: eqx.is_array(leaf), sae)
)
# print(model)
print(freeze_spec)