diff --git a/pyproject.toml b/pyproject.toml index 020b84ce5..e393ec4f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,16 @@ name = "sparse_autoencoder" version = "0.0.0" # TODO; automatically update versions by looking at git description = "" -dependencies = [] +dependencies = [ + "argtoml", + "equinox", + "jax", + "optax", + "torch", + "torchvision", + "tqdm", + "jo3mnist" +] dynamic = ["readme"] requires-python = ">=3.11" classifiers = [ diff --git a/requirements.txt b/requirements.txt index 9b90410c3..400331ae0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,6 @@ tqdm twine typeguard git+file:///mnt/nas/git/jo3util -git+https://github.com/kiyoon/jupynium.nvim@v0.2.1 git+https://github.com/JJJHolscher/jupytools -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -f https://download.pytorch.org/whl/cu118 diff --git a/src/__main__.py b/src/__main__.py index e800cb448..c38859731 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -9,8 +9,6 @@ import argtoml import equinox as eqx import jax import jax.numpy as jnp -import mnist -import optax # https://github.com/deepmind/optax from jaxtyping import Float # https://github.com/google/jaxtyping from jaxtyping import Array, Int, PyTree from torch.utils.data import DataLoader diff --git a/src/cnn.py b/src/cnn.py index 4aa389cd5..fdf02d0e2 100644 --- a/src/cnn.py +++ b/src/cnn.py @@ -8,7 +8,7 @@ import equinox as eqx import jax import jax.numpy as jnp -import mnist +import jo3mnist import optax # https://github.com/deepmind/optax from jaxtyping import Float # https://github.com/google/jaxtyping from jaxtyping import Array, Int, PyTree @@ -146,7 +146,7 @@ def train_cnn( print_every, model_storage="./res/cnn.eqx", ): - trainloader, testloader = mnist.load(batch_size=batch_size) + trainloader, testloader = jo3mnist.load(batch_size=batch_size) model = CNN(key) optim = optax.adamw(learning_rate) model = train_loop( diff --git a/src/sae.py b/src/sae.py index d10a3e5fd..a336d26cc 100644 --- a/src/sae.py +++ b/src/sae.py @@ -10,7 +10,7 @@ from typing import Callable, List import equinox as eqx import jax import jax.numpy as jnp -import mnist +import jo3mnist import optax from jaxtyping import Array, Float, Int, PyTree from jo3util.eqx import insert_after, sow @@ -138,7 +138,7 @@ def train_sae( print_every, sae_storage="./res/sae_layer_6.eqx", ): - trainloader, testloader = mnist.load(batch_size=batch_size) + trainloader, testloader = jo3mnist.load(batch_size=batch_size) sae = SAE(activ_size, 1000, key) optim = optax.adamw(learning_rate) sae = train_loop(