sae werks with deep loss
This commit is contained in:
parent
ecd7d3bd65
commit
a3bad799e3
29
config.toml
29
config.toml
|
@ -1,33 +1,12 @@
|
||||||
batch_size = 64
|
batch_size = 64
|
||||||
steps = 1000000
|
steps = 10000
|
||||||
print_every = 50000
|
print_every = 500
|
||||||
seed = 0
|
seed = 1
|
||||||
cnn_storage = "./res/cnn.eqx"
|
cnn_storage = "./res/cnn.eqx"
|
||||||
|
|
||||||
[[sae]]
|
[[sae]]
|
||||||
layer = 6
|
layer = 6
|
||||||
hidden_size = 300
|
hidden_size = 256
|
||||||
input_size = 64
|
|
||||||
learning_rate = 1e-3
|
|
||||||
l1 = 3e-4 # from Neel Nanda's sae git
|
|
||||||
|
|
||||||
[[sae]]
|
|
||||||
layer = 6
|
|
||||||
hidden_size = 300
|
|
||||||
input_size = 64
|
|
||||||
learning_rate = 3e-4
|
|
||||||
l1 = 3e-4
|
|
||||||
|
|
||||||
[[sae]]
|
|
||||||
layer = 6
|
|
||||||
hidden_size = 300
|
|
||||||
input_size = 64
|
input_size = 64
|
||||||
learning_rate = 1e-4
|
learning_rate = 1e-4
|
||||||
l1 = 3e-4
|
l1 = 3e-4
|
||||||
|
|
||||||
[[sae]]
|
|
||||||
layer = 6
|
|
||||||
hidden_size = 300
|
|
||||||
input_size = 64
|
|
||||||
learning_rate = 3e-5
|
|
||||||
l1 = 3e-4
|
|
||||||
|
|
|
@ -39,21 +39,25 @@ for sae_hyperparams in O.sae:
|
||||||
sae_dir = run_dir(sae_hyperparams)
|
sae_dir = run_dir(sae_hyperparams)
|
||||||
if sae_dir.exists():
|
if sae_dir.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sae = train_sae(
|
|
||||||
key,
|
|
||||||
cnn,
|
|
||||||
lambda m: m.layers[sae_hyperparams.layer],
|
|
||||||
sae_hyperparams.input_size,
|
|
||||||
sae_hyperparams.hidden_size,
|
|
||||||
O.batch_size,
|
|
||||||
sae_hyperparams.learning_rate,
|
|
||||||
O.steps,
|
|
||||||
O.print_every,
|
|
||||||
SummaryWriter(sae_dir / "log")
|
|
||||||
)
|
|
||||||
|
|
||||||
sae_dir.mkdir()
|
sae_dir.mkdir()
|
||||||
|
|
||||||
|
with SummaryWriter(sae_dir / "log") as tensorboard:
|
||||||
|
tensorboard.add_text("hyperparameters", str(sae_hyperparams))
|
||||||
|
|
||||||
|
sae = train_sae(
|
||||||
|
key,
|
||||||
|
cnn,
|
||||||
|
sae_hyperparams.layer,
|
||||||
|
sae_hyperparams.input_size,
|
||||||
|
sae_hyperparams.hidden_size,
|
||||||
|
O.batch_size,
|
||||||
|
sae_hyperparams.learning_rate,
|
||||||
|
O.steps,
|
||||||
|
O.print_every,
|
||||||
|
tensorboard,
|
||||||
|
sae_hyperparams.l1,
|
||||||
|
)
|
||||||
|
|
||||||
argtoml.save(sae_hyperparams, sae_dir / "sae-hyperparams.toml")
|
argtoml.save(sae_hyperparams, sae_dir / "sae-hyperparams.toml")
|
||||||
argtoml.save(O, sae_dir / "config.toml")
|
argtoml.save(O, sae_dir / "config.toml")
|
||||||
jo3eqx.save(
|
jo3eqx.save(
|
||||||
|
@ -70,7 +74,7 @@ for sae_hyperparams in O.sae:
|
||||||
sae = jo3eqx.load(sae_dir / f"sae.eqx", SAE)
|
sae = jo3eqx.load(sae_dir / f"sae.eqx", SAE)
|
||||||
sown_cnn = jo3eqx.sow(lambda m: m.layers[sae_hyperparams.layer], cnn)
|
sown_cnn = jo3eqx.sow(lambda m: m.layers[sae_hyperparams.layer], cnn)
|
||||||
trainloader, testloader = jo3mnist.load(
|
trainloader, testloader = jo3mnist.load(
|
||||||
batch_size=O.batch_size, shuffle=False
|
batch_size=1, shuffle=False
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dir = sae_dir / f"train"
|
train_dir = sae_dir / f"train"
|
||||||
|
|
120
src/sae.py
120
src/sae.py
|
@ -11,6 +11,7 @@ from typing import Callable, List
|
||||||
import equinox as eqx
|
import equinox as eqx
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import jax.tree_util as jtu
|
||||||
import jo3mnist
|
import jo3mnist
|
||||||
import optax
|
import optax
|
||||||
from jaxtyping import Array, Float, Int, PyTree
|
from jaxtyping import Array, Float, Int, PyTree
|
||||||
|
@ -18,6 +19,12 @@ from jo3util.eqx import insert_after, sow
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from .cnn import cross_entropy
|
||||||
|
|
||||||
|
|
||||||
|
def filter_value_and_grad_with_aux(f):
|
||||||
|
return eqx.filter_value_and_grad(f, has_aux=True)
|
||||||
|
|
||||||
|
|
||||||
class SAE(eqx.Module):
|
class SAE(eqx.Module):
|
||||||
we: Float
|
we: Float
|
||||||
|
@ -27,15 +34,16 @@ class SAE(eqx.Module):
|
||||||
|
|
||||||
def __init__(self, in_size, hidden_size, key=jax.random.PRNGKey(42)):
|
def __init__(self, in_size, hidden_size, key=jax.random.PRNGKey(42)):
|
||||||
k0, k1, k2, k3 = jax.random.split(key, 4)
|
k0, k1, k2, k3 = jax.random.split(key, 4)
|
||||||
|
initializer = jax.nn.initializers.he_uniform()
|
||||||
|
|
||||||
# encoder weight matrix
|
# encoder weight matrix
|
||||||
self.we = jax.random.uniform(k0, (in_size, hidden_size))
|
self.we = initializer(k0, (in_size, hidden_size))
|
||||||
# decoder weight matrix
|
# decoder weight matrix
|
||||||
self.wd = jax.random.uniform(k1, (hidden_size, in_size))
|
self.wd = initializer(k1, (hidden_size, in_size))
|
||||||
# encoder bias
|
# encoder bias
|
||||||
self.be = jax.random.uniform(k2, (hidden_size,))
|
self.be = jnp.zeros((hidden_size,))
|
||||||
# decader bias
|
# decader bias
|
||||||
self.bd = jax.random.uniform(k3, (in_size,))
|
self.bd = jnp.zeros((in_size,))
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
x = self.encode(x)
|
x = self.encode(x)
|
||||||
|
@ -49,22 +57,32 @@ class SAE(eqx.Module):
|
||||||
def decode(self, fx):
|
def decode(self, fx):
|
||||||
return fx @ self.wd + self.bd
|
return fx @ self.wd + self.bd
|
||||||
|
|
||||||
|
def l1(self, x):
|
||||||
|
x = self.encode(x)
|
||||||
|
return jax.vmap(jnp.dot, (0, 0))(x, x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@eqx.filter_value_and_grad
|
@filter_value_and_grad_with_aux
|
||||||
def loss(sae, x, λ):
|
def loss(diff_model, static_model, sae_pos, x, y, λ):
|
||||||
fx = jax.vmap(sae.encode)(x)
|
model = eqx.combine(diff_model, static_model)
|
||||||
x_ = jax.vmap(sae.decode)(fx)
|
original_activ, reconstructed_activ, pred = jax.vmap(model)(x)
|
||||||
sq_err = jax.vmap(jnp.dot, (0, 0))((x - x_), (x - x_))
|
|
||||||
l1 = λ * jax.vmap(jnp.dot, (0, 0))(fx, fx)
|
reconstruction_err = jnp.mean(jax.vmap(jnp.dot, (0, 0))(
|
||||||
out = jnp.mean(sq_err + l1)
|
(original_activ - reconstructed_activ),
|
||||||
return out
|
(original_activ - reconstructed_activ)
|
||||||
|
))
|
||||||
|
l1 = λ * jnp.mean(sae_pos(model).l1(original_activ))
|
||||||
|
deep_err = jnp.mean(cross_entropy(y, pred))
|
||||||
|
|
||||||
|
loss = reconstruction_err + l1 + deep_err
|
||||||
|
return loss, (reconstruction_err, l1, deep_err)
|
||||||
|
|
||||||
|
|
||||||
def sample_features(cnn, sae, loader):
|
def sample_features(cnn, sae, loader):
|
||||||
for i, (x, _) in enumerate(loader):
|
for i, (x, _) in enumerate(loader):
|
||||||
x = x.numpy()
|
x = x.numpy()
|
||||||
activ = jax.vmap(cnn)(x)[0]
|
activ = jax.vmap(cnn)(x)[0]
|
||||||
yield i, sae.encode(activ)
|
yield i, sae.encode(activ)[0]
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model: eqx.Module, testloader: DataLoader):
|
def evaluate(model: eqx.Module, testloader: DataLoader):
|
||||||
|
@ -84,23 +102,25 @@ def evaluate(model: eqx.Module, testloader: DataLoader):
|
||||||
|
|
||||||
@eqx.filter_jit
|
@eqx.filter_jit
|
||||||
def make_step(
|
def make_step(
|
||||||
sae: SAE,
|
|
||||||
model: eqx.Module,
|
model: eqx.Module,
|
||||||
|
freeze_spec: PyTree,
|
||||||
|
sae_pos: Callable,
|
||||||
optim,
|
optim,
|
||||||
opt_state: PyTree,
|
opt_state: PyTree,
|
||||||
x: Float[Array, "batch 1 28 28"],
|
x: Float[Array, "batch 1 28 28"],
|
||||||
|
y: Float[Array, "batch"],
|
||||||
λ: float,
|
λ: float,
|
||||||
):
|
):
|
||||||
activ = jax.vmap(model)(x)[0]
|
diff_model, static_model = eqx.partition(model, freeze_spec)
|
||||||
loss_value, grads = SAE.loss(sae, activ[0], λ)
|
(loss, aux), grads = SAE.loss(diff_model, static_model, sae_pos, x, y, λ)
|
||||||
updates, opt_state = optim.update(grads, opt_state, sae)
|
updates, opt_state = optim.update(grads, opt_state, model)
|
||||||
sae = eqx.apply_updates(sae, updates)
|
model = eqx.apply_updates(model, updates)
|
||||||
return sae, opt_state, loss_value
|
return model, opt_state, loss, *aux
|
||||||
|
|
||||||
|
|
||||||
def train_loop(
|
def train_loop(
|
||||||
sae: SAE,
|
model: eqx.Module,
|
||||||
cnn: eqx.Module,
|
freeze_spec: PyTree,
|
||||||
sae_pos: Callable,
|
sae_pos: Callable,
|
||||||
trainloader: DataLoader,
|
trainloader: DataLoader,
|
||||||
testloader: DataLoader,
|
testloader: DataLoader,
|
||||||
|
@ -110,34 +130,57 @@ def train_loop(
|
||||||
tensorboard,
|
tensorboard,
|
||||||
λ,
|
λ,
|
||||||
) -> eqx.Module:
|
) -> eqx.Module:
|
||||||
opt_state = optim.init(eqx.filter(sae, eqx.is_array))
|
opt_state = optim.init(freeze_spec)
|
||||||
|
|
||||||
cnn = sow(sae_pos, cnn)
|
|
||||||
|
|
||||||
print(f"test_accuracy={evaluate(cnn, testloader).item()}")
|
|
||||||
|
|
||||||
# Loop over our training dataset as many times as we need.
|
# Loop over our training dataset as many times as we need.
|
||||||
def infinite_trainloader():
|
def infinite_trainloader():
|
||||||
while True:
|
while True:
|
||||||
yield from trainloader
|
yield from trainloader
|
||||||
|
|
||||||
for step, (x, _) in zip(range(steps), infinite_trainloader()):
|
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
||||||
# PyTorch dataloaders give PyTorch tensors by default,
|
# PyTorch dataloaders give PyTorch tensors by default,
|
||||||
# so convert them to NumPy arrays.
|
# so convert them to NumPy arrays.
|
||||||
sae, opt_state, train_loss = make_step(
|
model, opt_state, loss, reconstruction_err, l1, deep_err = make_step(
|
||||||
sae, cnn, optim, opt_state, x.numpy(), λ
|
model,
|
||||||
|
freeze_spec,
|
||||||
|
sae_pos,
|
||||||
|
optim,
|
||||||
|
opt_state,
|
||||||
|
x.numpy(),
|
||||||
|
y.numpy(),
|
||||||
|
λ
|
||||||
)
|
)
|
||||||
tensorboard.add_scalar("loss", train_loss.item(), step)
|
tensorboard.add_scalar("loss", loss.item(), step)
|
||||||
if (step % print_every) == 0 or (step == steps - 1):
|
if (step % print_every) == 0 or (step == steps - 1):
|
||||||
cnn_with_sae = insert_after(sae_pos, cnn, sae)
|
test_accuracy = evaluate(model, testloader)
|
||||||
test_accuracy = evaluate(cnn_with_sae, testloader)
|
|
||||||
print(
|
print(
|
||||||
datetime.now().strftime("%H:%M"),
|
datetime.now().strftime("%H:%M"),
|
||||||
f"{step=}, train_loss={train_loss.item()}, "
|
step,
|
||||||
f"test_accuracy={test_accuracy.item()}",
|
f"{loss=:.3f}",
|
||||||
|
f"rec={reconstruction_err.item():.3f}",
|
||||||
|
f"{l1=:.3f}",
|
||||||
|
f"{deep_err=:.3f}",
|
||||||
|
f"{test_accuracy=:.3f}",
|
||||||
)
|
)
|
||||||
tensorboard.add_scalar("accu", test_accuracy.item(), step)
|
tensorboard.add_scalar("accu", test_accuracy.item(), step)
|
||||||
return sae
|
return sae_pos(model)
|
||||||
|
|
||||||
|
|
||||||
|
def compose_model(cnn, sae, layer):
|
||||||
|
sae_pos = lambda m: m.layers[layer]
|
||||||
|
model = sow(sae_pos, cnn)
|
||||||
|
model = insert_after(sae_pos, model, sae)
|
||||||
|
model = sow(sae_pos, model)
|
||||||
|
|
||||||
|
sae_pos = lambda m: m.layers[layer].children[1]
|
||||||
|
freeze_spec = jtu.tree_map(lambda _: False, model)
|
||||||
|
freeze_spec = eqx.tree_at(
|
||||||
|
sae_pos,
|
||||||
|
freeze_spec,
|
||||||
|
replace=jtu.tree_map(lambda leaf: eqx.is_array(leaf), sae)
|
||||||
|
)
|
||||||
|
|
||||||
|
return model, freeze_spec, sae_pos
|
||||||
|
|
||||||
|
|
||||||
def train_sae(
|
def train_sae(
|
||||||
|
@ -154,11 +197,14 @@ def train_sae(
|
||||||
λ,
|
λ,
|
||||||
):
|
):
|
||||||
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
||||||
|
# print(f"test_accuracy={evaluate(cnn, testloader).item()}")
|
||||||
|
|
||||||
sae = SAE(activ_size, hidden_size, key)
|
sae = SAE(activ_size, hidden_size, key)
|
||||||
|
model, freeze_spec, sae_pos = compose_model(cnn, sae, sae_pos)
|
||||||
optim = optax.adamw(learning_rate)
|
optim = optax.adamw(learning_rate)
|
||||||
return train_loop(
|
return train_loop(
|
||||||
sae,
|
model,
|
||||||
cnn,
|
freeze_spec,
|
||||||
sae_pos,
|
sae_pos,
|
||||||
trainloader,
|
trainloader,
|
||||||
testloader,
|
testloader,
|
||||||
|
|
37
src/temp.py
Normal file
37
src/temp.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
|
||||||
|
import argtoml
|
||||||
|
import equinox as eqx
|
||||||
|
import jax
|
||||||
|
import jax.tree_util as jtu
|
||||||
|
from jo3util.eqx import insert_after, sow
|
||||||
|
|
||||||
|
from cnn import CNN
|
||||||
|
from sae import SAE
|
||||||
|
|
||||||
|
|
||||||
|
O = argtoml.parse_args()
|
||||||
|
|
||||||
|
key = jax.random.PRNGKey(O.seed)
|
||||||
|
key, subkey = jax.random.split(key)
|
||||||
|
|
||||||
|
cnn = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
|
||||||
|
|
||||||
|
sae_hyparam = O.sae[0]
|
||||||
|
sae_pos = lambda m: m.layers[sae_hyparam.layer]
|
||||||
|
|
||||||
|
sae = SAE(sae_hyparam.input_size, sae_hyparam.hidden_size, key)
|
||||||
|
|
||||||
|
model = sow(sae_pos, cnn)
|
||||||
|
model = insert_after(sae_pos, model, sae)
|
||||||
|
model = sow(sae_pos, model)
|
||||||
|
|
||||||
|
freeze_spec = jtu.tree_map(lambda _: False, model)
|
||||||
|
freeze_spec = eqx.tree_at(
|
||||||
|
lambda m: m.layers[sae_hyparam.layer].children[1],
|
||||||
|
freeze_spec,
|
||||||
|
replace=jtu.tree_map(lambda leaf: eqx.is_array(leaf), sae)
|
||||||
|
)
|
||||||
|
# print(model)
|
||||||
|
print(freeze_spec)
|
Loading…
Reference in New Issue
Block a user