progress bar
This commit is contained in:
parent
ff6775f2a1
commit
6c74e1d903
22
src/cnn.py
22
src/cnn.py
|
@ -13,6 +13,7 @@ import optax # https://github.com/deepmind/optax
|
||||||
from jaxtyping import Float # https://github.com/google/jaxtyping
|
from jaxtyping import Float # https://github.com/google/jaxtyping
|
||||||
from jaxtyping import Array, Int, PyTree
|
from jaxtyping import Array, Int, PyTree
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
class CNN(eqx.Module):
|
class CNN(eqx.Module):
|
||||||
|
@ -123,18 +124,29 @@ def train_loop(
|
||||||
while True:
|
while True:
|
||||||
yield from trainloader
|
yield from trainloader
|
||||||
|
|
||||||
for step, (x, y) in zip(range(steps), infinite_trainloader()):
|
test_loss, test_accuracy = 0, 0
|
||||||
|
progress_bar = tqdm(
|
||||||
|
zip(range(steps), infinite_trainloader()),
|
||||||
|
desc=f"training a cnn",
|
||||||
|
unit="step",
|
||||||
|
total=steps
|
||||||
|
)
|
||||||
|
|
||||||
|
for step, (x, y) in progress_bar:
|
||||||
# PyTorch dataloaders give PyTorch tensors by default,
|
# PyTorch dataloaders give PyTorch tensors by default,
|
||||||
# so convert them to NumPy arrays.
|
# so convert them to NumPy arrays.
|
||||||
x = x.numpy()
|
x = x.numpy()
|
||||||
y = y.numpy()
|
y = y.numpy()
|
||||||
model, opt_state, train_loss = make_step(model, opt_state, x, y)
|
model, opt_state, train_loss = make_step(model, opt_state, x, y)
|
||||||
|
|
||||||
if (step % print_every) == 0 or (step == steps - 1):
|
if (step % print_every) == 0 or (step == steps - 1):
|
||||||
test_loss, test_accuracy = evaluate(model, testloader)
|
test_loss, test_accuracy = evaluate(model, testloader)
|
||||||
print(
|
|
||||||
f"{step=}, train_loss={train_loss.item()}, "
|
progress_bar.set_postfix({
|
||||||
f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
|
"trn-loss": f"{train_loss.item()}%",
|
||||||
)
|
"tst-loss": f"{test_loss}%",
|
||||||
|
"tst-acc": f"{test_accuracy}%"
|
||||||
|
})
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user