tensorboard, low l1, lower lr

This commit is contained in:
JJJHolscher 2023-12-14 13:44:35 +01:00
parent bc7647cb43
commit ecd7d3bd65
3 changed files with 40 additions and 42 deletions

View File

@ -1,47 +1,33 @@
batch_size = 64
steps = 500000
print_every = 10000
steps = 1000000
print_every = 50000
seed = 0
cnn_storage = "./res/cnn.eqx"
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 0.1
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 3e-2
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 1e-2
[[sae]]
layer = 6
hidden_size = 1000
input_size = 64
learning_rate = 3e-3
[[sae]]
layer = 6
hidden_size = 1000
hidden_size = 300
input_size = 64
learning_rate = 1e-3
l1 = 3e-4 # from Neel Nanda's sae git
[[sae]]
layer = 6
hidden_size = 1000
hidden_size = 300
input_size = 64
learning_rate = 3e-4
l1 = 3e-4
[[sae]]
layer = 6
hidden_size = 1000
hidden_size = 300
input_size = 64
learning_rate = 1e-4
l1 = 3e-4
[[sae]]
layer = 6
hidden_size = 300
input_size = 64
learning_rate = 3e-5
l1 = 3e-4

View File

@ -10,6 +10,7 @@ import jax.numpy as jnp
import jo3mnist
from jo3util import eqx as jo3eqx
from jo3util.root import run_dir
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from .cnn import CNN, train_cnn
@ -49,6 +50,7 @@ for sae_hyperparams in O.sae:
sae_hyperparams.learning_rate,
O.steps,
O.print_every,
SummaryWriter(sae_dir / "log")
)
sae_dir.mkdir()

View File

@ -5,6 +5,7 @@
"""
from datetime import datetime
from typing import Callable, List
import equinox as eqx
@ -15,6 +16,7 @@ import optax
from jaxtyping import Array, Float, Int, PyTree
from jo3util.eqx import insert_after, sow
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
class SAE(eqx.Module):
@ -87,10 +89,10 @@ def make_step(
optim,
opt_state: PyTree,
x: Float[Array, "batch 1 28 28"],
y: Int[Array, " batch"],
λ: float,
):
activ = jax.vmap(model)(x)[0]
loss_value, grads = SAE.loss(sae, activ[0], 1)
loss_value, grads = SAE.loss(sae, activ[0], λ)
updates, opt_state = optim.update(grads, opt_state, sae)
sae = eqx.apply_updates(sae, updates)
return sae, opt_state, loss_value
@ -98,39 +100,43 @@ def make_step(
def train_loop(
sae: SAE,
model: eqx.Module,
cnn: eqx.Module,
sae_pos: Callable,
trainloader: DataLoader,
testloader: DataLoader,
optim: optax.GradientTransformation,
steps: int,
print_every: int,
tensorboard,
λ,
) -> eqx.Module:
opt_state = optim.init(eqx.filter(sae, eqx.is_array))
model = sow(sae_pos, model)
print(f"test_accuracy={evaluate(model, testloader).item()}")
cnn = sow(sae_pos, cnn)
print(f"test_accuracy={evaluate(cnn, testloader).item()}")
# Loop over our training dataset as many times as we need.
def infinite_trainloader():
while True:
yield from trainloader
for step, (x, y) in zip(range(steps), infinite_trainloader()):
for step, (x, _) in zip(range(steps), infinite_trainloader()):
# PyTorch dataloaders give PyTorch tensors by default,
# so convert them to NumPy arrays.
x = x.numpy()
y = y.numpy()
sae, opt_state, train_loss = make_step(sae, model, optim, opt_state, x, y)
sae, opt_state, train_loss = make_step(
sae, cnn, optim, opt_state, x.numpy(), λ
)
tensorboard.add_scalar("loss", train_loss.item(), step)
if (step % print_every) == 0 or (step == steps - 1):
model_with_sae = insert_after(sae_pos, model, sae)
test_accuracy = evaluate(model_with_sae, testloader)
cnn_with_sae = insert_after(sae_pos, cnn, sae)
test_accuracy = evaluate(cnn_with_sae, testloader)
print(
datetime.now().strftime("%H:%M"),
f"{step=}, train_loss={train_loss.item()}, "
f"test_accuracy={test_accuracy.item()}"
f"test_accuracy={test_accuracy.item()}",
)
tensorboard.add_scalar("accu", test_accuracy.item(), step)
return sae
@ -144,6 +150,8 @@ def train_sae(
learning_rate,
steps,
print_every,
tensorboard,
λ,
):
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
sae = SAE(activ_size, hidden_size, key)
@ -157,4 +165,6 @@ def train_sae(
optim,
steps,
print_every,
tensorboard,
λ
)