change mnist to jo3mnist

This commit is contained in:
JJJHolscher 2023-12-06 17:01:53 +01:00
parent 36de39f788
commit d8ce953d71
5 changed files with 14 additions and 8 deletions

View File

@ -3,7 +3,16 @@
name = "sparse_autoencoder" name = "sparse_autoencoder"
version = "0.0.0" # TODO; automatically update versions by looking at git version = "0.0.0" # TODO; automatically update versions by looking at git
description = "" description = ""
dependencies = [] dependencies = [
"argtoml",
"equinox",
"jax",
"optax",
"torch",
"torchvision",
"tqdm",
"jo3mnist"
]
dynamic = ["readme"] dynamic = ["readme"]
requires-python = ">=3.11" requires-python = ">=3.11"
classifiers = [ classifiers = [

View File

@ -19,7 +19,6 @@ tqdm
twine twine
typeguard typeguard
git+file:///mnt/nas/git/jo3util git+file:///mnt/nas/git/jo3util
git+https://github.com/kiyoon/jupynium.nvim@v0.2.1
git+https://github.com/JJJHolscher/jupytools git+https://github.com/JJJHolscher/jupytools
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-f https://download.pytorch.org/whl/cu118 -f https://download.pytorch.org/whl/cu118

View File

@ -9,8 +9,6 @@ import argtoml
import equinox as eqx import equinox as eqx
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import mnist
import optax # https://github.com/deepmind/optax
from jaxtyping import Float # https://github.com/google/jaxtyping from jaxtyping import Float # https://github.com/google/jaxtyping
from jaxtyping import Array, Int, PyTree from jaxtyping import Array, Int, PyTree
from torch.utils.data import DataLoader from torch.utils.data import DataLoader

View File

@ -8,7 +8,7 @@
import equinox as eqx import equinox as eqx
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import mnist import jo3mnist
import optax # https://github.com/deepmind/optax import optax # https://github.com/deepmind/optax
from jaxtyping import Float # https://github.com/google/jaxtyping from jaxtyping import Float # https://github.com/google/jaxtyping
from jaxtyping import Array, Int, PyTree from jaxtyping import Array, Int, PyTree
@ -146,7 +146,7 @@ def train_cnn(
print_every, print_every,
model_storage="./res/cnn.eqx", model_storage="./res/cnn.eqx",
): ):
trainloader, testloader = mnist.load(batch_size=batch_size) trainloader, testloader = jo3mnist.load(batch_size=batch_size)
model = CNN(key) model = CNN(key)
optim = optax.adamw(learning_rate) optim = optax.adamw(learning_rate)
model = train_loop( model = train_loop(

View File

@ -10,7 +10,7 @@ from typing import Callable, List
import equinox as eqx import equinox as eqx
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import mnist import jo3mnist
import optax import optax
from jaxtyping import Array, Float, Int, PyTree from jaxtyping import Array, Float, Int, PyTree
from jo3util.eqx import insert_after, sow from jo3util.eqx import insert_after, sow
@ -138,7 +138,7 @@ def train_sae(
print_every, print_every,
sae_storage="./res/sae_layer_6.eqx", sae_storage="./res/sae_layer_6.eqx",
): ):
trainloader, testloader = mnist.load(batch_size=batch_size) trainloader, testloader = jo3mnist.load(batch_size=batch_size)
sae = SAE(activ_size, 1000, key) sae = SAE(activ_size, 1000, key)
optim = optax.adamw(learning_rate) optim = optax.adamw(learning_rate)
sae = train_loop( sae = train_loop(