big search for good learning rate

This commit is contained in:
JJJHolscher 2023-12-13 20:42:30 +01:00
parent d8ce953d71
commit bc7647cb43
5 changed files with 145 additions and 61 deletions

View File

@ -1,7 +1,47 @@
batch_size = 64 batch_size = 64
learning_rate = 3e-4 steps = 500000
steps = 300 print_every = 10000
print_every = 30 seed = 0
seed = 5678
cnn_storage = "./res/cnn.eqx" cnn_storage = "./res/cnn.eqx"
sae_storage = "./res/sae.eqx"
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 0.1
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 3e-2
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 1e-2
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 3e-3
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 1e-3
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 3e-4
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 1e-4

View File

@ -1,25 +1,21 @@
argtoml argtoml[save]
build build
debugpy debugpy
equinox equinox
jax>=0.4.14 jax
jaxtyping jaxtyping
jo3mnist
jo3util
matplotlib matplotlib
nbclassic
notebook
optax optax
pandas pandas
pyright pyright
scikit-learn scikit-learn
tensorboard tensorflow
tensorboardX
torch torch
torchvision torchvision
tqdm tqdm
twine twine
typeguard typeguard
git+file:///mnt/nas/git/jo3util
git+https://github.com/JJJHolscher/jupytools
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-f https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cpu
git+file:///mnt/nas/git/mnist

View File

@ -1,30 +1,30 @@
#! /usr/bin/env python3 #! /usr/bin/env python3
# vim:fenc=utf-8 # vim:fenc=utf-8
""" from pathlib import Path
"""
import argtoml import argtoml
import equinox as eqx import equinox as eqx
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jaxtyping import Float # https://github.com/google/jaxtyping import jo3mnist
from jaxtyping import Array, Int, PyTree from jo3util import eqx as jo3eqx
from torch.utils.data import DataLoader from jo3util.root import run_dir
from tqdm import tqdm
from .cnn import CNN, train_cnn from .cnn import CNN, train_cnn
from .sae import SAE, train_sae from .sae import SAE, sample_features, train_sae
# Hyperparameters # Hyperparameters
O = argtoml.parse_args() O = argtoml.parse_args()
key = jax.random.PRNGKey(O.seed) key = jax.random.PRNGKey(O.seed)
key, subkey = jax.random.split(key) key, subkey = jax.random.split(key)
if O.cnn_storage.exists(): if (Path(".") / O.cnn_storage).exists():
model = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey)) cnn = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
else: else:
model = train_cnn( cnn = train_cnn(
subkey, subkey,
O.batch_size, O.batch_size,
O.learning_rate, O.learning_rate,
@ -32,15 +32,59 @@ else:
O.print_every, O.print_every,
O.cnn_storage, O.cnn_storage,
) )
eqx.tree_serialise_leaves(O.cnn_storage, cnn)
sae = train_sae( for sae_hyperparams in O.sae:
key, sae_dir = run_dir(sae_hyperparams)
model, if sae_dir.exists():
lambda m: m.layers[6], continue
64,
O.batch_size, sae = train_sae(
O.learning_rate, key,
O.steps, cnn,
O.print_every, lambda m: m.layers[sae_hyperparams.layer],
O.sae_storage, sae_hyperparams.input_size,
) sae_hyperparams.hidden_size,
O.batch_size,
sae_hyperparams.learning_rate,
O.steps,
O.print_every,
)
sae_dir.mkdir()
argtoml.save(sae_hyperparams, sae_dir / "sae-hyperparams.toml")
argtoml.save(O, sae_dir / "config.toml")
jo3eqx.save(
sae_dir / f"sae.eqx",
sae,
{
"in_size": sae_hyperparams.input_size,
"hidden_size": sae_hyperparams.hidden_size,
},
)
for sae_hyperparams in O.sae:
sae_dir = run_dir(sae_hyperparams)
sae = jo3eqx.load(sae_dir / f"sae.eqx", SAE)
sown_cnn = jo3eqx.sow(lambda m: m.layers[sae_hyperparams.layer], cnn)
trainloader, testloader = jo3mnist.load(
batch_size=O.batch_size, shuffle=False
)
train_dir = sae_dir / f"train"
if not train_dir.exists():
print("saving features from the training set")
train_dir.mkdir()
for i, features in tqdm(
sample_features(sown_cnn, sae, trainloader), total=len(trainloader)
):
jnp.save(train_dir / f"{i}.npy", features, allow_pickle=False)
test_dir = sae_dir / f"test"
if not test_dir.exists():
print("saving features from the test set")
test_dir.mkdir()
for i, features in tqdm(
sample_features(sown_cnn, sae, testloader), total=len(testloader)
):
jnp.save(test_dir / f"{i}.npy", features, allow_pickle=False)

View File

@ -152,5 +152,4 @@ def train_cnn(
model = train_loop( model = train_loop(
model, trainloader, testloader, optim, steps, print_every model, trainloader, testloader, optim, steps, print_every
) )
eqx.tree_serialise_leaves(model_storage, model)
return model return model

View File

@ -58,6 +58,13 @@ class SAE(eqx.Module):
return out return out
def sample_features(cnn, sae, loader):
for i, (x, _) in enumerate(loader):
x = x.numpy()
activ = jax.vmap(cnn)(x)[0]
yield i, sae.encode(activ)
def evaluate(model: eqx.Module, testloader: DataLoader): def evaluate(model: eqx.Module, testloader: DataLoader):
"""This function evaluates the model on the test dataset, """This function evaluates the model on the test dataset,
computing both the average loss and the average accuracy. computing both the average loss and the average accuracy.
@ -68,11 +75,27 @@ def evaluate(model: eqx.Module, testloader: DataLoader):
y = y.numpy() y = y.numpy()
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`, # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
# and both have JIT wrappers, so this is fast. # and both have JIT wrappers, so this is fast.
pred_y = jnp.argmax(jax.vmap(model)(x)[0], axis=1) pred_y = jnp.argmax(jax.vmap(model)(x)[-1], axis=1)
avg_acc += jnp.mean(y == pred_y) avg_acc += jnp.mean(y == pred_y)
return avg_acc / len(testloader) return avg_acc / len(testloader)
@eqx.filter_jit
def make_step(
sae: SAE,
model: eqx.Module,
optim,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"],
):
activ = jax.vmap(model)(x)[0]
loss_value, grads = SAE.loss(sae, activ[0], 1)
updates, opt_state = optim.update(grads, opt_state, sae)
sae = eqx.apply_updates(sae, updates)
return sae, opt_state, loss_value
def train_loop( def train_loop(
sae: SAE, sae: SAE,
model: eqx.Module, model: eqx.Module,
@ -89,22 +112,6 @@ def train_loop(
print(f"test_accuracy={evaluate(model, testloader).item()}") print(f"test_accuracy={evaluate(model, testloader).item()}")
# Always wrap everything -- computing gradients, running the optimiser, updating
# the model -- into a single JIT region. This ensures things run as fast as
# possible.
@eqx.filter_jit
def make_step(
sae: SAE,
model: eqx.Module,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"],
):
_, activ = jax.vmap(model)(x)
loss_value, grads = SAE.loss(sae, activ[0], 1)
updates, opt_state = optim.update(grads, opt_state, sae)
sae = eqx.apply_updates(sae, updates)
return sae, opt_state, loss_value
# Loop over our training dataset as many times as we need. # Loop over our training dataset as many times as we need.
def infinite_trainloader(): def infinite_trainloader():
@ -116,7 +123,7 @@ def train_loop(
# so convert them to NumPy arrays. # so convert them to NumPy arrays.
x = x.numpy() x = x.numpy()
y = y.numpy() y = y.numpy()
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y) sae, opt_state, train_loss = make_step(sae, model, optim, opt_state, x, y)
if (step % print_every) == 0 or (step == steps - 1): if (step % print_every) == 0 or (step == steps - 1):
model_with_sae = insert_after(sae_pos, model, sae) model_with_sae = insert_after(sae_pos, model, sae)
test_accuracy = evaluate(model_with_sae, testloader) test_accuracy = evaluate(model_with_sae, testloader)
@ -132,16 +139,16 @@ def train_sae(
cnn, cnn,
sae_pos, sae_pos,
activ_size, activ_size,
hidden_size,
batch_size, batch_size,
learning_rate, learning_rate,
steps, steps,
print_every, print_every,
sae_storage="./res/sae_layer_6.eqx",
): ):
trainloader, testloader = jo3mnist.load(batch_size=batch_size) trainloader, testloader = jo3mnist.load(batch_size=batch_size)
sae = SAE(activ_size, 1000, key) sae = SAE(activ_size, hidden_size, key)
optim = optax.adamw(learning_rate) optim = optax.adamw(learning_rate)
sae = train_loop( return train_loop(
sae, sae,
cnn, cnn,
sae_pos, sae_pos,
@ -151,5 +158,3 @@ def train_sae(
steps, steps,
print_every, print_every,
) )
eqx.tree_serialise_leaves(sae_storage, sae)
return sae