gradually getting there, have yet to get a single trained sae

This commit is contained in:
JJJHolscher 2023-10-19 17:45:43 +02:00
parent ffb976d94e
commit 63048b1915
18 changed files with 358 additions and 28 deletions

13
.gitignore vendored Normal file
View File

@ -0,0 +1,13 @@
# JO3'S DEFAULT IGNORE RULES
/tmp*
## python
/build/
/dist/
/*.egg-info/
/.venv/
__pycache__
**.egg-info/
**.ipynb_checkpoints
/Untitle*.ipynb
## rust
/target/

7
config.toml Normal file
View File

@ -0,0 +1,7 @@
batch_size = 64
learning_rate = 3e-4
steps = 300
print_every = 30
seed = 5678
cnn_storage = "./res/cnn.eqx"
sae_storage = "./res/sae.eqx"

View File

@ -1,24 +0,0 @@
build
debugpy
equinox
jax>=0.4.14
jaxtyping
matplotlib
nbclassic
notebook
optax
pandas
pyright
scikit-learn
tensorboard
tensorboardX
torch
torchvision
tqdm
twine
typeguard
git+file://../jo3util
git+https://github.com/kiyoon/jupynium.nvim@v0.2.1
git+https://github.com/JJJHolscher/jupytools
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-f https://download.pytorch.org/whl/cu118

View File

@ -1,6 +1,6 @@
[project] [project]
name = "NAME" name = "sparse_autoencoder"
version = "0.0.0" # TODO; automatically update versions by looking at git version = "0.0.0" # TODO; automatically update versions by looking at git
description = "" description = ""
dependencies = [] dependencies = []
@ -19,7 +19,7 @@ demo = []
github = "JJJHolscher" github = "JJJHolscher"
[project.urls] [project.urls]
homepage = "https://github.com/JJJHolscher/NAME" homepage = "https://github.com/JJJHolscher/sparse_autoencoder"
[[project.authors]] [[project.authors]]
name = "Jochem Hölscher" name = "Jochem Hölscher"
@ -32,7 +32,7 @@ requires = [
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
include = ["NAME"] include = ["sparse_autoencoder"]
[tool.setuptools.dynamic] [tool.setuptools.dynamic]
readme = {file = ["README.md"], content-type = "text/markdown"} readme = {file = ["README.md"], content-type = "text/markdown"}

View File

@ -1,11 +1,26 @@
argtoml argtoml
build build
debugpy debugpy
jo3util equinox
jax>=0.4.14
jaxtyping
matplotlib
nbclassic nbclassic
notebook notebook
optax
pandas
pyright pyright
scikit-learn
tensorboard
tensorboardX
torch
torchvision
tqdm
twine twine
typeguard typeguard
git+file:///mnt/nas/git/jo3util
git+https://github.com/kiyoon/jupynium.nvim@v0.2.1 git+https://github.com/kiyoon/jupynium.nvim@v0.2.1
git+https://github.com/JJJHolscher/jupytools git+https://github.com/JJJHolscher/jupytools
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-f https://download.pytorch.org/whl/cu118
git+file:///mnt/nas/git/mnist

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
res/cnn.eqx Normal file

Binary file not shown.

1
sparse_autoencoder Symbolic link
View File

@ -0,0 +1 @@
./src

View File

@ -5,3 +5,44 @@
""" """
import argtoml
import equinox as eqx
import jax
import jax.numpy as jnp
import mnist
import optax # https://github.com/deepmind/optax
from jaxtyping import Float # https://github.com/google/jaxtyping
from jaxtyping import Array, Int, PyTree
from torch.utils.data import DataLoader
from .cnn import CNN, train_cnn
from .sae import SAE, train_sae
# Hyperparameters
O = argtoml.parse_args()
key = jax.random.PRNGKey(O.seed)
key, subkey = jax.random.split(key)
if O.cnn_storage.exists():
model = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
else:
model = train_cnn(
subkey,
O.batch_size,
O.learning_rate,
O.steps,
O.print_every,
O.cnn_storage,
)
sae = train_sae(
key,
model,
lambda m: m.layers[4],
512,
O.batch_size,
O.learning_rate,
O.steps,
O.print_every,
O.sae_storage,
)

154
src/cnn.py Normal file
View File

@ -0,0 +1,154 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
"""
"""
import equinox as eqx
import jax
import jax.numpy as jnp
import mnist
import optax # https://github.com/deepmind/optax
from jaxtyping import Float # https://github.com/google/jaxtyping
from jaxtyping import Array, Int, PyTree
from torch.utils.data import DataLoader
class CNN(eqx.Module):
layers: list
def __init__(self, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
# Standard CNN setup: convolutional layer, followed by flattening,
# with a small MLP on top.
self.layers = [
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
eqx.nn.MaxPool2d(kernel_size=2),
jax.nn.relu,
jnp.ravel,
eqx.nn.Linear(1728, 512, key=key2),
jax.nn.sigmoid,
eqx.nn.Linear(512, 64, key=key3),
jax.nn.relu,
eqx.nn.Linear(64, 10, key=key4),
jax.nn.log_softmax,
]
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
for layer in self.layers:
x = layer(x)
return x
@eqx.filter_jit
def loss(
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
# Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
# a single input input image of shape (1, 28, 28).
#
# Therefore, we have to use jax.vmap, which in this case maps our model over the
# leading (batch) axis.
pred_y = jax.vmap(model)(x)
return cross_entropy(y, pred_y)
def cross_entropy(
y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
# y are the true targets, and should be integers 0-9.
# pred_y are the log-softmax'd predictions.
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
return -jnp.mean(pred_y)
@eqx.filter_jit
def compute_accuracy(
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
"""This function takes as input the current model
and computes the average accuracy on a batch.
"""
pred_y = jax.vmap(model)(x)
pred_y = jnp.argmax(pred_y, axis=1)
return jnp.mean(y == pred_y)
def evaluate(model: CNN, testloader: DataLoader):
"""This function evaluates the model on the test dataset,
computing both the average loss and the average accuracy.
"""
avg_loss = 0
avg_acc = 0
for x, y in testloader:
x = x.numpy()
y = y.numpy()
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
# and both have JIT wrappers, so this is fast.
avg_loss += loss(model, x, y)
avg_acc += compute_accuracy(model, x, y)
return avg_loss / len(testloader), avg_acc / len(testloader)
def train_loop(
model: CNN,
trainloader: DataLoader,
testloader: DataLoader,
optim: optax.GradientTransformation,
steps: int,
print_every: int,
) -> CNN:
# Just like earlier: It only makes sense to train the arrays in our model,
# so filter out everything else.
opt_state = optim.init(eqx.filter(model, eqx.is_array))
# Always wrap everything -- computing gradients, running the optimiser, updating
# the model -- into a single JIT region. This ensures things run as fast as
# possible.
@eqx.filter_jit
def make_step(
model: CNN,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"],
):
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
updates, opt_state = optim.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss_value
# Loop over our training dataset as many times as we need.
def infinite_trainloader():
while True:
yield from trainloader
for step, (x, y) in zip(range(steps), infinite_trainloader()):
# PyTorch dataloaders give PyTorch tensors by default,
# so convert them to NumPy arrays.
x = x.numpy()
y = y.numpy()
model, opt_state, train_loss = make_step(model, opt_state, x, y)
if (step % print_every) == 0 or (step == steps - 1):
test_loss, test_accuracy = evaluate(model, testloader)
print(
f"{step=}, train_loss={train_loss.item()}, "
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
)
return model
def train_cnn(
key,
batch_size,
learning_rate,
steps,
print_every,
model_storage="./res/cnn.eqx",
):
trainloader, testloader = mnist.load(batch_size=batch_size)
model = CNN(key)
optim = optax.adamw(learning_rate)
model = train_loop(model, trainloader, testloader, optim, steps, print_every)
eqx.tree_serialise_leaves(model_storage, model)
return model

123
src/sae.py Normal file
View File

@ -0,0 +1,123 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
"""
"""
import equinox as eqx
import jax
import jax.numpy as jnp
import mnist
import optax
from jaxtyping import Array, Float, Int, PyTree
from jo3util.eqx import sow
from torch.utils.data import DataLoader
class SAE(eqx.Module):
we: Float
wd: Float
be: Float
bd: Float
def __init__(self, in_size, hidden_size, key=jax.random.PRNGKey(42)):
k0, k1, k2, k3 = jax.random.split(key, 4)
# encoder weight matrix
self.we = jax.random.uniform(k0, (in_size, hidden_size))
# decoder weight matrix
self.wd = jax.random.uniform(k1, (hidden_size, in_size))
# encoder bias
self.be = jax.random.uniform(k2, (hidden_size,))
# decader bias
self.bd = jax.random.uniform(k3, (in_size,))
def encode(self, x):
x = (x - self.bd) @ self.we + self.be
return jax.nn.relu(x)
def decode(self, fx):
return fx @ self.wd + self.bd
@staticmethod
@eqx.filter_value_and_grad
def loss(sae, x, λ):
fx = jax.vmap(sae.encode)(x)
x_ = jax.vmap(sae.decode)(fx)
sq_err = jnp.dot((x - x_), (x - x_))
l1 = λ * jnp.dot(fx, fx)
return jnp.mean(sq_err + l1)
def train_loop(
sae: SAE,
model: eqx.Module,
trainloader: DataLoader,
testloader: DataLoader,
optim: optax.GradientTransformation,
steps: int,
print_every: int,
) -> eqx.Module:
# Just like earlier: It only makes sense to train the arrays in our model,
# so filter out everything else.
opt_state = optim.init(eqx.filter(model, eqx.is_array))
# Always wrap everything -- computing gradients, running the optimiser, updating
# the model -- into a single JIT region. This ensures things run as fast as
# possible.
@eqx.filter_jit
def make_step(
sae: SAE,
model: eqx.Module,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"],
):
_, activ = jax.vmap(model)(x)
loss_value, grads = SAE.loss(sae, activ[0], 1)
updates, opt_state = optim.update(grads, opt_state, sae)
sae = eqx.apply_updates(sae, updates)
return sae, opt_state, loss_value
# Loop over our training dataset as many times as we need.
def infinite_trainloader():
while True:
yield from trainloader
for step, (x, y) in zip(range(steps), infinite_trainloader()):
# PyTorch dataloaders give PyTorch tensors by default,
# so convert them to NumPy arrays.
x = x.numpy()
y = y.numpy()
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y)
if (step % print_every) == 0 or (step == steps - 1):
test_loss, test_accuracy = evaluate(sae, model, testloader)
print(
f"{step=}, train_loss={train_loss.item()}, "
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
)
return model
def train_sae(
key,
cnn,
sae_pos,
activ_size,
batch_size,
learning_rate,
steps,
print_every,
sae_storage="./res/sae.eqx",
):
trainloader, testloader = mnist.load(batch_size=batch_size)
model = sow(sae_pos, cnn)
print(model)
sae = SAE(activ_size, 1000, key)
optim = optax.adamw(learning_rate)
model = train_loop(
sae, model, trainloader, testloader, optim, steps, print_every
)
# eqx.tree_serialise_leaves(model_storage, model)
return model