deepsync
This commit is contained in:
parent
63048b1915
commit
1bae37419a
|
@ -38,8 +38,8 @@ else:
|
|||
sae = train_sae(
|
||||
key,
|
||||
model,
|
||||
lambda m: m.layers[4],
|
||||
512,
|
||||
lambda m: m.layers[6],
|
||||
64,
|
||||
O.batch_size,
|
||||
O.learning_rate,
|
||||
O.steps,
|
||||
|
|
|
@ -149,6 +149,8 @@ def train_cnn(
|
|||
trainloader, testloader = mnist.load(batch_size=batch_size)
|
||||
model = CNN(key)
|
||||
optim = optax.adamw(learning_rate)
|
||||
model = train_loop(model, trainloader, testloader, optim, steps, print_every)
|
||||
model = train_loop(
|
||||
model, trainloader, testloader, optim, steps, print_every
|
||||
)
|
||||
eqx.tree_serialise_leaves(model_storage, model)
|
||||
return model
|
||||
|
|
68
src/sae.py
68
src/sae.py
|
@ -5,13 +5,15 @@
|
|||
|
||||
"""
|
||||
|
||||
from typing import Callable, List
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import mnist
|
||||
import optax
|
||||
from jaxtyping import Array, Float, Int, PyTree
|
||||
from jo3util.eqx import sow
|
||||
from jo3util.eqx import insert_after, sow
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
|
@ -33,8 +35,13 @@ class SAE(eqx.Module):
|
|||
# decader bias
|
||||
self.bd = jax.random.uniform(k3, (in_size,))
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.encode(x)
|
||||
x = self.decode(x)
|
||||
return x
|
||||
|
||||
def encode(self, x):
|
||||
x = (x - self.bd) @ self.we + self.be
|
||||
x = ((x - self.bd) @ self.we) + self.be
|
||||
return jax.nn.relu(x)
|
||||
|
||||
def decode(self, fx):
|
||||
|
@ -45,23 +52,42 @@ class SAE(eqx.Module):
|
|||
def loss(sae, x, λ):
|
||||
fx = jax.vmap(sae.encode)(x)
|
||||
x_ = jax.vmap(sae.decode)(fx)
|
||||
sq_err = jnp.dot((x - x_), (x - x_))
|
||||
l1 = λ * jnp.dot(fx, fx)
|
||||
return jnp.mean(sq_err + l1)
|
||||
sq_err = jax.vmap(jnp.dot, (0, 0))((x - x_), (x - x_))
|
||||
l1 = λ * jax.vmap(jnp.dot, (0, 0))(fx, fx)
|
||||
out = jnp.mean(sq_err + l1)
|
||||
return out
|
||||
|
||||
|
||||
def evaluate(model: eqx.Module, testloader: DataLoader):
|
||||
"""This function evaluates the model on the test dataset,
|
||||
computing both the average loss and the average accuracy.
|
||||
"""
|
||||
avg_acc = 0
|
||||
for x, y in testloader:
|
||||
x = x.numpy()
|
||||
y = y.numpy()
|
||||
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
|
||||
# and both have JIT wrappers, so this is fast.
|
||||
pred_y = jnp.argmax(jax.vmap(model)(x)[0], axis=1)
|
||||
avg_acc += jnp.mean(y == pred_y)
|
||||
return avg_acc / len(testloader)
|
||||
|
||||
|
||||
def train_loop(
|
||||
sae: SAE,
|
||||
model: eqx.Module,
|
||||
sae_pos: Callable,
|
||||
trainloader: DataLoader,
|
||||
testloader: DataLoader,
|
||||
optim: optax.GradientTransformation,
|
||||
steps: int,
|
||||
print_every: int,
|
||||
) -> eqx.Module:
|
||||
# Just like earlier: It only makes sense to train the arrays in our model,
|
||||
# so filter out everything else.
|
||||
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
||||
opt_state = optim.init(eqx.filter(sae, eqx.is_array))
|
||||
|
||||
model = sow(sae_pos, model)
|
||||
|
||||
print(f"test_accuracy={evaluate(model, testloader).item()}")
|
||||
|
||||
# Always wrap everything -- computing gradients, running the optimiser, updating
|
||||
# the model -- into a single JIT region. This ensures things run as fast as
|
||||
|
@ -92,12 +118,13 @@ def train_loop(
|
|||
y = y.numpy()
|
||||
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y)
|
||||
if (step % print_every) == 0 or (step == steps - 1):
|
||||
test_loss, test_accuracy = evaluate(sae, model, testloader)
|
||||
model_with_sae = insert_after(sae_pos, model, sae)
|
||||
test_accuracy = evaluate(model_with_sae, testloader)
|
||||
print(
|
||||
f"{step=}, train_loss={train_loss.item()}, "
|
||||
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
|
||||
f"test_accuracy={test_accuracy.item()}"
|
||||
)
|
||||
return model
|
||||
return sae
|
||||
|
||||
|
||||
def train_sae(
|
||||
|
@ -109,15 +136,20 @@ def train_sae(
|
|||
learning_rate,
|
||||
steps,
|
||||
print_every,
|
||||
sae_storage="./res/sae.eqx",
|
||||
sae_storage="./res/sae_layer_6.eqx",
|
||||
):
|
||||
trainloader, testloader = mnist.load(batch_size=batch_size)
|
||||
model = sow(sae_pos, cnn)
|
||||
print(model)
|
||||
sae = SAE(activ_size, 1000, key)
|
||||
optim = optax.adamw(learning_rate)
|
||||
model = train_loop(
|
||||
sae, model, trainloader, testloader, optim, steps, print_every
|
||||
sae = train_loop(
|
||||
sae,
|
||||
cnn,
|
||||
sae_pos,
|
||||
trainloader,
|
||||
testloader,
|
||||
optim,
|
||||
steps,
|
||||
print_every,
|
||||
)
|
||||
# eqx.tree_serialise_leaves(model_storage, model)
|
||||
return model
|
||||
eqx.tree_serialise_leaves(sae_storage, sae)
|
||||
return sae
|
||||
|
|
Loading…
Reference in New Issue
Block a user