diff --git a/src/cnn.py b/src/cnn.py index 4e3dbe118..fa0bbfe1a 100644 --- a/src/cnn.py +++ b/src/cnn.py @@ -13,6 +13,7 @@ import optax # https://github.com/deepmind/optax from jaxtyping import Float # https://github.com/google/jaxtyping from jaxtyping import Array, Int, PyTree from torch.utils.data import DataLoader +from tqdm import tqdm class CNN(eqx.Module): @@ -123,18 +124,29 @@ def train_loop( while True: yield from trainloader - for step, (x, y) in zip(range(steps), infinite_trainloader()): + test_loss, test_accuracy = 0, 0 + progress_bar = tqdm( + zip(range(steps), infinite_trainloader()), + desc=f"training a cnn", + unit="step", + total=steps + ) + + for step, (x, y) in progress_bar: # PyTorch dataloaders give PyTorch tensors by default, # so convert them to NumPy arrays. x = x.numpy() y = y.numpy() model, opt_state, train_loss = make_step(model, opt_state, x, y) + if (step % print_every) == 0 or (step == steps - 1): test_loss, test_accuracy = evaluate(model, testloader) - print( - f"{step=}, train_loss={train_loss.item()}, " - f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}" - ) + + progress_bar.set_postfix({ + "trn-loss": f"{train_loss.item()}%", + "tst-loss": f"{test_loss}%", + "tst-acc": f"{test_accuracy}%" + }) return model