change mnist to jo3mnist
This commit is contained in:
parent
36de39f788
commit
d8ce953d71
|
@ -3,7 +3,16 @@
|
|||
name = "sparse_autoencoder"
|
||||
version = "0.0.0" # TODO; automatically update versions by looking at git
|
||||
description = ""
|
||||
dependencies = []
|
||||
dependencies = [
|
||||
"argtoml",
|
||||
"equinox",
|
||||
"jax",
|
||||
"optax",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"tqdm",
|
||||
"jo3mnist"
|
||||
]
|
||||
dynamic = ["readme"]
|
||||
requires-python = ">=3.11"
|
||||
classifiers = [
|
||||
|
|
|
@ -19,7 +19,6 @@ tqdm
|
|||
twine
|
||||
typeguard
|
||||
git+file:///mnt/nas/git/jo3util
|
||||
git+https://github.com/kiyoon/jupynium.nvim@v0.2.1
|
||||
git+https://github.com/JJJHolscher/jupytools
|
||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
-f https://download.pytorch.org/whl/cu118
|
||||
|
|
|
@ -9,8 +9,6 @@ import argtoml
|
|||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import mnist
|
||||
import optax # https://github.com/deepmind/optax
|
||||
from jaxtyping import Float # https://github.com/google/jaxtyping
|
||||
from jaxtyping import Array, Int, PyTree
|
||||
from torch.utils.data import DataLoader
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import mnist
|
||||
import jo3mnist
|
||||
import optax # https://github.com/deepmind/optax
|
||||
from jaxtyping import Float # https://github.com/google/jaxtyping
|
||||
from jaxtyping import Array, Int, PyTree
|
||||
|
@ -146,7 +146,7 @@ def train_cnn(
|
|||
print_every,
|
||||
model_storage="./res/cnn.eqx",
|
||||
):
|
||||
trainloader, testloader = mnist.load(batch_size=batch_size)
|
||||
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
||||
model = CNN(key)
|
||||
optim = optax.adamw(learning_rate)
|
||||
model = train_loop(
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import Callable, List
|
|||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import mnist
|
||||
import jo3mnist
|
||||
import optax
|
||||
from jaxtyping import Array, Float, Int, PyTree
|
||||
from jo3util.eqx import insert_after, sow
|
||||
|
@ -138,7 +138,7 @@ def train_sae(
|
|||
print_every,
|
||||
sae_storage="./res/sae_layer_6.eqx",
|
||||
):
|
||||
trainloader, testloader = mnist.load(batch_size=batch_size)
|
||||
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
||||
sae = SAE(activ_size, 1000, key)
|
||||
optim = optax.adamw(learning_rate)
|
||||
sae = train_loop(
|
||||
|
|
Loading…
Reference in New Issue
Block a user