big search for good learning rate
This commit is contained in:
parent
d8ce953d71
commit
bc7647cb43
50
config.toml
50
config.toml
|
@ -1,7 +1,47 @@
|
||||||
batch_size = 64
|
batch_size = 64
|
||||||
learning_rate = 3e-4
|
steps = 500000
|
||||||
steps = 300
|
print_every = 10000
|
||||||
print_every = 30
|
seed = 0
|
||||||
seed = 5678
|
|
||||||
cnn_storage = "./res/cnn.eqx"
|
cnn_storage = "./res/cnn.eqx"
|
||||||
sae_storage = "./res/sae.eqx"
|
|
||||||
|
[[sae]]
|
||||||
|
layer = 6
|
||||||
|
hidden_size = 1000
|
||||||
|
input_size = 64
|
||||||
|
learning_rate = 0.1
|
||||||
|
|
||||||
|
[[sae]]
|
||||||
|
layer = 6
|
||||||
|
hidden_size = 1000
|
||||||
|
input_size = 64
|
||||||
|
learning_rate = 3e-2
|
||||||
|
|
||||||
|
[[sae]]
|
||||||
|
layer = 6
|
||||||
|
hidden_size = 1000
|
||||||
|
input_size = 64
|
||||||
|
learning_rate = 1e-2
|
||||||
|
|
||||||
|
[[sae]]
|
||||||
|
layer = 6
|
||||||
|
hidden_size = 1000
|
||||||
|
input_size = 64
|
||||||
|
learning_rate = 3e-3
|
||||||
|
|
||||||
|
[[sae]]
|
||||||
|
layer = 6
|
||||||
|
hidden_size = 1000
|
||||||
|
input_size = 64
|
||||||
|
learning_rate = 1e-3
|
||||||
|
|
||||||
|
[[sae]]
|
||||||
|
layer = 6
|
||||||
|
hidden_size = 1000
|
||||||
|
input_size = 64
|
||||||
|
learning_rate = 3e-4
|
||||||
|
|
||||||
|
[[sae]]
|
||||||
|
layer = 6
|
||||||
|
hidden_size = 1000
|
||||||
|
input_size = 64
|
||||||
|
learning_rate = 1e-4
|
||||||
|
|
|
@ -1,25 +1,21 @@
|
||||||
argtoml
|
argtoml[save]
|
||||||
build
|
build
|
||||||
debugpy
|
debugpy
|
||||||
equinox
|
equinox
|
||||||
jax>=0.4.14
|
jax
|
||||||
jaxtyping
|
jaxtyping
|
||||||
|
jo3mnist
|
||||||
|
jo3util
|
||||||
matplotlib
|
matplotlib
|
||||||
nbclassic
|
|
||||||
notebook
|
|
||||||
optax
|
optax
|
||||||
pandas
|
pandas
|
||||||
pyright
|
pyright
|
||||||
scikit-learn
|
scikit-learn
|
||||||
tensorboard
|
tensorflow
|
||||||
tensorboardX
|
|
||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
tqdm
|
tqdm
|
||||||
twine
|
twine
|
||||||
typeguard
|
typeguard
|
||||||
git+file:///mnt/nas/git/jo3util
|
|
||||||
git+https://github.com/JJJHolscher/jupytools
|
|
||||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||||
-f https://download.pytorch.org/whl/cu118
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
git+file:///mnt/nas/git/mnist
|
|
||||||
|
|
|
@ -1,30 +1,30 @@
|
||||||
#! /usr/bin/env python3
|
#! /usr/bin/env python3
|
||||||
# vim:fenc=utf-8
|
# vim:fenc=utf-8
|
||||||
|
|
||||||
"""
|
from pathlib import Path
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argtoml
|
import argtoml
|
||||||
import equinox as eqx
|
import equinox as eqx
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from jaxtyping import Float # https://github.com/google/jaxtyping
|
import jo3mnist
|
||||||
from jaxtyping import Array, Int, PyTree
|
from jo3util import eqx as jo3eqx
|
||||||
from torch.utils.data import DataLoader
|
from jo3util.root import run_dir
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .cnn import CNN, train_cnn
|
from .cnn import CNN, train_cnn
|
||||||
from .sae import SAE, train_sae
|
from .sae import SAE, sample_features, train_sae
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
O = argtoml.parse_args()
|
O = argtoml.parse_args()
|
||||||
|
|
||||||
key = jax.random.PRNGKey(O.seed)
|
key = jax.random.PRNGKey(O.seed)
|
||||||
key, subkey = jax.random.split(key)
|
key, subkey = jax.random.split(key)
|
||||||
|
|
||||||
if O.cnn_storage.exists():
|
if (Path(".") / O.cnn_storage).exists():
|
||||||
model = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
|
cnn = eqx.tree_deserialise_leaves(O.cnn_storage, CNN(subkey))
|
||||||
else:
|
else:
|
||||||
model = train_cnn(
|
cnn = train_cnn(
|
||||||
subkey,
|
subkey,
|
||||||
O.batch_size,
|
O.batch_size,
|
||||||
O.learning_rate,
|
O.learning_rate,
|
||||||
|
@ -32,15 +32,59 @@ else:
|
||||||
O.print_every,
|
O.print_every,
|
||||||
O.cnn_storage,
|
O.cnn_storage,
|
||||||
)
|
)
|
||||||
|
eqx.tree_serialise_leaves(O.cnn_storage, cnn)
|
||||||
|
|
||||||
sae = train_sae(
|
for sae_hyperparams in O.sae:
|
||||||
key,
|
sae_dir = run_dir(sae_hyperparams)
|
||||||
model,
|
if sae_dir.exists():
|
||||||
lambda m: m.layers[6],
|
continue
|
||||||
64,
|
|
||||||
O.batch_size,
|
sae = train_sae(
|
||||||
O.learning_rate,
|
key,
|
||||||
O.steps,
|
cnn,
|
||||||
O.print_every,
|
lambda m: m.layers[sae_hyperparams.layer],
|
||||||
O.sae_storage,
|
sae_hyperparams.input_size,
|
||||||
)
|
sae_hyperparams.hidden_size,
|
||||||
|
O.batch_size,
|
||||||
|
sae_hyperparams.learning_rate,
|
||||||
|
O.steps,
|
||||||
|
O.print_every,
|
||||||
|
)
|
||||||
|
|
||||||
|
sae_dir.mkdir()
|
||||||
|
argtoml.save(sae_hyperparams, sae_dir / "sae-hyperparams.toml")
|
||||||
|
argtoml.save(O, sae_dir / "config.toml")
|
||||||
|
jo3eqx.save(
|
||||||
|
sae_dir / f"sae.eqx",
|
||||||
|
sae,
|
||||||
|
{
|
||||||
|
"in_size": sae_hyperparams.input_size,
|
||||||
|
"hidden_size": sae_hyperparams.hidden_size,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for sae_hyperparams in O.sae:
|
||||||
|
sae_dir = run_dir(sae_hyperparams)
|
||||||
|
sae = jo3eqx.load(sae_dir / f"sae.eqx", SAE)
|
||||||
|
sown_cnn = jo3eqx.sow(lambda m: m.layers[sae_hyperparams.layer], cnn)
|
||||||
|
trainloader, testloader = jo3mnist.load(
|
||||||
|
batch_size=O.batch_size, shuffle=False
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dir = sae_dir / f"train"
|
||||||
|
if not train_dir.exists():
|
||||||
|
print("saving features from the training set")
|
||||||
|
train_dir.mkdir()
|
||||||
|
for i, features in tqdm(
|
||||||
|
sample_features(sown_cnn, sae, trainloader), total=len(trainloader)
|
||||||
|
):
|
||||||
|
jnp.save(train_dir / f"{i}.npy", features, allow_pickle=False)
|
||||||
|
|
||||||
|
test_dir = sae_dir / f"test"
|
||||||
|
if not test_dir.exists():
|
||||||
|
print("saving features from the test set")
|
||||||
|
test_dir.mkdir()
|
||||||
|
for i, features in tqdm(
|
||||||
|
sample_features(sown_cnn, sae, testloader), total=len(testloader)
|
||||||
|
):
|
||||||
|
jnp.save(test_dir / f"{i}.npy", features, allow_pickle=False)
|
||||||
|
|
|
@ -152,5 +152,4 @@ def train_cnn(
|
||||||
model = train_loop(
|
model = train_loop(
|
||||||
model, trainloader, testloader, optim, steps, print_every
|
model, trainloader, testloader, optim, steps, print_every
|
||||||
)
|
)
|
||||||
eqx.tree_serialise_leaves(model_storage, model)
|
|
||||||
return model
|
return model
|
||||||
|
|
51
src/sae.py
51
src/sae.py
|
@ -58,6 +58,13 @@ class SAE(eqx.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def sample_features(cnn, sae, loader):
|
||||||
|
for i, (x, _) in enumerate(loader):
|
||||||
|
x = x.numpy()
|
||||||
|
activ = jax.vmap(cnn)(x)[0]
|
||||||
|
yield i, sae.encode(activ)
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model: eqx.Module, testloader: DataLoader):
|
def evaluate(model: eqx.Module, testloader: DataLoader):
|
||||||
"""This function evaluates the model on the test dataset,
|
"""This function evaluates the model on the test dataset,
|
||||||
computing both the average loss and the average accuracy.
|
computing both the average loss and the average accuracy.
|
||||||
|
@ -68,11 +75,27 @@ def evaluate(model: eqx.Module, testloader: DataLoader):
|
||||||
y = y.numpy()
|
y = y.numpy()
|
||||||
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
|
# Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
|
||||||
# and both have JIT wrappers, so this is fast.
|
# and both have JIT wrappers, so this is fast.
|
||||||
pred_y = jnp.argmax(jax.vmap(model)(x)[0], axis=1)
|
pred_y = jnp.argmax(jax.vmap(model)(x)[-1], axis=1)
|
||||||
avg_acc += jnp.mean(y == pred_y)
|
avg_acc += jnp.mean(y == pred_y)
|
||||||
return avg_acc / len(testloader)
|
return avg_acc / len(testloader)
|
||||||
|
|
||||||
|
|
||||||
|
@eqx.filter_jit
|
||||||
|
def make_step(
|
||||||
|
sae: SAE,
|
||||||
|
model: eqx.Module,
|
||||||
|
optim,
|
||||||
|
opt_state: PyTree,
|
||||||
|
x: Float[Array, "batch 1 28 28"],
|
||||||
|
y: Int[Array, " batch"],
|
||||||
|
):
|
||||||
|
activ = jax.vmap(model)(x)[0]
|
||||||
|
loss_value, grads = SAE.loss(sae, activ[0], 1)
|
||||||
|
updates, opt_state = optim.update(grads, opt_state, sae)
|
||||||
|
sae = eqx.apply_updates(sae, updates)
|
||||||
|
return sae, opt_state, loss_value
|
||||||
|
|
||||||
|
|
||||||
def train_loop(
|
def train_loop(
|
||||||
sae: SAE,
|
sae: SAE,
|
||||||
model: eqx.Module,
|
model: eqx.Module,
|
||||||
|
@ -89,22 +112,6 @@ def train_loop(
|
||||||
|
|
||||||
print(f"test_accuracy={evaluate(model, testloader).item()}")
|
print(f"test_accuracy={evaluate(model, testloader).item()}")
|
||||||
|
|
||||||
# Always wrap everything -- computing gradients, running the optimiser, updating
|
|
||||||
# the model -- into a single JIT region. This ensures things run as fast as
|
|
||||||
# possible.
|
|
||||||
@eqx.filter_jit
|
|
||||||
def make_step(
|
|
||||||
sae: SAE,
|
|
||||||
model: eqx.Module,
|
|
||||||
opt_state: PyTree,
|
|
||||||
x: Float[Array, "batch 1 28 28"],
|
|
||||||
y: Int[Array, " batch"],
|
|
||||||
):
|
|
||||||
_, activ = jax.vmap(model)(x)
|
|
||||||
loss_value, grads = SAE.loss(sae, activ[0], 1)
|
|
||||||
updates, opt_state = optim.update(grads, opt_state, sae)
|
|
||||||
sae = eqx.apply_updates(sae, updates)
|
|
||||||
return sae, opt_state, loss_value
|
|
||||||
|
|
||||||
# Loop over our training dataset as many times as we need.
|
# Loop over our training dataset as many times as we need.
|
||||||
def infinite_trainloader():
|
def infinite_trainloader():
|
||||||
|
@ -116,7 +123,7 @@ def train_loop(
|
||||||
# so convert them to NumPy arrays.
|
# so convert them to NumPy arrays.
|
||||||
x = x.numpy()
|
x = x.numpy()
|
||||||
y = y.numpy()
|
y = y.numpy()
|
||||||
sae, opt_state, train_loss = make_step(sae, model, opt_state, x, y)
|
sae, opt_state, train_loss = make_step(sae, model, optim, opt_state, x, y)
|
||||||
if (step % print_every) == 0 or (step == steps - 1):
|
if (step % print_every) == 0 or (step == steps - 1):
|
||||||
model_with_sae = insert_after(sae_pos, model, sae)
|
model_with_sae = insert_after(sae_pos, model, sae)
|
||||||
test_accuracy = evaluate(model_with_sae, testloader)
|
test_accuracy = evaluate(model_with_sae, testloader)
|
||||||
|
@ -132,16 +139,16 @@ def train_sae(
|
||||||
cnn,
|
cnn,
|
||||||
sae_pos,
|
sae_pos,
|
||||||
activ_size,
|
activ_size,
|
||||||
|
hidden_size,
|
||||||
batch_size,
|
batch_size,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
steps,
|
steps,
|
||||||
print_every,
|
print_every,
|
||||||
sae_storage="./res/sae_layer_6.eqx",
|
|
||||||
):
|
):
|
||||||
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
||||||
sae = SAE(activ_size, 1000, key)
|
sae = SAE(activ_size, hidden_size, key)
|
||||||
optim = optax.adamw(learning_rate)
|
optim = optax.adamw(learning_rate)
|
||||||
sae = train_loop(
|
return train_loop(
|
||||||
sae,
|
sae,
|
||||||
cnn,
|
cnn,
|
||||||
sae_pos,
|
sae_pos,
|
||||||
|
@ -151,5 +158,3 @@ def train_sae(
|
||||||
steps,
|
steps,
|
||||||
print_every,
|
print_every,
|
||||||
)
|
)
|
||||||
eqx.tree_serialise_leaves(sae_storage, sae)
|
|
||||||
return sae
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user