big search for good learning rate
This commit is contained in:
parent
d8ce953d71
commit
bc7647cb43
50
config.toml
50
config.toml
|
@ -1,7 +1,47 @@
|
|||
batch_size = 64
|
||||
learning_rate = 3e-4
|
||||
steps = 300
|
||||
print_every = 30
|
||||
seed = 5678
|
||||
steps = 500000
|
||||
print_every = 10000
|
||||
seed = 0
|
||||
cnn_storage = "./res/cnn.eqx"
|
||||
sae_storage = "./res/sae.eqx"
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
input_size = 64
|
||||
learning_rate = 0.1
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
input_size = 64
|
||||
learning_rate = 3e-2
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
input_size = 64
|
||||
learning_rate = 1e-2
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
input_size = 64
|
||||
learning_rate = 3e-3
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
input_size = 64
|
||||
learning_rate = 1e-3
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
input_size = 64
|
||||
learning_rate = 3e-4
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
input_size = 64
|
||||
learning_rate = 1e-4
|
||||
|
|
|
@ -1,25 +1,21 @@
|
|||
argtoml
|
||||
argtoml[save]
|
||||
build
|
||||
debugpy
|
||||
equinox
|
||||
jax>=0.4.14
|
||||
jax
|
||||
jaxtyping
|
||||
jo3mnist
|
||||
jo3util
|
||||
matplotlib
|
||||
nbclassic
|
||||
notebook
|
||||
optax
|
||||
pandas
|
||||
pyright
|
||||
scikit-learn
|
||||
tensorboard
|
||||
tensorboardX
|
||||
tensorflow
|
||||
torch
|
||||
torchvision
|
||||
tqdm
|
||||
twine
|
||||
typeguard
|
||||
git+file:///mnt/nas/git/jo3util
|
||||
git+https://github.com/JJJHolscher/jupytools
|
||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
-f https://download.pytorch.org/whl/cu118
|
||||
git+file:///mnt/nas/git/mnist
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
|
|
@ -1,30 +1,30 @@
|
|||
#! /usr/bin/env python3
|
||||
# vim:fenc=utf-8
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import argtoml
|
||||
import equinox as eqx
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jaxtyping import Float # https://github.com/google/jaxtyping
|
||||
from jaxtyping import Array, Int, PyTree
|
||||
from torch.utils.data import DataLoader
|
||||
import jo3mnist
|
||||
from jo3util import eqx as jo3eqx
|
||||
from jo3util.root import run_dir
|
||||
from tqdm import tqdm
|
||||
|
||||
from .cnn import CNN, train_cnn
|
||||
from .sae import SAE, train_sae
|
||||
from .sae import SAE, sample_features, train_sae
|
||||
|
||||
# Hyperparameters
|
||||
O = argtoml.parse_args()
|
||||
|
||||
key = jax.random.PRNGKey(O.seed)
|
||||
key, subkey = jax.random.split(key)
|
||||
|
||||
if O.cnn_storage.exists():
|
||||
model = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
|
||||
if (Path(".") / O.cnn_storage).exists():
|
||||
cnn = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
|
||||
else:
|
||||
model = train_cnn(
|
||||
cnn = train_cnn(
|
||||
subkey,
|
||||
O.batch_size,
|
||||
O.learning_rate,
|
||||
|
@ -32,15 +32,59 @@ else:
|
|||
O.print_every,
|
||||
O.cnn_storage,
|
||||
)
|
||||
eqx.tree_serialise_leaves(O.cnn_storage, cnn)
|
||||
|
||||
for sae_hyperparams in O.sae:
|
||||
sae_dir = run_dir(sae_hyperparams)
|
||||
if sae_dir.exists():
|
||||
continue
|
||||
|
||||
sae = train_sae(
|
||||
key,
|
||||
model,
|
||||
lambda m: m.layers[6],
|
||||
64,
|
||||
cnn,
|
||||
lambda m: m.layers[sae_hyperparams.layer],
|
||||
sae_hyperparams.input_size,
|
||||
sae_hyperparams.hidden_size,
|
||||
O.batch_size,
|
||||
O.learning_rate,
|
||||
sae_hyperparams.learning_rate,
|
||||
O.steps,
|
||||
O.print_every,
|
||||
O.sae_storage,
|
||||
)
|
||||
|
||||
sae_dir.mkdir()
|
||||
argtoml.save(sae_hyperparams, sae_dir / "sae-hyperparams.toml")
|
||||
argtoml.save(O, sae_dir / "config.toml")
|
||||
jo3eqx.save(
|
||||
sae_dir / f"sae.eqx",
|
||||
sae,
|
||||
{
|
||||
"in_size": sae_hyperparams.input_size,
|
||||
"hidden_size": sae_hyperparams.hidden_size,
|
||||
},
|
||||
)
|
||||
|
||||
for sae_hyperparams in O.sae:
|
||||
sae_dir = run_dir(sae_hyperparams)
|
||||
sae = jo3eqx.load(sae_dir / f"sae.eqx", SAE)
|
||||
sown_cnn = jo3eqx.sow(lambda m: m.layers[sae_hyperparams.layer], cnn)
|
||||
trainloader, testloader = jo3mnist.load(
|
||||
batch_size=O.batch_size, shuffle=False
|
||||
)
|
||||
|
||||
train_dir = sae_dir / f"train"
|
||||
if not train_dir.exists():
|
||||
print("saving features from the training set")
|
||||
train_dir.mkdir()
|
||||
for i, features in tqdm(
|
||||
sample_features(sown_cnn, sae, trainloader), total=len(trainloader)
|
||||
):
|
||||
jnp.save(train_dir / f"{i}.npy", features, allow_pickle=False)
|
||||
|
||||
test_dir = sae_dir / f"test"
|
||||
if not test_dir.exists():
|
||||
print("saving features from the test set")
|
||||
test_dir.mkdir()
|
||||
for i, features in tqdm(
|
||||
sample_features(sown_cnn, sae, testloader), total=len(testloader)
|
||||
):
|
||||
jnp.save(test_dir / f"{i}.npy", features, allow_pickle=False)
|
||||
|
|
|
@ -152,5 +152,4 @@ def train_cnn(
|
|||
model = train_loop(
|
||||
model, trainloader, testloader, optim, steps, print_every
|
||||
)
|
||||
eqx.tree_serialise_leaves(model_storage, model)
|
||||
return model
|
||||
|
|
51
src/sae.py
51
src/sae.py
|
@ -58,6 +58,13 @@ class SAE(eqx.Module):
|
|||
return out
|
||||
|
||||
|
||||
def sample_features(cnn, sae, loader):
|
||||
for i, (x, _) in enumerate(loader):
|
||||
x = x.numpy()
|
||||
activ = jax.vmap(cnn)(x)[0]
|
||||
yield i, sae.encode(activ)
|
||||
|
||||
|
||||
def evaluate(model: eqx.Module, testloader: DataLoader):
|
||||
"""This function evaluates the model on the test dataset,
|
||||
computing both the average loss and the average accuracy.
|
||||
|
@ -68,11 +75,27 @@ def evaluate(model: eqx.Module, testloader: DataLoader):
|
|||
y = y.numpy()
|
||||
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
|
||||
# and both have JIT wrappers, so this is fast.
|
||||
pred_y = jnp.argmax(jax.vmap(model)(x)[0], axis=1)
|
||||
pred_y = jnp.argmax(jax.vmap(model)(x)[-1], axis=1)
|
||||
avg_acc += jnp.mean(y == pred_y)
|
||||
return avg_acc / len(testloader)
|
||||
|
||||
|
||||
@eqx.filter_jit
|
||||
def make_step(
|
||||
sae: SAE,
|
||||
model: eqx.Module,
|
||||
optim,
|
||||
opt_state: PyTree,
|
||||
x: Float[Array, "batch 1 28 28"],
|
||||
y: Int[Array, " batch"],
|
||||
):
|
||||
activ = jax.vmap(model)(x)[0]
|
||||
loss_value, grads = SAE.loss(sae, activ[0], 1)
|
||||
updates, opt_state = optim.update(grads, opt_state, sae)
|
||||
sae = eqx.apply_updates(sae, updates)
|
||||
return sae, opt_state, loss_value
|
||||
|
||||
|
||||
def train_loop(
|
||||
sae: SAE,
|
||||
model: eqx.Module,
|
||||
|
@ -89,22 +112,6 @@ def train_loop(
|
|||
|
||||
print(f"test_accuracy={evaluate(model, testloader).item()}")
|
||||
|
||||
# Always wrap everything -- computing gradients, running the optimiser, updating
|
||||
# the model -- into a single JIT region. This ensures things run as fast as
|
||||
# possible.
|
||||
@eqx.filter_jit
|
||||
def make_step(
|
||||
sae: SAE,
|
||||
model: eqx.Module,
|
||||
opt_state: PyTree,
|
||||
x: Float[Array, "batch 1 28 28"],
|
||||
y: Int[Array, " batch"],
|
||||
):
|
||||
_, activ = jax.vmap(model)(x)
|
||||
loss_value, grads = SAE.loss(sae, activ[0], 1)
|
||||
updates, opt_state = optim.update(grads, opt_state, sae)
|
||||
sae = eqx.apply_updates(sae, updates)
|
||||
return sae, opt_state, loss_value
|
||||
|
||||
# Loop over our training dataset as many times as we need.
|
||||
def infinite_trainloader():
|
||||
|
@ -116,7 +123,7 @@ def train_loop(
|
|||
# so convert them to NumPy arrays.
|
||||
x = x.numpy()
|
||||
y = y.numpy()
|
||||
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y)
|
||||
sae, opt_state, train_loss = make_step(sae, model, optim, opt_state, x, y)
|
||||
if (step % print_every) == 0 or (step == steps - 1):
|
||||
model_with_sae = insert_after(sae_pos, model, sae)
|
||||
test_accuracy = evaluate(model_with_sae, testloader)
|
||||
|
@ -132,16 +139,16 @@ def train_sae(
|
|||
cnn,
|
||||
sae_pos,
|
||||
activ_size,
|
||||
hidden_size,
|
||||
batch_size,
|
||||
learning_rate,
|
||||
steps,
|
||||
print_every,
|
||||
sae_storage="./res/sae_layer_6.eqx",
|
||||
):
|
||||
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
||||
sae = SAE(activ_size, 1000, key)
|
||||
sae = SAE(activ_size, hidden_size, key)
|
||||
optim = optax.adamw(learning_rate)
|
||||
sae = train_loop(
|
||||
return train_loop(
|
||||
sae,
|
||||
cnn,
|
||||
sae_pos,
|
||||
|
@ -151,5 +158,3 @@ def train_sae(
|
|||
steps,
|
||||
print_every,
|
||||
)
|
||||
eqx.tree_serialise_leaves(sae_storage, sae)
|
||||
return sae
|
||||
|
|
Loading…
Reference in New Issue
Block a user