deepsync
This commit is contained in:
parent
63048b1915
commit
1bae37419a
|
@ -38,8 +38,8 @@ else:
|
||||||
sae = train_sae(
|
sae = train_sae(
|
||||||
key,
|
key,
|
||||||
model,
|
model,
|
||||||
lambda m: m.layers[4],
|
lambda m: m.layers[6],
|
||||||
512,
|
64,
|
||||||
O.batch_size,
|
O.batch_size,
|
||||||
O.learning_rate,
|
O.learning_rate,
|
||||||
O.steps,
|
O.steps,
|
||||||
|
|
|
@ -149,6 +149,8 @@ def train_cnn(
|
||||||
trainloader, testloader = mnist.load(batch_size=batch_size)
|
trainloader, testloader = mnist.load(batch_size=batch_size)
|
||||||
model = CNN(key)
|
model = CNN(key)
|
||||||
optim = optax.adamw(learning_rate)
|
optim = optax.adamw(learning_rate)
|
||||||
model = train_loop(model, trainloader, testloader, optim, steps, print_every)
|
model = train_loop(
|
||||||
|
model, trainloader, testloader, optim, steps, print_every
|
||||||
|
)
|
||||||
eqx.tree_serialise_leaves(model_storage, model)
|
eqx.tree_serialise_leaves(model_storage, model)
|
||||||
return model
|
return model
|
||||||
|
|
68
src/sae.py
68
src/sae.py
|
@ -5,13 +5,15 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
import equinox as eqx
|
import equinox as eqx
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import mnist
|
import mnist
|
||||||
import optax
|
import optax
|
||||||
from jaxtyping import Array, Float, Int, PyTree
|
from jaxtyping import Array, Float, Int, PyTree
|
||||||
from jo3util.eqx import sow
|
from jo3util.eqx import insert_after, sow
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,8 +35,13 @@ class SAE(eqx.Module):
|
||||||
# decader bias
|
# decader bias
|
||||||
self.bd = jax.random.uniform(k3, (in_size,))
|
self.bd = jax.random.uniform(k3, (in_size,))
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.encode(x)
|
||||||
|
x = self.decode(x)
|
||||||
|
return x
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
x = (x - self.bd) @ self.we + self.be
|
x = ((x - self.bd) @ self.we) + self.be
|
||||||
return jax.nn.relu(x)
|
return jax.nn.relu(x)
|
||||||
|
|
||||||
def decode(self, fx):
|
def decode(self, fx):
|
||||||
|
@ -45,23 +52,42 @@ class SAE(eqx.Module):
|
||||||
def loss(sae, x, λ):
|
def loss(sae, x, λ):
|
||||||
fx = jax.vmap(sae.encode)(x)
|
fx = jax.vmap(sae.encode)(x)
|
||||||
x_ = jax.vmap(sae.decode)(fx)
|
x_ = jax.vmap(sae.decode)(fx)
|
||||||
sq_err = jnp.dot((x - x_), (x - x_))
|
sq_err = jax.vmap(jnp.dot, (0, 0))((x - x_), (x - x_))
|
||||||
l1 = λ * jnp.dot(fx, fx)
|
l1 = λ * jax.vmap(jnp.dot, (0, 0))(fx, fx)
|
||||||
return jnp.mean(sq_err + l1)
|
out = jnp.mean(sq_err + l1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model: eqx.Module, testloader: DataLoader):
|
||||||
|
"""This function evaluates the model on the test dataset,
|
||||||
|
computing both the average loss and the average accuracy.
|
||||||
|
"""
|
||||||
|
avg_acc = 0
|
||||||
|
for x, y in testloader:
|
||||||
|
x = x.numpy()
|
||||||
|
y = y.numpy()
|
||||||
|
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
|
||||||
|
# and both have JIT wrappers, so this is fast.
|
||||||
|
pred_y = jnp.argmax(jax.vmap(model)(x)[0], axis=1)
|
||||||
|
avg_acc += jnp.mean(y == pred_y)
|
||||||
|
return avg_acc / len(testloader)
|
||||||
|
|
||||||
|
|
||||||
def train_loop(
|
def train_loop(
|
||||||
sae: SAE,
|
sae: SAE,
|
||||||
model: eqx.Module,
|
model: eqx.Module,
|
||||||
|
sae_pos: Callable,
|
||||||
trainloader: DataLoader,
|
trainloader: DataLoader,
|
||||||
testloader: DataLoader,
|
testloader: DataLoader,
|
||||||
optim: optax.GradientTransformation,
|
optim: optax.GradientTransformation,
|
||||||
steps: int,
|
steps: int,
|
||||||
print_every: int,
|
print_every: int,
|
||||||
) -> eqx.Module:
|
) -> eqx.Module:
|
||||||
# Just like earlier: It only makes sense to train the arrays in our model,
|
opt_state = optim.init(eqx.filter(sae, eqx.is_array))
|
||||||
# so filter out everything else.
|
|
||||||
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
model = sow(sae_pos, model)
|
||||||
|
|
||||||
|
print(f"test_accuracy={evaluate(model, testloader).item()}")
|
||||||
|
|
||||||
# Always wrap everything -- computing gradients, running the optimiser, updating
|
# Always wrap everything -- computing gradients, running the optimiser, updating
|
||||||
# the model -- into a single JIT region. This ensures things run as fast as
|
# the model -- into a single JIT region. This ensures things run as fast as
|
||||||
|
@ -92,12 +118,13 @@ def train_loop(
|
||||||
y = y.numpy()
|
y = y.numpy()
|
||||||
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y)
|
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y)
|
||||||
if (step % print_every) == 0 or (step == steps - 1):
|
if (step % print_every) == 0 or (step == steps - 1):
|
||||||
test_loss, test_accuracy = evaluate(sae, model, testloader)
|
model_with_sae = insert_after(sae_pos, model, sae)
|
||||||
|
test_accuracy = evaluate(model_with_sae, testloader)
|
||||||
print(
|
print(
|
||||||
f"{step=}, train_loss={train_loss.item()}, "
|
f"{step=}, train_loss={train_loss.item()}, "
|
||||||
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
|
f"test_accuracy={test_accuracy.item()}"
|
||||||
)
|
)
|
||||||
return model
|
return sae
|
||||||
|
|
||||||
|
|
||||||
def train_sae(
|
def train_sae(
|
||||||
|
@ -109,15 +136,20 @@ def train_sae(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
steps,
|
steps,
|
||||||
print_every,
|
print_every,
|
||||||
sae_storage="./res/sae.eqx",
|
sae_storage="./res/sae_layer_6.eqx",
|
||||||
):
|
):
|
||||||
trainloader, testloader = mnist.load(batch_size=batch_size)
|
trainloader, testloader = mnist.load(batch_size=batch_size)
|
||||||
model = sow(sae_pos, cnn)
|
|
||||||
print(model)
|
|
||||||
sae = SAE(activ_size, 1000, key)
|
sae = SAE(activ_size, 1000, key)
|
||||||
optim = optax.adamw(learning_rate)
|
optim = optax.adamw(learning_rate)
|
||||||
model = train_loop(
|
sae = train_loop(
|
||||||
sae, model, trainloader, testloader, optim, steps, print_every
|
sae,
|
||||||
|
cnn,
|
||||||
|
sae_pos,
|
||||||
|
trainloader,
|
||||||
|
testloader,
|
||||||
|
optim,
|
||||||
|
steps,
|
||||||
|
print_every,
|
||||||
)
|
)
|
||||||
# eqx.tree_serialise_leaves(model_storage, model)
|
eqx.tree_serialise_leaves(sae_storage, sae)
|
||||||
return model
|
return sae
|
||||||
|
|
Loading…
Reference in New Issue
Block a user