progress bar
This commit is contained in:
parent
ff6775f2a1
commit
6c74e1d903
22
src/cnn.py
22
src/cnn.py
|
@ -13,6 +13,7 @@ import optax # https://github.com/deepmind/optax
|
|||
from jaxtyping import Float # https://github.com/google/jaxtyping
|
||||
from jaxtyping import Array, Int, PyTree
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class CNN(eqx.Module):
|
||||
|
@ -123,18 +124,29 @@ def train_loop(
|
|||
while True:
|
||||
yield from trainloader
|
||||
|
||||
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
||||
test_loss, test_accuracy = 0, 0
|
||||
progress_bar = tqdm(
|
||||
zip(range(steps), infinite_trainloader()),
|
||||
desc=f"training a cnn",
|
||||
unit="step",
|
||||
total=steps
|
||||
)
|
||||
|
||||
for step, (x, y) in progress_bar:
|
||||
# PyTorch dataloaders give PyTorch tensors by default,
|
||||
# so convert them to NumPy arrays.
|
||||
x = x.numpy()
|
||||
y = y.numpy()
|
||||
model, opt_state, train_loss = make_step(model, opt_state, x, y)
|
||||
|
||||
if (step % print_every) == 0 or (step == steps - 1):
|
||||
test_loss, test_accuracy = evaluate(model, testloader)
|
||||
print(
|
||||
f"{step=}, train_loss={train_loss.item()}, "
|
||||
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
|
||||
)
|
||||
|
||||
progress_bar.set_postfix({
|
||||
"trn-loss": f"{train_loss.item()}%",
|
||||
"tst-loss": f"{test_loss}%",
|
||||
"tst-acc": f"{test_accuracy}%"
|
||||
})
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user