tensorboard, low l1, lower lr
This commit is contained in:
parent
bc7647cb43
commit
ecd7d3bd65
44
config.toml
44
config.toml
|
@ -1,47 +1,33 @@
|
||||||
batch_size = 64
|
batch_size = 64
|
||||||
steps = 500000
|
steps = 1000000
|
||||||
print_every = 10000
|
print_every = 50000
|
||||||
seed = 0
|
seed = 0
|
||||||
cnn_storage = "./res/cnn.eqx"
|
cnn_storage = "./res/cnn.eqx"
|
||||||
|
|
||||||
[[sae]]
|
[[sae]]
|
||||||
layer = 6
|
layer = 6
|
||||||
hidden_size = 1000
|
hidden_size = 300
|
||||||
input_size = 64
|
|
||||||
learning_rate = 0.1
|
|
||||||
|
|
||||||
[[sae]]
|
|
||||||
layer = 6
|
|
||||||
hidden_size = 1000
|
|
||||||
input_size = 64
|
|
||||||
learning_rate = 3e-2
|
|
||||||
|
|
||||||
[[sae]]
|
|
||||||
layer = 6
|
|
||||||
hidden_size = 1000
|
|
||||||
input_size = 64
|
|
||||||
learning_rate = 1e-2
|
|
||||||
|
|
||||||
[[sae]]
|
|
||||||
layer = 6
|
|
||||||
hidden_size = 1000
|
|
||||||
input_size = 64
|
|
||||||
learning_rate = 3e-3
|
|
||||||
|
|
||||||
[[sae]]
|
|
||||||
layer = 6
|
|
||||||
hidden_size = 1000
|
|
||||||
input_size = 64
|
input_size = 64
|
||||||
learning_rate = 1e-3
|
learning_rate = 1e-3
|
||||||
|
l1 = 3e-4 # from Neel Nanda's sae git
|
||||||
|
|
||||||
[[sae]]
|
[[sae]]
|
||||||
layer = 6
|
layer = 6
|
||||||
hidden_size = 1000
|
hidden_size = 300
|
||||||
input_size = 64
|
input_size = 64
|
||||||
learning_rate = 3e-4
|
learning_rate = 3e-4
|
||||||
|
l1 = 3e-4
|
||||||
|
|
||||||
[[sae]]
|
[[sae]]
|
||||||
layer = 6
|
layer = 6
|
||||||
hidden_size = 1000
|
hidden_size = 300
|
||||||
input_size = 64
|
input_size = 64
|
||||||
learning_rate = 1e-4
|
learning_rate = 1e-4
|
||||||
|
l1 = 3e-4
|
||||||
|
|
||||||
|
[[sae]]
|
||||||
|
layer = 6
|
||||||
|
hidden_size = 300
|
||||||
|
input_size = 64
|
||||||
|
learning_rate = 3e-5
|
||||||
|
l1 = 3e-4
|
||||||
|
|
|
@ -10,6 +10,7 @@ import jax.numpy as jnp
|
||||||
import jo3mnist
|
import jo3mnist
|
||||||
from jo3util import eqx as jo3eqx
|
from jo3util import eqx as jo3eqx
|
||||||
from jo3util.root import run_dir
|
from jo3util.root import run_dir
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .cnn import CNN, train_cnn
|
from .cnn import CNN, train_cnn
|
||||||
|
@ -49,6 +50,7 @@ for sae_hyperparams in O.sae:
|
||||||
sae_hyperparams.learning_rate,
|
sae_hyperparams.learning_rate,
|
||||||
O.steps,
|
O.steps,
|
||||||
O.print_every,
|
O.print_every,
|
||||||
|
SummaryWriter(sae_dir / "log")
|
||||||
)
|
)
|
||||||
|
|
||||||
sae_dir.mkdir()
|
sae_dir.mkdir()
|
||||||
|
|
36
src/sae.py
36
src/sae.py
|
@ -5,6 +5,7 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
import equinox as eqx
|
import equinox as eqx
|
||||||
|
@ -15,6 +16,7 @@ import optax
|
||||||
from jaxtyping import Array, Float, Int, PyTree
|
from jaxtyping import Array, Float, Int, PyTree
|
||||||
from jo3util.eqx import insert_after, sow
|
from jo3util.eqx import insert_after, sow
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
class SAE(eqx.Module):
|
class SAE(eqx.Module):
|
||||||
|
@ -87,10 +89,10 @@ def make_step(
|
||||||
optim,
|
optim,
|
||||||
opt_state: PyTree,
|
opt_state: PyTree,
|
||||||
x: Float[Array, "batch 1 28 28"],
|
x: Float[Array, "batch 1 28 28"],
|
||||||
y: Int[Array, " batch"],
|
λ: float,
|
||||||
):
|
):
|
||||||
activ = jax.vmap(model)(x)[0]
|
activ = jax.vmap(model)(x)[0]
|
||||||
loss_value, grads = SAE.loss(sae, activ[0], 1)
|
loss_value, grads = SAE.loss(sae, activ[0], λ)
|
||||||
updates, opt_state = optim.update(grads, opt_state, sae)
|
updates, opt_state = optim.update(grads, opt_state, sae)
|
||||||
sae = eqx.apply_updates(sae, updates)
|
sae = eqx.apply_updates(sae, updates)
|
||||||
return sae, opt_state, loss_value
|
return sae, opt_state, loss_value
|
||||||
|
@ -98,39 +100,43 @@ def make_step(
|
||||||
|
|
||||||
def train_loop(
|
def train_loop(
|
||||||
sae: SAE,
|
sae: SAE,
|
||||||
model: eqx.Module,
|
cnn: eqx.Module,
|
||||||
sae_pos: Callable,
|
sae_pos: Callable,
|
||||||
trainloader: DataLoader,
|
trainloader: DataLoader,
|
||||||
testloader: DataLoader,
|
testloader: DataLoader,
|
||||||
optim: optax.GradientTransformation,
|
optim: optax.GradientTransformation,
|
||||||
steps: int,
|
steps: int,
|
||||||
print_every: int,
|
print_every: int,
|
||||||
|
tensorboard,
|
||||||
|
λ,
|
||||||
) -> eqx.Module:
|
) -> eqx.Module:
|
||||||
opt_state = optim.init(eqx.filter(sae, eqx.is_array))
|
opt_state = optim.init(eqx.filter(sae, eqx.is_array))
|
||||||
|
|
||||||
model = sow(sae_pos, model)
|
cnn = sow(sae_pos, cnn)
|
||||||
|
|
||||||
print(f"test_accuracy={evaluate(model, testloader).item()}")
|
print(f"test_accuracy={evaluate(cnn, testloader).item()}")
|
||||||
|
|
||||||
|
|
||||||
# Loop over our training dataset as many times as we need.
|
# Loop over our training dataset as many times as we need.
|
||||||
def infinite_trainloader():
|
def infinite_trainloader():
|
||||||
while True:
|
while True:
|
||||||
yield from trainloader
|
yield from trainloader
|
||||||
|
|
||||||
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
for step, (x, _) in zip(range(steps), infinite_trainloader()):
|
||||||
# PyTorch dataloaders give PyTorch tensors by default,
|
# PyTorch dataloaders give PyTorch tensors by default,
|
||||||
# so convert them to NumPy arrays.
|
# so convert them to NumPy arrays.
|
||||||
x = x.numpy()
|
sae, opt_state, train_loss = make_step(
|
||||||
y = y.numpy()
|
sae, cnn, optim, opt_state, x.numpy(), λ
|
||||||
sae, opt_state, train_loss = make_step(sae, model, optim, opt_state, x, y)
|
)
|
||||||
|
tensorboard.add_scalar("loss", train_loss.item(), step)
|
||||||
if (step % print_every) == 0 or (step == steps - 1):
|
if (step % print_every) == 0 or (step == steps - 1):
|
||||||
model_with_sae = insert_after(sae_pos, model, sae)
|
cnn_with_sae = insert_after(sae_pos, cnn, sae)
|
||||||
test_accuracy = evaluate(model_with_sae, testloader)
|
test_accuracy = evaluate(cnn_with_sae, testloader)
|
||||||
print(
|
print(
|
||||||
|
datetime.now().strftime("%H:%M"),
|
||||||
f"{step=}, train_loss={train_loss.item()}, "
|
f"{step=}, train_loss={train_loss.item()}, "
|
||||||
f"test_accuracy={test_accuracy.item()}"
|
f"test_accuracy={test_accuracy.item()}",
|
||||||
)
|
)
|
||||||
|
tensorboard.add_scalar("accu", test_accuracy.item(), step)
|
||||||
return sae
|
return sae
|
||||||
|
|
||||||
|
|
||||||
|
@ -144,6 +150,8 @@ def train_sae(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
steps,
|
steps,
|
||||||
print_every,
|
print_every,
|
||||||
|
tensorboard,
|
||||||
|
λ,
|
||||||
):
|
):
|
||||||
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
trainloader, testloader = jo3mnist.load(batch_size=batch_size)
|
||||||
sae = SAE(activ_size, hidden_size, key)
|
sae = SAE(activ_size, hidden_size, key)
|
||||||
|
@ -157,4 +165,6 @@ def train_sae(
|
||||||
optim,
|
optim,
|
||||||
steps,
|
steps,
|
||||||
print_every,
|
print_every,
|
||||||
|
tensorboard,
|
||||||
|
λ
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user