root file and some debug tools
This commit is contained in:
parent
abbb64d8a1
commit
1609b3da89
|
@ -1,6 +1,7 @@
|
||||||
build
|
build
|
||||||
debugpy
|
debugpy
|
||||||
ipython
|
ipython
|
||||||
|
jax
|
||||||
jupytext
|
jupytext
|
||||||
matplotlib-backend-kitty
|
matplotlib-backend-kitty
|
||||||
numpy
|
numpy
|
||||||
|
|
46
src/root.py
Normal file
46
src/root.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from importlib.resources import files
|
||||||
|
import __main__
|
||||||
|
|
||||||
|
|
||||||
|
def root_dir():
|
||||||
|
if "FILE_NAME" in os.environ:
|
||||||
|
return Path(os.environ["FILE_NAME"]).parent()
|
||||||
|
|
||||||
|
if __main__.__package__:
|
||||||
|
return files(__main__.__package__)
|
||||||
|
|
||||||
|
# Use the folder of the main file as toml_dir.
|
||||||
|
if "__file__" in dir(__main__):
|
||||||
|
toml_dir = Path(__main__.__file__).parent
|
||||||
|
return toml_dir
|
||||||
|
|
||||||
|
# Find the path of the ipython notebook.
|
||||||
|
try:
|
||||||
|
import ipynbname
|
||||||
|
|
||||||
|
return ipynbname.path().parent
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return Path(os.path.abspath("."))
|
||||||
|
|
||||||
|
|
||||||
|
def root_file():
|
||||||
|
if "FILE_NAME" in os.environ:
|
||||||
|
return Path(os.environ["FILE_NAME"])
|
||||||
|
|
||||||
|
if "__file__" in dir(__main__):
|
||||||
|
return Path(__main__.__file__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ipynbname
|
||||||
|
|
||||||
|
return ipynbname.path()
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
raise NotImplementedError
|
35
src/timing.py
Normal file
35
src/timing.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
class ReportTime:
|
||||||
|
def __init__(self):
|
||||||
|
self.i = 0
|
||||||
|
self.last_t = time()
|
||||||
|
self.t = {}
|
||||||
|
|
||||||
|
def __call__(self, name=None):
|
||||||
|
t = time()
|
||||||
|
name = self.i if name is None else name
|
||||||
|
self.t[name] = [t - self.last_t]
|
||||||
|
self.i += 1
|
||||||
|
self.last_t = time()
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
self.i = 0
|
||||||
|
self.__call__ = self.call_next
|
||||||
|
|
||||||
|
def call_next(self, name=None):
|
||||||
|
t = time()
|
||||||
|
name = self.i if name is None else name
|
||||||
|
self.t[name].append(t - self.last_t)
|
||||||
|
self.i += 1
|
||||||
|
self.last_t = time()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
out = []
|
||||||
|
for i, t in self.t.items():
|
||||||
|
t = sum(t) / len(t)
|
||||||
|
out.append(f"{i}={t:.4f}")
|
||||||
|
return "\n".join(out)
|
||||||
|
|
46
src/vis.py
Normal file
46
src/vis.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
import jax
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
def to_rgb(x):
|
||||||
|
x = x / x.max() * 255
|
||||||
|
x = np.broadcast_to(x, (3, x.shape[1], x.shape[2]))
|
||||||
|
return x.astype("uint8")
|
||||||
|
|
||||||
|
|
||||||
|
def to_img(x, scale_up=10):
|
||||||
|
x = to_rgb(x)
|
||||||
|
|
||||||
|
img = np.empty(
|
||||||
|
(x.shape[1] * scale_up, x.shape[2] * scale_up, x.shape[0]), dtype=np.uint8
|
||||||
|
)
|
||||||
|
|
||||||
|
for color in range(x.shape[0]):
|
||||||
|
for row in range(x.shape[1]):
|
||||||
|
for col in range(x.shape[2]):
|
||||||
|
pixel = x[color, row, col]
|
||||||
|
new_row = row * scale_up
|
||||||
|
new_col = col * scale_up
|
||||||
|
img[
|
||||||
|
new_row : new_row + scale_up,
|
||||||
|
new_col : new_col + scale_up,
|
||||||
|
color,
|
||||||
|
] = pixel
|
||||||
|
if x.shape[0] == 1:
|
||||||
|
return Image.fromarray(img.squeeze())
|
||||||
|
else:
|
||||||
|
return Image.fromarray(img, "RGB")
|
||||||
|
|
||||||
|
|
||||||
|
def color_mask(x, mask, color=np.array([255, 0, 0])):
|
||||||
|
if len(mask.shape) == 2:
|
||||||
|
mask = np.broadcast_to(mask, (3, mask.shape[0], mask.shape[1]))
|
||||||
|
|
||||||
|
x = to_rgb(x)
|
||||||
|
x = x * (1 - mask)
|
||||||
|
color = jax.vmap(lambda x, y: x * y, (0, 0))(color, mask)
|
||||||
|
x = x + color
|
||||||
|
|
||||||
|
return x
|
Loading…
Reference in New Issue
Block a user