This commit is contained in:
JJJHolscher 2023-11-01 09:44:45 +01:00
parent 63048b1915
commit 1bae37419a
3 changed files with 55 additions and 21 deletions

View File

@ -38,8 +38,8 @@ else:
sae = train_sae(
key,
model,
lambda m: m.layers[4],
512,
lambda m: m.layers[6],
64,
O.batch_size,
O.learning_rate,
O.steps,

View File

@ -149,6 +149,8 @@ def train_cnn(
trainloader, testloader = mnist.load(batch_size=batch_size)
model = CNN(key)
optim = optax.adamw(learning_rate)
model = train_loop(model, trainloader, testloader, optim, steps, print_every)
model = train_loop(
model, trainloader, testloader, optim, steps, print_every
)
eqx.tree_serialise_leaves(model_storage, model)
return model

View File

@ -5,13 +5,15 @@
"""
from typing import Callable, List
import equinox as eqx
import jax
import jax.numpy as jnp
import mnist
import optax
from jaxtyping import Array, Float, Int, PyTree
from jo3util.eqx import sow
from jo3util.eqx import insert_after, sow
from torch.utils.data import DataLoader
@ -33,8 +35,13 @@ class SAE(eqx.Module):
# decader bias
self.bd = jax.random.uniform(k3, (in_size,))
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
def encode(self, x):
x = (x - self.bd) @ self.we + self.be
x = ((x - self.bd) @ self.we) + self.be
return jax.nn.relu(x)
def decode(self, fx):
@ -45,23 +52,42 @@ class SAE(eqx.Module):
def loss(sae, x, λ):
fx = jax.vmap(sae.encode)(x)
x_ = jax.vmap(sae.decode)(fx)
sq_err = jnp.dot((x - x_), (x - x_))
l1 = λ * jnp.dot(fx, fx)
return jnp.mean(sq_err + l1)
sq_err = jax.vmap(jnp.dot, (0, 0))((x - x_), (x - x_))
l1 = λ * jax.vmap(jnp.dot, (0, 0))(fx, fx)
out = jnp.mean(sq_err + l1)
return out
def evaluate(model: eqx.Module, testloader: DataLoader):
"""This function evaluates the model on the test dataset,
computing both the average loss and the average accuracy.
"""
avg_acc = 0
for x, y in testloader:
x = x.numpy()
y = y.numpy()
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
# and both have JIT wrappers, so this is fast.
pred_y = jnp.argmax(jax.vmap(model)(x)[0], axis=1)
avg_acc += jnp.mean(y == pred_y)
return avg_acc / len(testloader)
def train_loop(
sae: SAE,
model: eqx.Module,
sae_pos: Callable,
trainloader: DataLoader,
testloader: DataLoader,
optim: optax.GradientTransformation,
steps: int,
print_every: int,
) -> eqx.Module:
# Just like earlier: It only makes sense to train the arrays in our model,
# so filter out everything else.
opt_state = optim.init(eqx.filter(model, eqx.is_array))
opt_state = optim.init(eqx.filter(sae, eqx.is_array))
model = sow(sae_pos, model)
print(f"test_accuracy={evaluate(model, testloader).item()}")
# Always wrap everything -- computing gradients, running the optimiser, updating
# the model -- into a single JIT region. This ensures things run as fast as
@ -92,12 +118,13 @@ def train_loop(
y = y.numpy()
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y)
if (step % print_every) == 0 or (step == steps - 1):
test_loss, test_accuracy = evaluate(sae, model, testloader)
model_with_sae = insert_after(sae_pos, model, sae)
test_accuracy = evaluate(model_with_sae, testloader)
print(
f"{step=}, train_loss={train_loss.item()}, "
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
f"test_accuracy={test_accuracy.item()}"
)
return model
return sae
def train_sae(
@ -109,15 +136,20 @@ def train_sae(
learning_rate,
steps,
print_every,
sae_storage="./res/sae.eqx",
sae_storage="./res/sae_layer_6.eqx",
):
trainloader, testloader = mnist.load(batch_size=batch_size)
model = sow(sae_pos, cnn)
print(model)
sae = SAE(activ_size, 1000, key)
optim = optax.adamw(learning_rate)
model = train_loop(
sae, model, trainloader, testloader, optim, steps, print_every
sae = train_loop(
sae,
cnn,
sae_pos,
trainloader,
testloader,
optim,
steps,
print_every,
)
# eqx.tree_serialise_leaves(model_storage, model)
return model
eqx.tree_serialise_leaves(sae_storage, sae)
return sae