gradually getting there, have yet to get a single trained sae
This commit is contained in:
parent
ffb976d94e
commit
63048b1915
13
.gitignore
vendored
Normal file
13
.gitignore
vendored
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# JO3'S DEFAULT IGNORE RULES
|
||||||
|
/tmp*
|
||||||
|
## python
|
||||||
|
/build/
|
||||||
|
/dist/
|
||||||
|
/*.egg-info/
|
||||||
|
/.venv/
|
||||||
|
__pycache__
|
||||||
|
**.egg-info/
|
||||||
|
**.ipynb_checkpoints
|
||||||
|
/Untitle*.ipynb
|
||||||
|
## rust
|
||||||
|
/target/
|
7
config.toml
Normal file
7
config.toml
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
batch_size = 64
|
||||||
|
learning_rate = 3e-4
|
||||||
|
steps = 300
|
||||||
|
print_every = 30
|
||||||
|
seed = 5678
|
||||||
|
cnn_storage = "./res/cnn.eqx"
|
||||||
|
sae_storage = "./res/sae.eqx"
|
|
@ -1,24 +0,0 @@
|
||||||
build
|
|
||||||
debugpy
|
|
||||||
equinox
|
|
||||||
jax>=0.4.14
|
|
||||||
jaxtyping
|
|
||||||
matplotlib
|
|
||||||
nbclassic
|
|
||||||
notebook
|
|
||||||
optax
|
|
||||||
pandas
|
|
||||||
pyright
|
|
||||||
scikit-learn
|
|
||||||
tensorboard
|
|
||||||
tensorboardX
|
|
||||||
torch
|
|
||||||
torchvision
|
|
||||||
tqdm
|
|
||||||
twine
|
|
||||||
typeguard
|
|
||||||
git+file://../jo3util
|
|
||||||
git+https://github.com/kiyoon/jupynium.nvim@v0.2.1
|
|
||||||
git+https://github.com/JJJHolscher/jupytools
|
|
||||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
|
||||||
-f https://download.pytorch.org/whl/cu118
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "NAME"
|
name = "sparse_autoencoder"
|
||||||
version = "0.0.0" # TODO; automatically update versions by looking at git
|
version = "0.0.0" # TODO; automatically update versions by looking at git
|
||||||
description = ""
|
description = ""
|
||||||
dependencies = []
|
dependencies = []
|
||||||
|
@ -19,7 +19,7 @@ demo = []
|
||||||
github = "JJJHolscher"
|
github = "JJJHolscher"
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
homepage = "https://github.com/JJJHolscher/NAME"
|
homepage = "https://github.com/JJJHolscher/sparse_autoencoder"
|
||||||
|
|
||||||
[[project.authors]]
|
[[project.authors]]
|
||||||
name = "Jochem Hölscher"
|
name = "Jochem Hölscher"
|
||||||
|
@ -32,7 +32,7 @@ requires = [
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
include = ["NAME"]
|
include = ["sparse_autoencoder"]
|
||||||
|
|
||||||
[tool.setuptools.dynamic]
|
[tool.setuptools.dynamic]
|
||||||
readme = {file = ["README.md"], content-type = "text/markdown"}
|
readme = {file = ["README.md"], content-type = "text/markdown"}
|
||||||
|
|
|
@ -1,11 +1,26 @@
|
||||||
argtoml
|
argtoml
|
||||||
build
|
build
|
||||||
debugpy
|
debugpy
|
||||||
jo3util
|
equinox
|
||||||
|
jax>=0.4.14
|
||||||
|
jaxtyping
|
||||||
|
matplotlib
|
||||||
nbclassic
|
nbclassic
|
||||||
notebook
|
notebook
|
||||||
|
optax
|
||||||
|
pandas
|
||||||
pyright
|
pyright
|
||||||
|
scikit-learn
|
||||||
|
tensorboard
|
||||||
|
tensorboardX
|
||||||
|
torch
|
||||||
|
torchvision
|
||||||
|
tqdm
|
||||||
twine
|
twine
|
||||||
typeguard
|
typeguard
|
||||||
|
git+file:///mnt/nas/git/jo3util
|
||||||
git+https://github.com/kiyoon/jupynium.nvim@v0.2.1
|
git+https://github.com/kiyoon/jupynium.nvim@v0.2.1
|
||||||
git+https://github.com/JJJHolscher/jupytools
|
git+https://github.com/JJJHolscher/jupytools
|
||||||
|
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||||
|
-f https://download.pytorch.org/whl/cu118
|
||||||
|
git+file:///mnt/nas/git/mnist
|
||||||
|
|
BIN
res/MNIST/MNIST/raw/t10k-images-idx3-ubyte
Normal file
BIN
res/MNIST/MNIST/raw/t10k-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
BIN
res/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte
Normal file
BIN
res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
BIN
res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/train-images-idx3-ubyte
Normal file
BIN
res/MNIST/MNIST/raw/train-images-idx3-ubyte
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/train-images-idx3-ubyte.gz
Normal file
BIN
res/MNIST/MNIST/raw/train-images-idx3-ubyte.gz
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/train-labels-idx1-ubyte
Normal file
BIN
res/MNIST/MNIST/raw/train-labels-idx1-ubyte
Normal file
Binary file not shown.
BIN
res/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz
Normal file
BIN
res/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz
Normal file
Binary file not shown.
BIN
res/cnn.eqx
Normal file
BIN
res/cnn.eqx
Normal file
Binary file not shown.
1
sparse_autoencoder
Symbolic link
1
sparse_autoencoder
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
./src
|
|
@ -5,3 +5,44 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argtoml
|
||||||
|
import equinox as eqx
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import mnist
|
||||||
|
import optax # https://github.com/deepmind/optax
|
||||||
|
from jaxtyping import Float # https://github.com/google/jaxtyping
|
||||||
|
from jaxtyping import Array, Int, PyTree
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from .cnn import CNN, train_cnn
|
||||||
|
from .sae import SAE, train_sae
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
O = argtoml.parse_args()
|
||||||
|
key = jax.random.PRNGKey(O.seed)
|
||||||
|
key, subkey = jax.random.split(key)
|
||||||
|
|
||||||
|
if O.cnn_storage.exists():
|
||||||
|
model = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
|
||||||
|
else:
|
||||||
|
model = train_cnn(
|
||||||
|
subkey,
|
||||||
|
O.batch_size,
|
||||||
|
O.learning_rate,
|
||||||
|
O.steps,
|
||||||
|
O.print_every,
|
||||||
|
O.cnn_storage,
|
||||||
|
)
|
||||||
|
|
||||||
|
sae = train_sae(
|
||||||
|
key,
|
||||||
|
model,
|
||||||
|
lambda m: m.layers[4],
|
||||||
|
512,
|
||||||
|
O.batch_size,
|
||||||
|
O.learning_rate,
|
||||||
|
O.steps,
|
||||||
|
O.print_every,
|
||||||
|
O.sae_storage,
|
||||||
|
)
|
||||||
|
|
154
src/cnn.py
Normal file
154
src/cnn.py
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import equinox as eqx
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import mnist
|
||||||
|
import optax # https://github.com/deepmind/optax
|
||||||
|
from jaxtyping import Float # https://github.com/google/jaxtyping
|
||||||
|
from jaxtyping import Array, Int, PyTree
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class CNN(eqx.Module):
|
||||||
|
layers: list
|
||||||
|
|
||||||
|
def __init__(self, key):
|
||||||
|
key1, key2, key3, key4 = jax.random.split(key, 4)
|
||||||
|
# Standard CNN setup: convolutional layer, followed by flattening,
|
||||||
|
# with a small MLP on top.
|
||||||
|
self.layers = [
|
||||||
|
eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
|
||||||
|
eqx.nn.MaxPool2d(kernel_size=2),
|
||||||
|
jax.nn.relu,
|
||||||
|
jnp.ravel,
|
||||||
|
eqx.nn.Linear(1728, 512, key=key2),
|
||||||
|
jax.nn.sigmoid,
|
||||||
|
eqx.nn.Linear(512, 64, key=key3),
|
||||||
|
jax.nn.relu,
|
||||||
|
eqx.nn.Linear(64, 10, key=key4),
|
||||||
|
jax.nn.log_softmax,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@eqx.filter_jit
|
||||||
|
def loss(
|
||||||
|
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||||||
|
) -> Float[Array, ""]:
|
||||||
|
# Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
|
||||||
|
# a single input input image of shape (1, 28, 28).
|
||||||
|
#
|
||||||
|
# Therefore, we have to use jax.vmap, which in this case maps our model over the
|
||||||
|
# leading (batch) axis.
|
||||||
|
pred_y = jax.vmap(model)(x)
|
||||||
|
return cross_entropy(y, pred_y)
|
||||||
|
|
||||||
|
|
||||||
|
def cross_entropy(
|
||||||
|
y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
|
||||||
|
) -> Float[Array, ""]:
|
||||||
|
# y are the true targets, and should be integers 0-9.
|
||||||
|
# pred_y are the log-softmax'd predictions.
|
||||||
|
pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
|
||||||
|
return -jnp.mean(pred_y)
|
||||||
|
|
||||||
|
|
||||||
|
@eqx.filter_jit
|
||||||
|
def compute_accuracy(
|
||||||
|
model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
|
||||||
|
) -> Float[Array, ""]:
|
||||||
|
"""This function takes as input the current model
|
||||||
|
and computes the average accuracy on a batch.
|
||||||
|
"""
|
||||||
|
pred_y = jax.vmap(model)(x)
|
||||||
|
pred_y = jnp.argmax(pred_y, axis=1)
|
||||||
|
return jnp.mean(y == pred_y)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model: CNN, testloader: DataLoader):
|
||||||
|
"""This function evaluates the model on the test dataset,
|
||||||
|
computing both the average loss and the average accuracy.
|
||||||
|
"""
|
||||||
|
avg_loss = 0
|
||||||
|
avg_acc = 0
|
||||||
|
for x, y in testloader:
|
||||||
|
x = x.numpy()
|
||||||
|
y = y.numpy()
|
||||||
|
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
|
||||||
|
# and both have JIT wrappers, so this is fast.
|
||||||
|
avg_loss += loss(model, x, y)
|
||||||
|
avg_acc += compute_accuracy(model, x, y)
|
||||||
|
return avg_loss / len(testloader), avg_acc / len(testloader)
|
||||||
|
|
||||||
|
|
||||||
|
def train_loop(
|
||||||
|
model: CNN,
|
||||||
|
trainloader: DataLoader,
|
||||||
|
testloader: DataLoader,
|
||||||
|
optim: optax.GradientTransformation,
|
||||||
|
steps: int,
|
||||||
|
print_every: int,
|
||||||
|
) -> CNN:
|
||||||
|
# Just like earlier: It only makes sense to train the arrays in our model,
|
||||||
|
# so filter out everything else.
|
||||||
|
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
||||||
|
|
||||||
|
# Always wrap everything -- computing gradients, running the optimiser, updating
|
||||||
|
# the model -- into a single JIT region. This ensures things run as fast as
|
||||||
|
# possible.
|
||||||
|
@eqx.filter_jit
|
||||||
|
def make_step(
|
||||||
|
model: CNN,
|
||||||
|
opt_state: PyTree,
|
||||||
|
x: Float[Array, "batch 1 28 28"],
|
||||||
|
y: Int[Array, " batch"],
|
||||||
|
):
|
||||||
|
loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
|
||||||
|
updates, opt_state = optim.update(grads, opt_state, model)
|
||||||
|
model = eqx.apply_updates(model, updates)
|
||||||
|
return model, opt_state, loss_value
|
||||||
|
|
||||||
|
# Loop over our training dataset as many times as we need.
|
||||||
|
def infinite_trainloader():
|
||||||
|
while True:
|
||||||
|
yield from trainloader
|
||||||
|
|
||||||
|
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
||||||
|
# PyTorch dataloaders give PyTorch tensors by default,
|
||||||
|
# so convert them to NumPy arrays.
|
||||||
|
x = x.numpy()
|
||||||
|
y = y.numpy()
|
||||||
|
model, opt_state, train_loss = make_step(model, opt_state, x, y)
|
||||||
|
if (step % print_every) == 0 or (step == steps - 1):
|
||||||
|
test_loss, test_accuracy = evaluate(model, testloader)
|
||||||
|
print(
|
||||||
|
f"{step=}, train_loss={train_loss.item()}, "
|
||||||
|
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def train_cnn(
|
||||||
|
key,
|
||||||
|
batch_size,
|
||||||
|
learning_rate,
|
||||||
|
steps,
|
||||||
|
print_every,
|
||||||
|
model_storage="./res/cnn.eqx",
|
||||||
|
):
|
||||||
|
trainloader, testloader = mnist.load(batch_size=batch_size)
|
||||||
|
model = CNN(key)
|
||||||
|
optim = optax.adamw(learning_rate)
|
||||||
|
model = train_loop(model, trainloader, testloader, optim, steps, print_every)
|
||||||
|
eqx.tree_serialise_leaves(model_storage, model)
|
||||||
|
return model
|
123
src/sae.py
Normal file
123
src/sae.py
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import equinox as eqx
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import mnist
|
||||||
|
import optax
|
||||||
|
from jaxtyping import Array, Float, Int, PyTree
|
||||||
|
from jo3util.eqx import sow
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class SAE(eqx.Module):
|
||||||
|
we: Float
|
||||||
|
wd: Float
|
||||||
|
be: Float
|
||||||
|
bd: Float
|
||||||
|
|
||||||
|
def __init__(self, in_size, hidden_size, key=jax.random.PRNGKey(42)):
|
||||||
|
k0, k1, k2, k3 = jax.random.split(key, 4)
|
||||||
|
|
||||||
|
# encoder weight matrix
|
||||||
|
self.we = jax.random.uniform(k0, (in_size, hidden_size))
|
||||||
|
# decoder weight matrix
|
||||||
|
self.wd = jax.random.uniform(k1, (hidden_size, in_size))
|
||||||
|
# encoder bias
|
||||||
|
self.be = jax.random.uniform(k2, (hidden_size,))
|
||||||
|
# decader bias
|
||||||
|
self.bd = jax.random.uniform(k3, (in_size,))
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
x = (x - self.bd) @ self.we + self.be
|
||||||
|
return jax.nn.relu(x)
|
||||||
|
|
||||||
|
def decode(self, fx):
|
||||||
|
return fx @ self.wd + self.bd
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@eqx.filter_value_and_grad
|
||||||
|
def loss(sae, x, λ):
|
||||||
|
fx = jax.vmap(sae.encode)(x)
|
||||||
|
x_ = jax.vmap(sae.decode)(fx)
|
||||||
|
sq_err = jnp.dot((x - x_), (x - x_))
|
||||||
|
l1 = λ * jnp.dot(fx, fx)
|
||||||
|
return jnp.mean(sq_err + l1)
|
||||||
|
|
||||||
|
|
||||||
|
def train_loop(
|
||||||
|
sae: SAE,
|
||||||
|
model: eqx.Module,
|
||||||
|
trainloader: DataLoader,
|
||||||
|
testloader: DataLoader,
|
||||||
|
optim: optax.GradientTransformation,
|
||||||
|
steps: int,
|
||||||
|
print_every: int,
|
||||||
|
) -> eqx.Module:
|
||||||
|
# Just like earlier: It only makes sense to train the arrays in our model,
|
||||||
|
# so filter out everything else.
|
||||||
|
opt_state = optim.init(eqx.filter(model, eqx.is_array))
|
||||||
|
|
||||||
|
# Always wrap everything -- computing gradients, running the optimiser, updating
|
||||||
|
# the model -- into a single JIT region. This ensures things run as fast as
|
||||||
|
# possible.
|
||||||
|
@eqx.filter_jit
|
||||||
|
def make_step(
|
||||||
|
sae: SAE,
|
||||||
|
model: eqx.Module,
|
||||||
|
opt_state: PyTree,
|
||||||
|
x: Float[Array, "batch 1 28 28"],
|
||||||
|
y: Int[Array, " batch"],
|
||||||
|
):
|
||||||
|
_, activ = jax.vmap(model)(x)
|
||||||
|
loss_value, grads = SAE.loss(sae, activ[0], 1)
|
||||||
|
updates, opt_state = optim.update(grads, opt_state, sae)
|
||||||
|
sae = eqx.apply_updates(sae, updates)
|
||||||
|
return sae, opt_state, loss_value
|
||||||
|
|
||||||
|
# Loop over our training dataset as many times as we need.
|
||||||
|
def infinite_trainloader():
|
||||||
|
while True:
|
||||||
|
yield from trainloader
|
||||||
|
|
||||||
|
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
||||||
|
# PyTorch dataloaders give PyTorch tensors by default,
|
||||||
|
# so convert them to NumPy arrays.
|
||||||
|
x = x.numpy()
|
||||||
|
y = y.numpy()
|
||||||
|
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y)
|
||||||
|
if (step % print_every) == 0 or (step == steps - 1):
|
||||||
|
test_loss, test_accuracy = evaluate(sae, model, testloader)
|
||||||
|
print(
|
||||||
|
f"{step=}, train_loss={train_loss.item()}, "
|
||||||
|
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def train_sae(
|
||||||
|
key,
|
||||||
|
cnn,
|
||||||
|
sae_pos,
|
||||||
|
activ_size,
|
||||||
|
batch_size,
|
||||||
|
learning_rate,
|
||||||
|
steps,
|
||||||
|
print_every,
|
||||||
|
sae_storage="./res/sae.eqx",
|
||||||
|
):
|
||||||
|
trainloader, testloader = mnist.load(batch_size=batch_size)
|
||||||
|
model = sow(sae_pos, cnn)
|
||||||
|
print(model)
|
||||||
|
sae = SAE(activ_size, 1000, key)
|
||||||
|
optim = optax.adamw(learning_rate)
|
||||||
|
model = train_loop(
|
||||||
|
sae, model, trainloader, testloader, optim, steps, print_every
|
||||||
|
)
|
||||||
|
# eqx.tree_serialise_leaves(model_storage, model)
|
||||||
|
return model
|
Loading…
Reference in New Issue
Block a user