progress bar

This commit is contained in:
JJJHolscher 2024-11-15 12:21:29 +01:00
parent ff6775f2a1
commit 6c74e1d903

View File

@ -13,6 +13,7 @@ import optax # https://github.com/deepmind/optax
from jaxtyping import Float # https://github.com/google/jaxtyping from jaxtyping import Float # https://github.com/google/jaxtyping
from jaxtyping import Array, Int, PyTree from jaxtyping import Array, Int, PyTree
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm
class CNN(eqx.Module): class CNN(eqx.Module):
@ -123,18 +124,29 @@ def train_loop(
while True: while True:
yield from trainloader yield from trainloader
for step, (x, y) in zip(range(steps), infinite_trainloader()): test_loss, test_accuracy = 0, 0
progress_bar = tqdm(
zip(range(steps), infinite_trainloader()),
desc=f"training a cnn",
unit="step",
total=steps
)
for step, (x, y) in progress_bar:
# PyTorch dataloaders give PyTorch tensors by default, # PyTorch dataloaders give PyTorch tensors by default,
# so convert them to NumPy arrays. # so convert them to NumPy arrays.
x = x.numpy() x = x.numpy()
y = y.numpy() y = y.numpy()
model, opt_state, train_loss = make_step(model, opt_state, x, y) model, opt_state, train_loss = make_step(model, opt_state, x, y)
if (step % print_every) == 0 or (step == steps - 1): if (step % print_every) == 0 or (step == steps - 1):
test_loss, test_accuracy = evaluate(model, testloader) test_loss, test_accuracy = evaluate(model, testloader)
print(
f"{step=}, train_loss={train_loss.item()}, " progress_bar.set_postfix({
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}" "trn-loss": f"{train_loss.item()}%",
) "tst-loss": f"{test_loss}%",
"tst-acc": f"{test_accuracy}%"
})
return model return model