tensorboard, low l1, lower lr
This commit is contained in:
parent
bc7647cb43
commit
ecd7d3bd65
44
config.toml
44
config.toml
|
@ -1,47 +1,33 @@
|
|||
batch_size = 64
|
||||
steps = 500000
|
||||
print_every = 10000
|
||||
steps = 1000000
|
||||
print_every = 50000
|
||||
seed = 0
|
||||
cnn_storage = "./res/cnn.eqx"
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
input_size = 64
|
||||
learning_rate = 0.1
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
input_size = 64
|
||||
learning_rate = 3e-2
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
input_size = 64
|
||||
learning_rate = 1e-2
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
input_size = 64
|
||||
learning_rate = 3e-3
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
hidden_size = 300
|
||||
input_size = 64
|
||||
learning_rate = 1e-3
|
||||
l1 = 3e-4 # from Neel Nanda's sae git
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
hidden_size = 300
|
||||
input_size = 64
|
||||
learning_rate = 3e-4
|
||||
l1 = 3e-4
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 1000
|
||||
hidden_size = 300
|
||||
input_size = 64
|
||||
learning_rate = 1e-4
|
||||
l1 = 3e-4
|
||||
|
||||
[[sae]]
|
||||
layer = 6
|
||||
hidden_size = 300
|
||||
input_size = 64
|
||||
learning_rate = 3e-5
|
||||
l1 = 3e-4
|
||||
|
|
|
@ -10,6 +10,7 @@ import jax.numpy as jnp
|
|||
import jo3mnist
|
||||
from jo3util import eqx as jo3eqx
|
||||
from jo3util.root import run_dir
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from tqdm import tqdm
|
||||
|
||||
from .cnn import CNN, train_cnn
|
||||
|
@ -49,6 +50,7 @@ for sae_hyperparams in O.sae:
|
|||
sae_hyperparams.learning_rate,
|
||||
O.steps,
|
||||
O.print_every,
|
||||
SummaryWriter(sae_dir / "log")
|
||||
)
|
||||
|
||||
sae_dir.mkdir()
|
||||
|
|
36
src/sae.py
36
src/sae.py
|
@ -5,6 +5,7 @@
|
|||
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Callable, List
|
||||
|
||||
import equinox as eqx
|
||||
|
@ -15,6 +16,7 @@ import optax
|
|||
from jaxtyping import Array, Float, Int, PyTree
|
||||
from jo3util.eqx import insert_after, sow
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
class SAE(eqx.Module):
|
||||
|
@ -87,10 +89,10 @@ def make_step(
|
|||
optim,
|
||||
opt_state: PyTree,
|
||||
x: Float[Array, "batch 1 28 28"],
|
||||
y: Int[Array, " batch"],
|
||||
λ: float,
|
||||
):
|
||||
activ = jax.vmap(model)(x)[0]
|
||||
loss_value, grads = SAE.loss(sae, activ[0], 1)
|
||||
loss_value, grads = SAE.loss(sae, activ[0], λ)
|
||||
updates, opt_state = optim.update(grads, opt_state, sae)
|
||||
sae = eqx.apply_updates(sae, updates)
|
||||
return sae, opt_state, loss_value
|
||||
|
@ -98,39 +100,43 @@ def make_step(
|
|||
|
||||
def train_loop(
|
||||
sae: SAE,
|
||||
model: eqx.Module,
|
||||
cnn: eqx.Module,
|
||||
sae_pos: Callable,
|
||||
trainloader: DataLoader,
|
||||
testloader: DataLoader,
|
||||
optim: optax.GradientTransformation,
|
||||
steps: int,
|
||||
print_every: int,
|
||||
tensorboard,
|
||||
λ,
|
||||
) -> eqx.Module:
|
||||
opt_state = optim.init(eqx.filter(sae, eqx.is_array))
|
||||
|
||||
model = sow(sae_pos, model)
|
||||
cnn = sow(sae_pos, cnn)
|
||||
|
||||
print(f"test_accuracy={evaluate(model, testloader).item()}")
|
||||
print(f"test_accuracy={evaluate(cnn, testloader).item()}")
|
||||
|
||||
|
||||
# Loop over our training dataset as many times as we need.
|
||||
def infinite_trainloader():
|
||||
while True:
|
||||
yield from trainloader
|
||||
|
||||
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
||||
for step, (x, _) in zip(range(steps), infinite_trainloader()):
|
||||
# PyTorch dataloaders give PyTorch tensors by default,
|
||||
# so convert them to NumPy arrays.
|
||||
x = x.numpy()
|
||||
y = y.numpy()
|
||||
sae, opt_state, train_loss = make_step(sae, model, optim, opt_state, x, y)
|
||||
sae, opt_state, train_loss = make_step(
|
||||
sae, cnn, optim, opt_state, x.numpy(), λ
|
||||
)
|
||||
tensorboard.add_scalar("loss", train_loss.item(), step)
|
||||
if (step % print_every) == 0 or (step == steps - 1):
|
||||
model_with_sae = insert_after(sae_pos, model, sae)
|
||||
test_accuracy = evaluate(model_with_sae, testloader)
|
||||
cnn_with_sae = insert_after(sae_pos, cnn, sae)
|
||||
test_accuracy = evaluate(cnn_with_sae, testloader)
|
||||
print(
|
||||
datetime.now().strftime("%H:%M"),
|
||||
f"{step=}, train_loss={train_loss.item()}, "
|
||||
f"test_accuracy={test_accuracy.item()}"
|
||||
f"test_accuracy={test_accuracy.item()}",
|
||||
)
|
||||
tensorboard.add_scalar("accu", test_accuracy.item(), step)
|
||||
return sae
|
||||
|
||||
|
||||
|
@ -144,6 +150,8 @@ def train_sae(
|
|||
learning_rate,
|
||||
steps,
|
||||
print_every,
|
||||
tensorboard,
|
||||
λ,
|
||||
):
|
||||
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
||||
sae = SAE(activ_size, hidden_size, key)
|
||||
|
@ -157,4 +165,6 @@ def train_sae(
|
|||
optim,
|
||||
steps,
|
||||
print_every,
|
||||
tensorboard,
|
||||
λ
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user