diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..6f2d6d619 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +# JO3'S DEFAULT IGNORE RULES +/tmp* +## python +/build/ +/dist/ +/*.egg-info/ +/.venv/ +__pycache__ +**.egg-info/ +**.ipynb_checkpoints +/Untitle*.ipynb +## rust +/target/ diff --git a/config.toml b/config.toml new file mode 100644 index 000000000..ca947a9eb --- /dev/null +++ b/config.toml @@ -0,0 +1,7 @@ +batch_size = 64 +learning_rate = 3e-4 +steps = 300 +print_every = 30 +seed = 5678 +cnn_storage = "./res/cnn.eqx" +sae_storage = "./res/sae.eqx" diff --git a/ml_requirements.txt b/ml_requirements.txt deleted file mode 100644 index d4eefd00c..000000000 --- a/ml_requirements.txt +++ /dev/null @@ -1,24 +0,0 @@ -build -debugpy -equinox -jax>=0.4.14 -jaxtyping -matplotlib -nbclassic -notebook -optax -pandas -pyright -scikit-learn -tensorboard -tensorboardX -torch -torchvision -tqdm -twine -typeguard -git+file://../jo3util -git+https://github.com/kiyoon/jupynium.nvim@v0.2.1 -git+https://github.com/JJJHolscher/jupytools --f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --f https://download.pytorch.org/whl/cu118 diff --git a/pyproject.toml b/pyproject.toml index 355e6bd0a..020b84ce5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] -name = "NAME" +name = "sparse_autoencoder" version = "0.0.0" # TODO; automatically update versions by looking at git description = "" dependencies = [] @@ -19,7 +19,7 @@ demo = [] github = "JJJHolscher" [project.urls] -homepage = "https://github.com/JJJHolscher/NAME" +homepage = "https://github.com/JJJHolscher/sparse_autoencoder" [[project.authors]] name = "Jochem Hölscher" @@ -32,7 +32,7 @@ requires = [ build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] -include = ["NAME"] +include = ["sparse_autoencoder"] [tool.setuptools.dynamic] readme = {file = ["README.md"], content-type = "text/markdown"} diff --git a/requirements.txt b/requirements.txt index 97462a890..9b90410c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,26 @@ argtoml build debugpy -jo3util +equinox +jax>=0.4.14 +jaxtyping +matplotlib nbclassic notebook +optax +pandas pyright +scikit-learn +tensorboard +tensorboardX +torch +torchvision +tqdm twine typeguard +git+file:///mnt/nas/git/jo3util git+https://github.com/kiyoon/jupynium.nvim@v0.2.1 git+https://github.com/JJJHolscher/jupytools +-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +-f https://download.pytorch.org/whl/cu118 +git+file:///mnt/nas/git/mnist diff --git a/res/MNIST/MNIST/raw/t10k-images-idx3-ubyte b/res/MNIST/MNIST/raw/t10k-images-idx3-ubyte new file mode 100644 index 000000000..1170b2cae Binary files /dev/null and b/res/MNIST/MNIST/raw/t10k-images-idx3-ubyte differ diff --git a/res/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz b/res/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz new file mode 100644 index 000000000..5ace8ea93 Binary files /dev/null and b/res/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz differ diff --git a/res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte b/res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 000000000..d1c3a9706 Binary files /dev/null and b/res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz b/res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz new file mode 100644 index 000000000..a7e141541 Binary files /dev/null and b/res/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz differ diff --git a/res/MNIST/MNIST/raw/train-images-idx3-ubyte b/res/MNIST/MNIST/raw/train-images-idx3-ubyte new file mode 100644 index 000000000..bbce27659 Binary files /dev/null and b/res/MNIST/MNIST/raw/train-images-idx3-ubyte differ diff --git a/res/MNIST/MNIST/raw/train-images-idx3-ubyte.gz b/res/MNIST/MNIST/raw/train-images-idx3-ubyte.gz new file mode 100644 index 000000000..b50e4b6bc Binary files /dev/null and b/res/MNIST/MNIST/raw/train-images-idx3-ubyte.gz differ diff --git a/res/MNIST/MNIST/raw/train-labels-idx1-ubyte b/res/MNIST/MNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 000000000..d6b4c5db3 Binary files /dev/null and b/res/MNIST/MNIST/raw/train-labels-idx1-ubyte differ diff --git a/res/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz b/res/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz new file mode 100644 index 000000000..707a576bb Binary files /dev/null and b/res/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz differ diff --git a/res/cnn.eqx b/res/cnn.eqx new file mode 100644 index 000000000..b30d54fbe Binary files /dev/null and b/res/cnn.eqx differ diff --git a/sparse_autoencoder b/sparse_autoencoder new file mode 120000 index 000000000..f7ffeddf4 --- /dev/null +++ b/sparse_autoencoder @@ -0,0 +1 @@ +./src \ No newline at end of file diff --git a/src/__main__.py b/src/__main__.py index 9789cc129..0eef5d17e 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -5,3 +5,44 @@ """ +import argtoml +import equinox as eqx +import jax +import jax.numpy as jnp +import mnist +import optax # https://github.com/deepmind/optax +from jaxtyping import Float # https://github.com/google/jaxtyping +from jaxtyping import Array, Int, PyTree +from torch.utils.data import DataLoader + +from .cnn import CNN, train_cnn +from .sae import SAE, train_sae + +# Hyperparameters +O = argtoml.parse_args() +key = jax.random.PRNGKey(O.seed) +key, subkey = jax.random.split(key) + +if O.cnn_storage.exists(): + model = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey)) +else: + model = train_cnn( + subkey, + O.batch_size, + O.learning_rate, + O.steps, + O.print_every, + O.cnn_storage, + ) + +sae = train_sae( + key, + model, + lambda m: m.layers[4], + 512, + O.batch_size, + O.learning_rate, + O.steps, + O.print_every, + O.sae_storage, +) diff --git a/src/cnn.py b/src/cnn.py new file mode 100644 index 000000000..11d272d74 --- /dev/null +++ b/src/cnn.py @@ -0,0 +1,154 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 + +""" + +""" + +import equinox as eqx +import jax +import jax.numpy as jnp +import mnist +import optax # https://github.com/deepmind/optax +from jaxtyping import Float # https://github.com/google/jaxtyping +from jaxtyping import Array, Int, PyTree +from torch.utils.data import DataLoader + + +class CNN(eqx.Module): + layers: list + + def __init__(self, key): + key1, key2, key3, key4 = jax.random.split(key, 4) + # Standard CNN setup: convolutional layer, followed by flattening, + # with a small MLP on top. + self.layers = [ + eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1), + eqx.nn.MaxPool2d(kernel_size=2), + jax.nn.relu, + jnp.ravel, + eqx.nn.Linear(1728, 512, key=key2), + jax.nn.sigmoid, + eqx.nn.Linear(512, 64, key=key3), + jax.nn.relu, + eqx.nn.Linear(64, 10, key=key4), + jax.nn.log_softmax, + ] + + def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]: + for layer in self.layers: + x = layer(x) + return x + + +@eqx.filter_jit +def loss( + model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] +) -> Float[Array, ""]: + # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on + # a single input input image of shape (1, 28, 28). + # + # Therefore, we have to use jax.vmap, which in this case maps our model over the + # leading (batch) axis. + pred_y = jax.vmap(model)(x) + return cross_entropy(y, pred_y) + + +def cross_entropy( + y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"] +) -> Float[Array, ""]: + # y are the true targets, and should be integers 0-9. + # pred_y are the log-softmax'd predictions. + pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1) + return -jnp.mean(pred_y) + + +@eqx.filter_jit +def compute_accuracy( + model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"] +) -> Float[Array, ""]: + """This function takes as input the current model + and computes the average accuracy on a batch. + """ + pred_y = jax.vmap(model)(x) + pred_y = jnp.argmax(pred_y, axis=1) + return jnp.mean(y == pred_y) + + +def evaluate(model: CNN, testloader: DataLoader): + """This function evaluates the model on the test dataset, + computing both the average loss and the average accuracy. + """ + avg_loss = 0 + avg_acc = 0 + for x, y in testloader: + x = x.numpy() + y = y.numpy() + # Note that all the JAX operations happen inside `loss` and `compute_accuracy`, + # and both have JIT wrappers, so this is fast. + avg_loss += loss(model, x, y) + avg_acc += compute_accuracy(model, x, y) + return avg_loss / len(testloader), avg_acc / len(testloader) + + +def train_loop( + model: CNN, + trainloader: DataLoader, + testloader: DataLoader, + optim: optax.GradientTransformation, + steps: int, + print_every: int, +) -> CNN: + # Just like earlier: It only makes sense to train the arrays in our model, + # so filter out everything else. + opt_state = optim.init(eqx.filter(model, eqx.is_array)) + + # Always wrap everything -- computing gradients, running the optimiser, updating + # the model -- into a single JIT region. This ensures things run as fast as + # possible. + @eqx.filter_jit + def make_step( + model: CNN, + opt_state: PyTree, + x: Float[Array, "batch 1 28 28"], + y: Int[Array, " batch"], + ): + loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y) + updates, opt_state = optim.update(grads, opt_state, model) + model = eqx.apply_updates(model, updates) + return model, opt_state, loss_value + + # Loop over our training dataset as many times as we need. + def infinite_trainloader(): + while True: + yield from trainloader + + for step, (x, y) in zip(range(steps), infinite_trainloader()): + # PyTorch dataloaders give PyTorch tensors by default, + # so convert them to NumPy arrays. + x = x.numpy() + y = y.numpy() + model, opt_state, train_loss = make_step(model, opt_state, x, y) + if (step % print_every) == 0 or (step == steps - 1): + test_loss, test_accuracy = evaluate(model, testloader) + print( + f"{step=}, train_loss={train_loss.item()}, " + f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}" + ) + return model + + +def train_cnn( + key, + batch_size, + learning_rate, + steps, + print_every, + model_storage="./res/cnn.eqx", +): + trainloader, testloader = mnist.load(batch_size=batch_size) + model = CNN(key) + optim = optax.adamw(learning_rate) + model = train_loop(model, trainloader, testloader, optim, steps, print_every) + eqx.tree_serialise_leaves(model_storage, model) + return model diff --git a/src/sae.py b/src/sae.py new file mode 100644 index 000000000..d2b99303c --- /dev/null +++ b/src/sae.py @@ -0,0 +1,123 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 + +""" + +""" + +import equinox as eqx +import jax +import jax.numpy as jnp +import mnist +import optax +from jaxtyping import Array, Float, Int, PyTree +from jo3util.eqx import sow +from torch.utils.data import DataLoader + + +class SAE(eqx.Module): + we: Float + wd: Float + be: Float + bd: Float + + def __init__(self, in_size, hidden_size, key=jax.random.PRNGKey(42)): + k0, k1, k2, k3 = jax.random.split(key, 4) + + # encoder weight matrix + self.we = jax.random.uniform(k0, (in_size, hidden_size)) + # decoder weight matrix + self.wd = jax.random.uniform(k1, (hidden_size, in_size)) + # encoder bias + self.be = jax.random.uniform(k2, (hidden_size,)) + # decader bias + self.bd = jax.random.uniform(k3, (in_size,)) + + def encode(self, x): + x = (x - self.bd) @ self.we + self.be + return jax.nn.relu(x) + + def decode(self, fx): + return fx @ self.wd + self.bd + + @staticmethod + @eqx.filter_value_and_grad + def loss(sae, x, λ): + fx = jax.vmap(sae.encode)(x) + x_ = jax.vmap(sae.decode)(fx) + sq_err = jnp.dot((x - x_), (x - x_)) + l1 = λ * jnp.dot(fx, fx) + return jnp.mean(sq_err + l1) + + +def train_loop( + sae: SAE, + model: eqx.Module, + trainloader: DataLoader, + testloader: DataLoader, + optim: optax.GradientTransformation, + steps: int, + print_every: int, +) -> eqx.Module: + # Just like earlier: It only makes sense to train the arrays in our model, + # so filter out everything else. + opt_state = optim.init(eqx.filter(model, eqx.is_array)) + + # Always wrap everything -- computing gradients, running the optimiser, updating + # the model -- into a single JIT region. This ensures things run as fast as + # possible. + @eqx.filter_jit + def make_step( + sae: SAE, + model: eqx.Module, + opt_state: PyTree, + x: Float[Array, "batch 1 28 28"], + y: Int[Array, " batch"], + ): + _, activ = jax.vmap(model)(x) + loss_value, grads = SAE.loss(sae, activ[0], 1) + updates, opt_state = optim.update(grads, opt_state, sae) + sae = eqx.apply_updates(sae, updates) + return sae, opt_state, loss_value + + # Loop over our training dataset as many times as we need. + def infinite_trainloader(): + while True: + yield from trainloader + + for step, (x, y) in zip(range(steps), infinite_trainloader()): + # PyTorch dataloaders give PyTorch tensors by default, + # so convert them to NumPy arrays. + x = x.numpy() + y = y.numpy() + sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y) + if (step % print_every) == 0 or (step == steps - 1): + test_loss, test_accuracy = evaluate(sae, model, testloader) + print( + f"{step=}, train_loss={train_loss.item()}, " + f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}" + ) + return model + + +def train_sae( + key, + cnn, + sae_pos, + activ_size, + batch_size, + learning_rate, + steps, + print_every, + sae_storage="./res/sae.eqx", +): + trainloader, testloader = mnist.load(batch_size=batch_size) + model = sow(sae_pos, cnn) + print(model) + sae = SAE(activ_size, 1000, key) + optim = optax.adamw(learning_rate) + model = train_loop( + sae, model, trainloader, testloader, optim, steps, print_every + ) + # eqx.tree_serialise_leaves(model_storage, model) + return model