sae werks with deep loss
This commit is contained in:
parent
ecd7d3bd65
commit
a3bad799e3
29
config.toml
29
config.toml
|
@ -1,33 +1,12 @@
|
|||
batch_size = 64
|
||||
steps = 1000000
|
||||
print_every = 50000
|
||||
seed = 0
|
||||
steps = 10000
|
||||
print_every = 500
|
||||
seed = 1
|
||||
cnn_storage = "./res/cnn.eqx"
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 300
|
||||
input_size = 64
|
||||
learning_rate = 1e-3
|
||||
l1 = 3e-4 # from Neel Nanda's sae git
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 300
|
||||
input_size = 64
|
||||
learning_rate = 3e-4
|
||||
l1 = 3e-4
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 300
|
||||
hidden_size = 256
|
||||
input_size = 64
|
||||
learning_rate = 1e-4
|
||||
l1 = 3e-4
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 300
|
||||
input_size = 64
|
||||
learning_rate = 3e-5
|
||||
l1 = 3e-4
|
||||
|
|
|
@ -39,21 +39,25 @@ for sae_hyperparams in O.sae:
|
|||
sae_dir = run_dir(sae_hyperparams)
|
||||
if sae_dir.exists():
|
||||
continue
|
||||
sae_dir.mkdir()
|
||||
|
||||
with SummaryWriter(sae_dir / "log") as tensorboard:
|
||||
tensorboard.add_text("hyperparameters", str(sae_hyperparams))
|
||||
|
||||
sae = train_sae(
|
||||
key,
|
||||
cnn,
|
||||
lambda m: m.layers[sae_hyperparams.layer],
|
||||
sae_hyperparams.layer,
|
||||
sae_hyperparams.input_size,
|
||||
sae_hyperparams.hidden_size,
|
||||
O.batch_size,
|
||||
sae_hyperparams.learning_rate,
|
||||
O.steps,
|
||||
O.print_every,
|
||||
SummaryWriter(sae_dir / "log")
|
||||
tensorboard,
|
||||
sae_hyperparams.l1,
|
||||
)
|
||||
|
||||
sae_dir.mkdir()
|
||||
argtoml.save(sae_hyperparams, sae_dir / "sae-hyperparams.toml")
|
||||
argtoml.save(O, sae_dir / "config.toml")
|
||||
jo3eqx.save(
|
||||
|
@ -70,7 +74,7 @@ for sae_hyperparams in O.sae:
|
|||
sae = jo3eqx.load(sae_dir / f"sae.eqx", SAE)
|
||||
sown_cnn = jo3eqx.sow(lambda m: m.layers[sae_hyperparams.layer], cnn)
|
||||
trainloader, testloader = jo3mnist.load(
|
||||
batch_size=O.batch_size, shuffle=False
|
||||
batch_size=1, shuffle=False
|
||||
)
|
||||
|
||||
train_dir = sae_dir / f"train"
|
||||
|
|
120
src/sae.py
120
src/sae.py
|
@ -11,6 +11,7 @@ from typing import Callable, List
|
|||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.tree_util as jtu
|
||||
import jo3mnist
|
||||
import optax
|
||||
from jaxtyping import Array, Float, Int, PyTree
|
||||
|
@ -18,6 +19,12 @@ from jo3util.eqx import insert_after, sow
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .cnn import cross_entropy
|
||||
|
||||
|
||||
def filter_value_and_grad_with_aux(f):
|
||||
return eqx.filter_value_and_grad(f, has_aux=True)
|
||||
|
||||
|
||||
class SAE(eqx.Module):
|
||||
we: Float
|
||||
|
@ -27,15 +34,16 @@ class SAE(eqx.Module):
|
|||
|
||||
def __init__(self, in_size, hidden_size, key=jax.random.PRNGKey(42)):
|
||||
k0, k1, k2, k3 = jax.random.split(key, 4)
|
||||
initializer = jax.nn.initializers.he_uniform()
|
||||
|
||||
# encoder weight matrix
|
||||
self.we = jax.random.uniform(k0, (in_size, hidden_size))
|
||||
self.we = initializer(k0, (in_size, hidden_size))
|
||||
# decoder weight matrix
|
||||
self.wd = jax.random.uniform(k1, (hidden_size, in_size))
|
||||
self.wd = initializer(k1, (hidden_size, in_size))
|
||||
# encoder bias
|
||||
self.be = jax.random.uniform(k2, (hidden_size,))
|
||||
self.be = jnp.zeros((hidden_size,))
|
||||
# decader bias
|
||||
self.bd = jax.random.uniform(k3, (in_size,))
|
||||
self.bd = jnp.zeros((in_size,))
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.encode(x)
|
||||
|
@ -49,22 +57,32 @@ class SAE(eqx.Module):
|
|||
def decode(self, fx):
|
||||
return fx @ self.wd + self.bd
|
||||
|
||||
def l1(self, x):
|
||||
x = self.encode(x)
|
||||
return jax.vmap(jnp.dot, (0, 0))(x, x)
|
||||
|
||||
@staticmethod
|
||||
@eqx.filter_value_and_grad
|
||||
def loss(sae, x, λ):
|
||||
fx = jax.vmap(sae.encode)(x)
|
||||
x_ = jax.vmap(sae.decode)(fx)
|
||||
sq_err = jax.vmap(jnp.dot, (0, 0))((x - x_), (x - x_))
|
||||
l1 = λ * jax.vmap(jnp.dot, (0, 0))(fx, fx)
|
||||
out = jnp.mean(sq_err + l1)
|
||||
return out
|
||||
@filter_value_and_grad_with_aux
|
||||
def loss(diff_model, static_model, sae_pos, x, y, λ):
|
||||
model = eqx.combine(diff_model, static_model)
|
||||
original_activ, reconstructed_activ, pred = jax.vmap(model)(x)
|
||||
|
||||
reconstruction_err = jnp.mean(jax.vmap(jnp.dot, (0, 0))(
|
||||
(original_activ - reconstructed_activ),
|
||||
(original_activ - reconstructed_activ)
|
||||
))
|
||||
l1 = λ * jnp.mean(sae_pos(model).l1(original_activ))
|
||||
deep_err = jnp.mean(cross_entropy(y, pred))
|
||||
|
||||
loss = reconstruction_err + l1 + deep_err
|
||||
return loss, (reconstruction_err, l1, deep_err)
|
||||
|
||||
|
||||
def sample_features(cnn, sae, loader):
|
||||
for i, (x, _) in enumerate(loader):
|
||||
x = x.numpy()
|
||||
activ = jax.vmap(cnn)(x)[0]
|
||||
yield i, sae.encode(activ)
|
||||
yield i, sae.encode(activ)[0]
|
||||
|
||||
|
||||
def evaluate(model: eqx.Module, testloader: DataLoader):
|
||||
|
@ -84,23 +102,25 @@ def evaluate(model: eqx.Module, testloader: DataLoader):
|
|||
|
||||
@eqx.filter_jit
|
||||
def make_step(
|
||||
sae: SAE,
|
||||
model: eqx.Module,
|
||||
freeze_spec: PyTree,
|
||||
sae_pos: Callable,
|
||||
optim,
|
||||
opt_state: PyTree,
|
||||
x: Float[Array, "batch 1 28 28"],
|
||||
y: Float[Array, "batch"],
|
||||
λ: float,
|
||||
):
|
||||
activ = jax.vmap(model)(x)[0]
|
||||
loss_value, grads = SAE.loss(sae, activ[0], λ)
|
||||
updates, opt_state = optim.update(grads, opt_state, sae)
|
||||
sae = eqx.apply_updates(sae, updates)
|
||||
return sae, opt_state, loss_value
|
||||
diff_model, static_model = eqx.partition(model, freeze_spec)
|
||||
(loss, aux), grads = SAE.loss(diff_model, static_model, sae_pos, x, y, λ)
|
||||
updates, opt_state = optim.update(grads, opt_state, model)
|
||||
model = eqx.apply_updates(model, updates)
|
||||
return model, opt_state, loss, *aux
|
||||
|
||||
|
||||
def train_loop(
|
||||
sae: SAE,
|
||||
cnn: eqx.Module,
|
||||
model: eqx.Module,
|
||||
freeze_spec: PyTree,
|
||||
sae_pos: Callable,
|
||||
trainloader: DataLoader,
|
||||
testloader: DataLoader,
|
||||
|
@ -110,34 +130,57 @@ def train_loop(
|
|||
tensorboard,
|
||||
λ,
|
||||
) -> eqx.Module:
|
||||
opt_state = optim.init(eqx.filter(sae, eqx.is_array))
|
||||
|
||||
cnn = sow(sae_pos, cnn)
|
||||
|
||||
print(f"test_accuracy={evaluate(cnn, testloader).item()}")
|
||||
opt_state = optim.init(freeze_spec)
|
||||
|
||||
# Loop over our training dataset as many times as we need.
|
||||
def infinite_trainloader():
|
||||
while True:
|
||||
yield from trainloader
|
||||
|
||||
for step, (x, _) in zip(range(steps), infinite_trainloader()):
|
||||
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
||||
# PyTorch dataloaders give PyTorch tensors by default,
|
||||
# so convert them to NumPy arrays.
|
||||
sae, opt_state, train_loss = make_step(
|
||||
sae, cnn, optim, opt_state, x.numpy(), λ
|
||||
model, opt_state, loss, reconstruction_err, l1, deep_err = make_step(
|
||||
model,
|
||||
freeze_spec,
|
||||
sae_pos,
|
||||
optim,
|
||||
opt_state,
|
||||
x.numpy(),
|
||||
y.numpy(),
|
||||
λ
|
||||
)
|
||||
tensorboard.add_scalar("loss", train_loss.item(), step)
|
||||
tensorboard.add_scalar("loss", loss.item(), step)
|
||||
if (step % print_every) == 0 or (step == steps - 1):
|
||||
cnn_with_sae = insert_after(sae_pos, cnn, sae)
|
||||
test_accuracy = evaluate(cnn_with_sae, testloader)
|
||||
test_accuracy = evaluate(model, testloader)
|
||||
print(
|
||||
datetime.now().strftime("%H:%M"),
|
||||
f"{step=}, train_loss={train_loss.item()}, "
|
||||
f"test_accuracy={test_accuracy.item()}",
|
||||
step,
|
||||
f"{loss=:.3f}",
|
||||
f"rec={reconstruction_err.item():.3f}",
|
||||
f"{l1=:.3f}",
|
||||
f"{deep_err=:.3f}",
|
||||
f"{test_accuracy=:.3f}",
|
||||
)
|
||||
tensorboard.add_scalar("accu", test_accuracy.item(), step)
|
||||
return sae
|
||||
return sae_pos(model)
|
||||
|
||||
|
||||
def compose_model(cnn, sae, layer):
|
||||
sae_pos = lambda m: m.layers[layer]
|
||||
model = sow(sae_pos, cnn)
|
||||
model = insert_after(sae_pos, model, sae)
|
||||
model = sow(sae_pos, model)
|
||||
|
||||
sae_pos = lambda m: m.layers[layer].children[1]
|
||||
freeze_spec = jtu.tree_map(lambda _: False, model)
|
||||
freeze_spec = eqx.tree_at(
|
||||
sae_pos,
|
||||
freeze_spec,
|
||||
replace=jtu.tree_map(lambda leaf: eqx.is_array(leaf), sae)
|
||||
)
|
||||
|
||||
return model, freeze_spec, sae_pos
|
||||
|
||||
|
||||
def train_sae(
|
||||
|
@ -154,11 +197,14 @@ def train_sae(
|
|||
λ,
|
||||
):
|
||||
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
||||
# print(f"test_accuracy={evaluate(cnn, testloader).item()}")
|
||||
|
||||
sae = SAE(activ_size, hidden_size, key)
|
||||
model, freeze_spec, sae_pos = compose_model(cnn, sae, sae_pos)
|
||||
optim = optax.adamw(learning_rate)
|
||||
return train_loop(
|
||||
sae,
|
||||
cnn,
|
||||
model,
|
||||
freeze_spec,
|
||||
sae_pos,
|
||||
trainloader,
|
||||
testloader,
|
||||
|
|
37
src/temp.py
Normal file
37
src/temp.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
#! /usr/bin/env python3
|
||||
# vim:fenc=utf-8
|
||||
|
||||
import argtoml
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.tree_util as jtu
|
||||
from jo3util.eqx import insert_after, sow
|
||||
|
||||
from cnn import CNN
|
||||
from sae import SAE
|
||||
|
||||
|
||||
O = argtoml.parse_args()
|
||||
|
||||
key = jax.random.PRNGKey(O.seed)
|
||||
key, subkey = jax.random.split(key)
|
||||
|
||||
cnn = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
|
||||
|
||||
sae_hyparam = O.sae[0]
|
||||
sae_pos = lambda m: m.layers[sae_hyparam.layer]
|
||||
|
||||
sae = SAE(sae_hyparam.input_size, sae_hyparam.hidden_size, key)
|
||||
|
||||
model = sow(sae_pos, cnn)
|
||||
model = insert_after(sae_pos, model, sae)
|
||||
model = sow(sae_pos, model)
|
||||
|
||||
freeze_spec = jtu.tree_map(lambda _: False, model)
|
||||
freeze_spec = eqx.tree_at(
|
||||
lambda m: m.layers[sae_hyparam.layer].children[1],
|
||||
freeze_spec,
|
||||
replace=jtu.tree_map(lambda leaf: eqx.is_array(leaf), sae)
|
||||
)
|
||||
# print(model)
|
||||
print(freeze_spec)
|
Loading…
Reference in New Issue
Block a user