root file and some debug tools

This commit is contained in:
JJJHolscher 2023-09-28 15:04:29 +02:00
parent abbb64d8a1
commit 1609b3da89
4 changed files with 128 additions and 0 deletions

View File

@ -1,6 +1,7 @@
build
debugpy
ipython
jax
jupytext
matplotlib-backend-kitty
numpy

46
src/root.py Normal file
View File

@ -0,0 +1,46 @@
#! /usr/bin/env python3
import os
from pathlib import Path
from importlib.resources import files
import __main__
def root_dir():
if "FILE_NAME" in os.environ:
return Path(os.environ["FILE_NAME"]).parent()
if __main__.__package__:
return files(__main__.__package__)
# Use the folder of the main file as toml_dir.
if "__file__" in dir(__main__):
toml_dir = Path(__main__.__file__).parent
return toml_dir
# Find the path of the ipython notebook.
try:
import ipynbname
return ipynbname.path().parent
except IndexError:
pass
return Path(os.path.abspath("."))
def root_file():
if "FILE_NAME" in os.environ:
return Path(os.environ["FILE_NAME"])
if "__file__" in dir(__main__):
return Path(__main__.__file__)
try:
import ipynbname
return ipynbname.path()
except IndexError:
pass
raise NotImplementedError

35
src/timing.py Normal file
View File

@ -0,0 +1,35 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
from time import time
class ReportTime:
def __init__(self):
self.i = 0
self.last_t = time()
self.t = {}
def __call__(self, name=None):
t = time()
name = self.i if name is None else name
self.t[name] = [t - self.last_t]
self.i += 1
self.last_t = time()
def next(self):
self.i = 0
self.__call__ = self.call_next
def call_next(self, name=None):
t = time()
name = self.i if name is None else name
self.t[name].append(t - self.last_t)
self.i += 1
self.last_t = time()
def __repr__(self):
out = []
for i, t in self.t.items():
t = sum(t) / len(t)
out.append(f"{i}={t:.4f}")
return "\n".join(out)

46
src/vis.py Normal file
View File

@ -0,0 +1,46 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
import jax
import numpy as np
from PIL import Image
def to_rgb(x):
x = x / x.max() * 255
x = np.broadcast_to(x, (3, x.shape[1], x.shape[2]))
return x.astype("uint8")
def to_img(x, scale_up=10):
x = to_rgb(x)
img = np.empty(
(x.shape[1] * scale_up, x.shape[2] * scale_up, x.shape[0]), dtype=np.uint8
)
for color in range(x.shape[0]):
for row in range(x.shape[1]):
for col in range(x.shape[2]):
pixel = x[color, row, col]
new_row = row * scale_up
new_col = col * scale_up
img[
new_row : new_row + scale_up,
new_col : new_col + scale_up,
color,
] = pixel
if x.shape[0] == 1:
return Image.fromarray(img.squeeze())
else:
return Image.fromarray(img, "RGB")
def color_mask(x, mask, color=np.array([255, 0, 0])):
if len(mask.shape) == 2:
mask = np.broadcast_to(mask, (3, mask.shape[0], mask.shape[1]))
x = to_rgb(x)
x = x * (1 - mask)
color = jax.vmap(lambda x, y: x * y, (0, 0))(color, mask)
x = x + color
return x