This commit is contained in:
JJJHolscher 2023-11-01 09:44:45 +01:00
parent 63048b1915
commit 1bae37419a
3 changed files with 55 additions and 21 deletions

View File

@ -38,8 +38,8 @@ else:
sae = train_sae( sae = train_sae(
key, key,
model, model,
lambda m: m.layers[4], lambda m: m.layers[6],
512, 64,
O.batch_size, O.batch_size,
O.learning_rate, O.learning_rate,
O.steps, O.steps,

View File

@ -149,6 +149,8 @@ def train_cnn(
trainloader, testloader = mnist.load(batch_size=batch_size) trainloader, testloader = mnist.load(batch_size=batch_size)
model = CNN(key) model = CNN(key)
optim = optax.adamw(learning_rate) optim = optax.adamw(learning_rate)
model = train_loop(model, trainloader, testloader, optim, steps, print_every) model = train_loop(
model, trainloader, testloader, optim, steps, print_every
)
eqx.tree_serialise_leaves(model_storage, model) eqx.tree_serialise_leaves(model_storage, model)
return model return model

View File

@ -5,13 +5,15 @@
""" """
from typing import Callable, List
import equinox as eqx import equinox as eqx
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import mnist import mnist
import optax import optax
from jaxtyping import Array, Float, Int, PyTree from jaxtyping import Array, Float, Int, PyTree
from jo3util.eqx import sow from jo3util.eqx import insert_after, sow
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -33,8 +35,13 @@ class SAE(eqx.Module):
# decader bias # decader bias
self.bd = jax.random.uniform(k3, (in_size,)) self.bd = jax.random.uniform(k3, (in_size,))
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
def encode(self, x): def encode(self, x):
x = (x - self.bd) @ self.we + self.be x = ((x - self.bd) @ self.we) + self.be
return jax.nn.relu(x) return jax.nn.relu(x)
def decode(self, fx): def decode(self, fx):
@ -45,23 +52,42 @@ class SAE(eqx.Module):
def loss(sae, x, λ): def loss(sae, x, λ):
fx = jax.vmap(sae.encode)(x) fx = jax.vmap(sae.encode)(x)
x_ = jax.vmap(sae.decode)(fx) x_ = jax.vmap(sae.decode)(fx)
sq_err = jnp.dot((x - x_), (x - x_)) sq_err = jax.vmap(jnp.dot, (0, 0))((x - x_), (x - x_))
l1 = λ * jnp.dot(fx, fx) l1 = λ * jax.vmap(jnp.dot, (0, 0))(fx, fx)
return jnp.mean(sq_err + l1) out = jnp.mean(sq_err + l1)
return out
def evaluate(model: eqx.Module, testloader: DataLoader):
"""This function evaluates the model on the test dataset,
computing both the average loss and the average accuracy.
"""
avg_acc = 0
for x, y in testloader:
x = x.numpy()
y = y.numpy()
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
# and both have JIT wrappers, so this is fast.
pred_y = jnp.argmax(jax.vmap(model)(x)[0], axis=1)
avg_acc += jnp.mean(y == pred_y)
return avg_acc / len(testloader)
def train_loop( def train_loop(
sae: SAE, sae: SAE,
model: eqx.Module, model: eqx.Module,
sae_pos: Callable,
trainloader: DataLoader, trainloader: DataLoader,
testloader: DataLoader, testloader: DataLoader,
optim: optax.GradientTransformation, optim: optax.GradientTransformation,
steps: int, steps: int,
print_every: int, print_every: int,
) -> eqx.Module: ) -> eqx.Module:
# Just like earlier: It only makes sense to train the arrays in our model, opt_state = optim.init(eqx.filter(sae, eqx.is_array))
# so filter out everything else.
opt_state = optim.init(eqx.filter(model, eqx.is_array)) model = sow(sae_pos, model)
print(f"test_accuracy={evaluate(model, testloader).item()}")
# Always wrap everything -- computing gradients, running the optimiser, updating # Always wrap everything -- computing gradients, running the optimiser, updating
# the model -- into a single JIT region. This ensures things run as fast as # the model -- into a single JIT region. This ensures things run as fast as
@ -92,12 +118,13 @@ def train_loop(
y = y.numpy() y = y.numpy()
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y) sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y)
if (step % print_every) == 0 or (step == steps - 1): if (step % print_every) == 0 or (step == steps - 1):
test_loss, test_accuracy = evaluate(sae, model, testloader) model_with_sae = insert_after(sae_pos, model, sae)
test_accuracy = evaluate(model_with_sae, testloader)
print( print(
f"{step=}, train_loss={train_loss.item()}, " f"{step=}, train_loss={train_loss.item()}, "
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}" f"test_accuracy={test_accuracy.item()}"
) )
return model return sae
def train_sae( def train_sae(
@ -109,15 +136,20 @@ def train_sae(
learning_rate, learning_rate,
steps, steps,
print_every, print_every,
sae_storage="./res/sae.eqx", sae_storage="./res/sae_layer_6.eqx",
): ):
trainloader, testloader = mnist.load(batch_size=batch_size) trainloader, testloader = mnist.load(batch_size=batch_size)
model = sow(sae_pos, cnn)
print(model)
sae = SAE(activ_size, 1000, key) sae = SAE(activ_size, 1000, key)
optim = optax.adamw(learning_rate) optim = optax.adamw(learning_rate)
model = train_loop( sae = train_loop(
sae, model, trainloader, testloader, optim, steps, print_every sae,
cnn,
sae_pos,
trainloader,
testloader,
optim,
steps,
print_every,
) )
# eqx.tree_serialise_leaves(model_storage, model) eqx.tree_serialise_leaves(sae_storage, sae)
return model return sae