change mnist to jo3mnist

This commit is contained in:
JJJHolscher 2023-12-06 17:01:53 +01:00
parent 36de39f788
commit d8ce953d71
5 changed files with 14 additions and 8 deletions

View File

@ -3,7 +3,16 @@
name = "sparse_autoencoder"
version = "0.0.0" # TODO; automatically update versions by looking at git
description = ""
dependencies = []
dependencies = [
"argtoml",
"equinox",
"jax",
"optax",
"torch",
"torchvision",
"tqdm",
"jo3mnist"
]
dynamic = ["readme"]
requires-python = ">=3.11"
classifiers = [

View File

@ -19,7 +19,6 @@ tqdm
twine
typeguard
git+file:///mnt/nas/git/jo3util
git+https://github.com/kiyoon/jupynium.nvim@v0.2.1
git+https://github.com/JJJHolscher/jupytools
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-f https://download.pytorch.org/whl/cu118

View File

@ -9,8 +9,6 @@ import argtoml
import equinox as eqx
import jax
import jax.numpy as jnp
import mnist
import optax # https://github.com/deepmind/optax
from jaxtyping import Float # https://github.com/google/jaxtyping
from jaxtyping import Array, Int, PyTree
from torch.utils.data import DataLoader

View File

@ -8,7 +8,7 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import mnist
import jo3mnist
import optax # https://github.com/deepmind/optax
from jaxtyping import Float # https://github.com/google/jaxtyping
from jaxtyping import Array, Int, PyTree
@ -146,7 +146,7 @@ def train_cnn(
print_every,
model_storage="./res/cnn.eqx",
):
trainloader, testloader = mnist.load(batch_size=batch_size)
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
model = CNN(key)
optim = optax.adamw(learning_rate)
model = train_loop(

View File

@ -10,7 +10,7 @@ from typing import Callable, List
import equinox as eqx
import jax
import jax.numpy as jnp
import mnist
import jo3mnist
import optax
from jaxtyping import Array, Float, Int, PyTree
from jo3util.eqx import insert_after, sow
@ -138,7 +138,7 @@ def train_sae(
print_every,
sae_storage="./res/sae_layer_6.eqx",
):
trainloader, testloader = mnist.load(batch_size=batch_size)
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
sae = SAE(activ_size, 1000, key)
optim = optax.adamw(learning_rate)
sae = train_loop(