change mnist to jo3mnist
This commit is contained in:
parent
36de39f788
commit
d8ce953d71
|
@ -3,7 +3,16 @@
|
||||||
name = "sparse_autoencoder"
|
name = "sparse_autoencoder"
|
||||||
version = "0.0.0" # TODO; automatically update versions by looking at git
|
version = "0.0.0" # TODO; automatically update versions by looking at git
|
||||||
description = ""
|
description = ""
|
||||||
dependencies = []
|
dependencies = [
|
||||||
|
"argtoml",
|
||||||
|
"equinox",
|
||||||
|
"jax",
|
||||||
|
"optax",
|
||||||
|
"torch",
|
||||||
|
"torchvision",
|
||||||
|
"tqdm",
|
||||||
|
"jo3mnist"
|
||||||
|
]
|
||||||
dynamic = ["readme"]
|
dynamic = ["readme"]
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
classifiers = [
|
classifiers = [
|
||||||
|
|
|
@ -19,7 +19,6 @@ tqdm
|
||||||
twine
|
twine
|
||||||
typeguard
|
typeguard
|
||||||
git+file:///mnt/nas/git/jo3util
|
git+file:///mnt/nas/git/jo3util
|
||||||
git+https://github.com/kiyoon/jupynium.nvim@v0.2.1
|
|
||||||
git+https://github.com/JJJHolscher/jupytools
|
git+https://github.com/JJJHolscher/jupytools
|
||||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||||
-f https://download.pytorch.org/whl/cu118
|
-f https://download.pytorch.org/whl/cu118
|
||||||
|
|
|
@ -9,8 +9,6 @@ import argtoml
|
||||||
import equinox as eqx
|
import equinox as eqx
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import mnist
|
|
||||||
import optax # https://github.com/deepmind/optax
|
|
||||||
from jaxtyping import Float # https://github.com/google/jaxtyping
|
from jaxtyping import Float # https://github.com/google/jaxtyping
|
||||||
from jaxtyping import Array, Int, PyTree
|
from jaxtyping import Array, Int, PyTree
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
import equinox as eqx
|
import equinox as eqx
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import mnist
|
import jo3mnist
|
||||||
import optax # https://github.com/deepmind/optax
|
import optax # https://github.com/deepmind/optax
|
||||||
from jaxtyping import Float # https://github.com/google/jaxtyping
|
from jaxtyping import Float # https://github.com/google/jaxtyping
|
||||||
from jaxtyping import Array, Int, PyTree
|
from jaxtyping import Array, Int, PyTree
|
||||||
|
@ -146,7 +146,7 @@ def train_cnn(
|
||||||
print_every,
|
print_every,
|
||||||
model_storage="./res/cnn.eqx",
|
model_storage="./res/cnn.eqx",
|
||||||
):
|
):
|
||||||
trainloader, testloader = mnist.load(batch_size=batch_size)
|
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
||||||
model = CNN(key)
|
model = CNN(key)
|
||||||
optim = optax.adamw(learning_rate)
|
optim = optax.adamw(learning_rate)
|
||||||
model = train_loop(
|
model = train_loop(
|
||||||
|
|
|
@ -10,7 +10,7 @@ from typing import Callable, List
|
||||||
import equinox as eqx
|
import equinox as eqx
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import mnist
|
import jo3mnist
|
||||||
import optax
|
import optax
|
||||||
from jaxtyping import Array, Float, Int, PyTree
|
from jaxtyping import Array, Float, Int, PyTree
|
||||||
from jo3util.eqx import insert_after, sow
|
from jo3util.eqx import insert_after, sow
|
||||||
|
@ -138,7 +138,7 @@ def train_sae(
|
||||||
print_every,
|
print_every,
|
||||||
sae_storage="./res/sae_layer_6.eqx",
|
sae_storage="./res/sae_layer_6.eqx",
|
||||||
):
|
):
|
||||||
trainloader, testloader = mnist.load(batch_size=batch_size)
|
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
||||||
sae = SAE(activ_size, 1000, key)
|
sae = SAE(activ_size, 1000, key)
|
||||||
optim = optax.adamw(learning_rate)
|
optim = optax.adamw(learning_rate)
|
||||||
sae = train_loop(
|
sae = train_loop(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user