gradually getting there, have yet to get a single trained sae
This commit is contained in:
parent
ffb976d94e
commit
63048b1915
13
.gitignore
vendored
Normal file
13
.gitignore
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
# JO3'S DEFAULT IGNORE RULES
|
||||
/tmp*
|
||||
## python
|
||||
/build/
|
||||
/dist/
|
||||
/*.egg-info/
|
||||
/.venv/
|
||||
__pycache__
|
||||
**.egg-info/
|
||||
**.ipynb_checkpoints
|
||||
/Untitle*.ipynb
|
||||
## rust
|
||||
/target/
|
7
config.toml
Normal file
7
config.toml
Normal file
|
@ -0,0 +1,7 @@
|
|||
batch_size = 64
|
||||
learning_rate = 3e-4
|
||||
steps = 300
|
||||
print_every = 30
|
||||
seed = 5678
|
||||
cnn_storage = "./res/cnn.eqx"
|
||||
sae_storage = "./res/sae.eqx"
|
|
@ -1,24 +0,0 @@
|
|||
build
|
||||
debugpy
|
||||
equinox
|
||||
jax>=0.4.14
|
||||
jaxtyping
|
||||
matplotlib
|
||||
nbclassic
|
||||
notebook
|
||||
optax
|
||||
pandas
|
||||
pyright
|
||||
scikit-learn
|
||||
tensorboard
|
||||
tensorboardX
|
||||
torch
|
||||
torchvision
|
||||
tqdm
|
||||
twine
|
||||
typeguard
|
||||
git+file://../jo3util
|
||||
git+https://github.com/kiyoon/jupynium.nvim@v0.2.1
|
||||
git+https://github.com/JJJHolscher/jupytools
|
||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
-f https://download.pytorch.org/whl/cu118
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
[project]
|
||||
name = "NAME"
|
||||
name = "sparse_autoencoder"
|
||||
version = "0.0.0" # TODO; automatically update versions by looking at git
|
||||
description = ""
|
||||
dependencies = []
|
||||
|
@ -19,7 +19,7 @@ demo = []
|
|||
github = "JJJHolscher"
|
||||
|
||||
[project.urls]
|
||||
homepage = "https://github.com/JJJHolscher/NAME"
|
||||
homepage = "https://github.com/JJJHolscher/sparse_autoencoder"
|
||||
|
||||
[[project.authors]]
|
||||
name = "Jochem Hölscher"
|
||||
|
@ -32,7 +32,7 @@ requires = [
|
|||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["NAME"]
|
||||
include = ["sparse_autoencoder"]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
readme = {file = ["README.md"], content-type = "text/markdown"}
|
||||
|
|
|
@ -1,11 +1,26 @@
|
|||
argtoml
|
||||
build
|
||||
debugpy
|
||||
jo3util
|
||||
equinox
|
||||
jax>=0.4.14
|
||||
jaxtyping
|
||||
matplotlib
|
||||
nbclassic
|
||||
notebook
|
||||
optax
|
||||
pandas
|
||||
pyright
|
||||
scikit-learn
|
||||
tensorboard
|
||||
tensorboardX
|
||||
torch
|
||||
torchvision
|
||||
tqdm
|
||||
twine
|
||||
typeguard
|
||||
git+file:///mnt/nas/git/jo3util
|
||||
git+https://github.com/kiyoon/jupynium.nvim@v0.2.1
|
||||
git+https://github.com/JJJHolscher/jupytools
|
||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
-f https://download.pytorch.org/whl/cu118
|
||||
git+file:///mnt/nas/git/mnist
|
||||
|
|
BIN
res/MNIST/MNIST/raw/t10k-images-idx3-ubyte
Normal file
BIN
res/MNIST/MNIST/raw/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
BIN
res/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte
Normal file
BIN
res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
BIN
res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/train-images-idx3-ubyte
Normal file
BIN
res/MNIST/MNIST/raw/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/train-images-idx3-ubyte.gz
Normal file
BIN
res/MNIST/MNIST/raw/train-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/train-labels-idx1-ubyte
Normal file
BIN
res/MNIST/MNIST/raw/train-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz
Normal file
BIN
res/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
BIN
res/cnn.eqx
Normal file
BIN
res/cnn.eqx
Normal file
Binary file not shown.
1
sparse_autoencoder
Symbolic link
1
sparse_autoencoder
Symbolic link
|
@ -0,0 +1 @@
|
|||
./src
|
|
@ -5,3 +5,44 @@
|
|||
|
||||
"""
|
||||
|
||||
import argtoml
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import mnist
|
||||
import optax # https://github.com/deepmind/optax
|
||||
from jaxtyping import Float # https://github.com/google/jaxtyping
|
||||
from jaxtyping import Array, Int, PyTree
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .cnn import CNN, train_cnn
|
||||
from .sae import SAE, train_sae
|
||||
|
||||
# Hyperparameters
|
||||
O = argtoml.parse_args()
|
||||
key = jax.random.PRNGKey(O.seed)
|
||||
key, subkey = jax.random.split(key)
|
||||
|
||||
if O.cnn_storage.exists():
|
||||
model = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
|
||||
else:
|
||||
model = train_cnn(
|
||||
subkey,
|
||||
O.batch_size,
|
||||
O.learning_rate,
|
||||
O.steps,
|
||||
O.print_every,
|
||||
O.cnn_storage,
|
||||
)
|
||||
|
||||
sae = train_sae(
|
||||
key,
|
||||
model,
|
||||
lambda m: m.layers[4],
|
||||
512,
|
||||
O.batch_size,
|
||||
O.learning_rate,
|
||||
O.steps,
|
||||
O.print_every,
|
||||
O.sae_storage,
|
||||
)
|
||||
|
|
154
src/cnn.py
Normal file
154
src/cnn.py
Normal file
|
@ -0,0 +1,154 @@
|
|||
#! /usr/bin/env python3
|
||||
# vim:fenc=utf-8
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import mnist
|
||||
import optax # https://github.com/deepmind/optax
|
||||
from jaxtyping import Float # https://github.com/google/jaxtyping
|
||||
from jaxtyping import Array, Int, PyTree
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class CNN(eqx.Module):
|
||||
layers: list
|
||||
|
||||
def __init__(self, key):
|
||||
key1, key2, key3, key4 = jax.random.split(key, 4)
|
||||
# Standard CNN setup: convolutional layer, followed by flattening,
|
||||
# with a small MLP on top.
|
||||
self.layers = [
|
||||
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
|
||||
eqx.nn.MaxPool2d(kernel_size=2),
|
||||
jax.nn.relu,
|
||||
jnp.ravel,
|
||||
eqx.nn.Linear(1728, 512, key=key2),
|
||||
jax.nn.sigmoid,
|
||||
eqx.nn.Linear(512, 64, key=key3),
|
||||
jax.nn.relu,
|
||||
eqx.nn.Linear(64, 10, key=key4),
|
||||
jax.nn.log_softmax,
|
||||
]
|
||||
|
||||
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@eqx.filter_jit
|
||||
def loss(
|
||||
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||||
) -> Float[Array, ""]:
|
||||
# Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
|
||||
# a single input input image of shape (1, 28, 28).
|
||||
#
|
||||
# Therefore, we have to use jax.vmap, which in this case maps our model over the
|
||||
# leading (batch) axis.
|
||||
pred_y = jax.vmap(model)(x)
|
||||
return cross_entropy(y, pred_y)
|
||||
|
||||
|
||||
def cross_entropy(
|
||||
y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
|
||||
) -> Float[Array, ""]:
|
||||
# y are the true targets, and should be integers 0-9.
|
||||
# pred_y are the log-softmax'd predictions.
|
||||
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
|
||||
return -jnp.mean(pred_y)
|
||||
|
||||
|
||||
@eqx.filter_jit
|
||||
def compute_accuracy(
|
||||
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||||
) -> Float[Array, ""]:
|
||||
"""This function takes as input the current model
|
||||
and computes the average accuracy on a batch.
|
||||
"""
|
||||
pred_y = jax.vmap(model)(x)
|
||||
pred_y = jnp.argmax(pred_y, axis=1)
|
||||
return jnp.mean(y == pred_y)
|
||||
|
||||
|
||||
def evaluate(model: CNN, testloader: DataLoader):
|
||||
"""This function evaluates the model on the test dataset,
|
||||
computing both the average loss and the average accuracy.
|
||||
"""
|
||||
avg_loss = 0
|
||||
avg_acc = 0
|
||||
for x, y in testloader:
|
||||
x = x.numpy()
|
||||
y = y.numpy()
|
||||
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
|
||||
# and both have JIT wrappers, so this is fast.
|
||||
avg_loss += loss(model, x, y)
|
||||
avg_acc += compute_accuracy(model, x, y)
|
||||
return avg_loss / len(testloader), avg_acc / len(testloader)
|
||||
|
||||
|
||||
def train_loop(
|
||||
model: CNN,
|
||||
trainloader: DataLoader,
|
||||
testloader: DataLoader,
|
||||
optim: optax.GradientTransformation,
|
||||
steps: int,
|
||||
print_every: int,
|
||||
) -> CNN:
|
||||
# Just like earlier: It only makes sense to train the arrays in our model,
|
||||
# so filter out everything else.
|
||||
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
||||
|
||||
# Always wrap everything -- computing gradients, running the optimiser, updating
|
||||
# the model -- into a single JIT region. This ensures things run as fast as
|
||||
# possible.
|
||||
@eqx.filter_jit
|
||||
def make_step(
|
||||
model: CNN,
|
||||
opt_state: PyTree,
|
||||
x: Float[Array, "batch 1 28 28"],
|
||||
y: Int[Array, " batch"],
|
||||
):
|
||||
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
|
||||
updates, opt_state = optim.update(grads, opt_state, model)
|
||||
model = eqx.apply_updates(model, updates)
|
||||
return model, opt_state, loss_value
|
||||
|
||||
# Loop over our training dataset as many times as we need.
|
||||
def infinite_trainloader():
|
||||
while True:
|
||||
yield from trainloader
|
||||
|
||||
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
||||
# PyTorch dataloaders give PyTorch tensors by default,
|
||||
# so convert them to NumPy arrays.
|
||||
x = x.numpy()
|
||||
y = y.numpy()
|
||||
model, opt_state, train_loss = make_step(model, opt_state, x, y)
|
||||
if (step % print_every) == 0 or (step == steps - 1):
|
||||
test_loss, test_accuracy = evaluate(model, testloader)
|
||||
print(
|
||||
f"{step=}, train_loss={train_loss.item()}, "
|
||||
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def train_cnn(
|
||||
key,
|
||||
batch_size,
|
||||
learning_rate,
|
||||
steps,
|
||||
print_every,
|
||||
model_storage="./res/cnn.eqx",
|
||||
):
|
||||
trainloader, testloader = mnist.load(batch_size=batch_size)
|
||||
model = CNN(key)
|
||||
optim = optax.adamw(learning_rate)
|
||||
model = train_loop(model, trainloader, testloader, optim, steps, print_every)
|
||||
eqx.tree_serialise_leaves(model_storage, model)
|
||||
return model
|
123
src/sae.py
Normal file
123
src/sae.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
#! /usr/bin/env python3
|
||||
# vim:fenc=utf-8
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import mnist
|
||||
import optax
|
||||
from jaxtyping import Array, Float, Int, PyTree
|
||||
from jo3util.eqx import sow
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class SAE(eqx.Module):
|
||||
we: Float
|
||||
wd: Float
|
||||
be: Float
|
||||
bd: Float
|
||||
|
||||
def __init__(self, in_size, hidden_size, key=jax.random.PRNGKey(42)):
|
||||
k0, k1, k2, k3 = jax.random.split(key, 4)
|
||||
|
||||
# encoder weight matrix
|
||||
self.we = jax.random.uniform(k0, (in_size, hidden_size))
|
||||
# decoder weight matrix
|
||||
self.wd = jax.random.uniform(k1, (hidden_size, in_size))
|
||||
# encoder bias
|
||||
self.be = jax.random.uniform(k2, (hidden_size,))
|
||||
# decader bias
|
||||
self.bd = jax.random.uniform(k3, (in_size,))
|
||||
|
||||
def encode(self, x):
|
||||
x = (x - self.bd) @ self.we + self.be
|
||||
return jax.nn.relu(x)
|
||||
|
||||
def decode(self, fx):
|
||||
return fx @ self.wd + self.bd
|
||||
|
||||
@staticmethod
|
||||
@eqx.filter_value_and_grad
|
||||
def loss(sae, x, λ):
|
||||
fx = jax.vmap(sae.encode)(x)
|
||||
x_ = jax.vmap(sae.decode)(fx)
|
||||
sq_err = jnp.dot((x - x_), (x - x_))
|
||||
l1 = λ * jnp.dot(fx, fx)
|
||||
return jnp.mean(sq_err + l1)
|
||||
|
||||
|
||||
def train_loop(
|
||||
sae: SAE,
|
||||
model: eqx.Module,
|
||||
trainloader: DataLoader,
|
||||
testloader: DataLoader,
|
||||
optim: optax.GradientTransformation,
|
||||
steps: int,
|
||||
print_every: int,
|
||||
) -> eqx.Module:
|
||||
# Just like earlier: It only makes sense to train the arrays in our model,
|
||||
# so filter out everything else.
|
||||
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
||||
|
||||
# Always wrap everything -- computing gradients, running the optimiser, updating
|
||||
# the model -- into a single JIT region. This ensures things run as fast as
|
||||
# possible.
|
||||
@eqx.filter_jit
|
||||
def make_step(
|
||||
sae: SAE,
|
||||
model: eqx.Module,
|
||||
opt_state: PyTree,
|
||||
x: Float[Array, "batch 1 28 28"],
|
||||
y: Int[Array, " batch"],
|
||||
):
|
||||
_, activ = jax.vmap(model)(x)
|
||||
loss_value, grads = SAE.loss(sae, activ[0], 1)
|
||||
updates, opt_state = optim.update(grads, opt_state, sae)
|
||||
sae = eqx.apply_updates(sae, updates)
|
||||
return sae, opt_state, loss_value
|
||||
|
||||
# Loop over our training dataset as many times as we need.
|
||||
def infinite_trainloader():
|
||||
while True:
|
||||
yield from trainloader
|
||||
|
||||
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
||||
# PyTorch dataloaders give PyTorch tensors by default,
|
||||
# so convert them to NumPy arrays.
|
||||
x = x.numpy()
|
||||
y = y.numpy()
|
||||
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y)
|
||||
if (step % print_every) == 0 or (step == steps - 1):
|
||||
test_loss, test_accuracy = evaluate(sae, model, testloader)
|
||||
print(
|
||||
f"{step=}, train_loss={train_loss.item()}, "
|
||||
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def train_sae(
|
||||
key,
|
||||
cnn,
|
||||
sae_pos,
|
||||
activ_size,
|
||||
batch_size,
|
||||
learning_rate,
|
||||
steps,
|
||||
print_every,
|
||||
sae_storage="./res/sae.eqx",
|
||||
):
|
||||
trainloader, testloader = mnist.load(batch_size=batch_size)
|
||||
model = sow(sae_pos, cnn)
|
||||
print(model)
|
||||
sae = SAE(activ_size, 1000, key)
|
||||
optim = optax.adamw(learning_rate)
|
||||
model = train_loop(
|
||||
sae, model, trainloader, testloader, optim, steps, print_every
|
||||
)
|
||||
# eqx.tree_serialise_leaves(model_storage, model)
|
||||
return model
|
Loading…
Reference in New Issue
Block a user