big search for good learning rate

This commit is contained in:
JJJHolscher 2023-12-13 20:42:30 +01:00
parent d8ce953d71
commit bc7647cb43
5 changed files with 145 additions and 61 deletions

View File

@ -1,7 +1,47 @@
batch_size = 64
learning_rate = 3e-4
steps = 300
print_every = 30
seed = 5678
steps = 500000
print_every = 10000
seed = 0
cnn_storage = "./res/cnn.eqx"
sae_storage = "./res/sae.eqx"
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 0.1
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 3e-2
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 1e-2
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 3e-3
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 1e-3
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 3e-4
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 1e-4

View File

@ -1,25 +1,21 @@
argtoml
argtoml[save]
build
debugpy
equinox
jax>=0.4.14
jax
jaxtyping
jo3mnist
jo3util
matplotlib
nbclassic
notebook
optax
pandas
pyright
scikit-learn
tensorboard
tensorboardX
tensorflow
torch
torchvision
tqdm
twine
typeguard
git+file:///mnt/nas/git/jo3util
git+https://github.com/JJJHolscher/jupytools
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-f https://download.pytorch.org/whl/cu118
git+file:///mnt/nas/git/mnist
--extra-index-url https://download.pytorch.org/whl/cpu

View File

@ -1,30 +1,30 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
"""
"""
from pathlib import Path
import argtoml
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Float # https://github.com/google/jaxtyping
from jaxtyping import Array, Int, PyTree
from torch.utils.data import DataLoader
import jo3mnist
from jo3util import eqx as jo3eqx
from jo3util.root import run_dir
from tqdm import tqdm
from .cnn import CNN, train_cnn
from .sae import SAE, train_sae
from .sae import SAE, sample_features, train_sae
# Hyperparameters
O = argtoml.parse_args()
key = jax.random.PRNGKey(O.seed)
key, subkey = jax.random.split(key)
if O.cnn_storage.exists():
model = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
if (Path(".") / O.cnn_storage).exists():
cnn = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
else:
model = train_cnn(
cnn = train_cnn(
subkey,
O.batch_size,
O.learning_rate,
@ -32,15 +32,59 @@ else:
O.print_every,
O.cnn_storage,
)
eqx.tree_serialise_leaves(O.cnn_storage, cnn)
for sae_hyperparams in O.sae:
sae_dir = run_dir(sae_hyperparams)
if sae_dir.exists():
continue
sae = train_sae(
key,
model,
lambda m: m.layers[6],
64,
cnn,
lambda m: m.layers[sae_hyperparams.layer],
sae_hyperparams.input_size,
sae_hyperparams.hidden_size,
O.batch_size,
O.learning_rate,
sae_hyperparams.learning_rate,
O.steps,
O.print_every,
O.sae_storage,
)
sae_dir.mkdir()
argtoml.save(sae_hyperparams, sae_dir / "sae-hyperparams.toml")
argtoml.save(O, sae_dir / "config.toml")
jo3eqx.save(
sae_dir / f"sae.eqx",
sae,
{
"in_size": sae_hyperparams.input_size,
"hidden_size": sae_hyperparams.hidden_size,
},
)
for sae_hyperparams in O.sae:
sae_dir = run_dir(sae_hyperparams)
sae = jo3eqx.load(sae_dir / f"sae.eqx", SAE)
sown_cnn = jo3eqx.sow(lambda m: m.layers[sae_hyperparams.layer], cnn)
trainloader, testloader = jo3mnist.load(
batch_size=O.batch_size, shuffle=False
)
train_dir = sae_dir / f"train"
if not train_dir.exists():
print("saving features from the training set")
train_dir.mkdir()
for i, features in tqdm(
sample_features(sown_cnn, sae, trainloader), total=len(trainloader)
):
jnp.save(train_dir / f"{i}.npy", features, allow_pickle=False)
test_dir = sae_dir / f"test"
if not test_dir.exists():
print("saving features from the test set")
test_dir.mkdir()
for i, features in tqdm(
sample_features(sown_cnn, sae, testloader), total=len(testloader)
):
jnp.save(test_dir / f"{i}.npy", features, allow_pickle=False)

View File

@ -152,5 +152,4 @@ def train_cnn(
model = train_loop(
model, trainloader, testloader, optim, steps, print_every
)
eqx.tree_serialise_leaves(model_storage, model)
return model

View File

@ -58,6 +58,13 @@ class SAE(eqx.Module):
return out
def sample_features(cnn, sae, loader):
for i, (x, _) in enumerate(loader):
x = x.numpy()
activ = jax.vmap(cnn)(x)[0]
yield i, sae.encode(activ)
def evaluate(model: eqx.Module, testloader: DataLoader):
"""This function evaluates the model on the test dataset,
computing both the average loss and the average accuracy.
@ -68,11 +75,27 @@ def evaluate(model: eqx.Module, testloader: DataLoader):
y = y.numpy()
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
# and both have JIT wrappers, so this is fast.
pred_y = jnp.argmax(jax.vmap(model)(x)[0], axis=1)
pred_y = jnp.argmax(jax.vmap(model)(x)[-1], axis=1)
avg_acc += jnp.mean(y == pred_y)
return avg_acc / len(testloader)
@eqx.filter_jit
def make_step(
sae: SAE,
model: eqx.Module,
optim,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"],
):
activ = jax.vmap(model)(x)[0]
loss_value, grads = SAE.loss(sae, activ[0], 1)
updates, opt_state = optim.update(grads, opt_state, sae)
sae = eqx.apply_updates(sae, updates)
return sae, opt_state, loss_value
def train_loop(
sae: SAE,
model: eqx.Module,
@ -89,22 +112,6 @@ def train_loop(
print(f"test_accuracy={evaluate(model, testloader).item()}")
# Always wrap everything -- computing gradients, running the optimiser, updating
# the model -- into a single JIT region. This ensures things run as fast as
# possible.
@eqx.filter_jit
def make_step(
sae: SAE,
model: eqx.Module,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"],
):
_, activ = jax.vmap(model)(x)
loss_value, grads = SAE.loss(sae, activ[0], 1)
updates, opt_state = optim.update(grads, opt_state, sae)
sae = eqx.apply_updates(sae, updates)
return sae, opt_state, loss_value
# Loop over our training dataset as many times as we need.
def infinite_trainloader():
@ -116,7 +123,7 @@ def train_loop(
# so convert them to NumPy arrays.
x = x.numpy()
y = y.numpy()
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y)
sae, opt_state, train_loss = make_step(sae, model, optim, opt_state, x, y)
if (step % print_every) == 0 or (step == steps - 1):
model_with_sae = insert_after(sae_pos, model, sae)
test_accuracy = evaluate(model_with_sae, testloader)
@ -132,16 +139,16 @@ def train_sae(
cnn,
sae_pos,
activ_size,
hidden_size,
batch_size,
learning_rate,
steps,
print_every,
sae_storage="./res/sae_layer_6.eqx",
):
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
sae = SAE(activ_size, 1000, key)
sae = SAE(activ_size, hidden_size, key)
optim = optax.adamw(learning_rate)
sae = train_loop(
return train_loop(
sae,
cnn,
sae_pos,
@ -151,5 +158,3 @@ def train_sae(
steps,
print_every,
)
eqx.tree_serialise_leaves(sae_storage, sae)
return sae