root file and some debug tools
This commit is contained in:
parent
abbb64d8a1
commit
1609b3da89
|
@ -1,6 +1,7 @@
|
|||
build
|
||||
debugpy
|
||||
ipython
|
||||
jax
|
||||
jupytext
|
||||
matplotlib-backend-kitty
|
||||
numpy
|
||||
|
|
46
src/root.py
Normal file
46
src/root.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
#! /usr/bin/env python3
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from importlib.resources import files
|
||||
import __main__
|
||||
|
||||
|
||||
def root_dir():
|
||||
if "FILE_NAME" in os.environ:
|
||||
return Path(os.environ["FILE_NAME"]).parent()
|
||||
|
||||
if __main__.__package__:
|
||||
return files(__main__.__package__)
|
||||
|
||||
# Use the folder of the main file as toml_dir.
|
||||
if "__file__" in dir(__main__):
|
||||
toml_dir = Path(__main__.__file__).parent
|
||||
return toml_dir
|
||||
|
||||
# Find the path of the ipython notebook.
|
||||
try:
|
||||
import ipynbname
|
||||
|
||||
return ipynbname.path().parent
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
return Path(os.path.abspath("."))
|
||||
|
||||
|
||||
def root_file():
|
||||
if "FILE_NAME" in os.environ:
|
||||
return Path(os.environ["FILE_NAME"])
|
||||
|
||||
if "__file__" in dir(__main__):
|
||||
return Path(__main__.__file__)
|
||||
|
||||
try:
|
||||
import ipynbname
|
||||
|
||||
return ipynbname.path()
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
raise NotImplementedError
|
35
src/timing.py
Normal file
35
src/timing.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
#! /usr/bin/env python3
|
||||
# vim:fenc=utf-8
|
||||
from time import time
|
||||
|
||||
class ReportTime:
|
||||
def __init__(self):
|
||||
self.i = 0
|
||||
self.last_t = time()
|
||||
self.t = {}
|
||||
|
||||
def __call__(self, name=None):
|
||||
t = time()
|
||||
name = self.i if name is None else name
|
||||
self.t[name] = [t - self.last_t]
|
||||
self.i += 1
|
||||
self.last_t = time()
|
||||
|
||||
def next(self):
|
||||
self.i = 0
|
||||
self.__call__ = self.call_next
|
||||
|
||||
def call_next(self, name=None):
|
||||
t = time()
|
||||
name = self.i if name is None else name
|
||||
self.t[name].append(t - self.last_t)
|
||||
self.i += 1
|
||||
self.last_t = time()
|
||||
|
||||
def __repr__(self):
|
||||
out = []
|
||||
for i, t in self.t.items():
|
||||
t = sum(t) / len(t)
|
||||
out.append(f"{i}={t:.4f}")
|
||||
return "\n".join(out)
|
||||
|
46
src/vis.py
Normal file
46
src/vis.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
#! /usr/bin/env python3
|
||||
# vim:fenc=utf-8
|
||||
import jax
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
def to_rgb(x):
|
||||
x = x / x.max() * 255
|
||||
x = np.broadcast_to(x, (3, x.shape[1], x.shape[2]))
|
||||
return x.astype("uint8")
|
||||
|
||||
|
||||
def to_img(x, scale_up=10):
|
||||
x = to_rgb(x)
|
||||
|
||||
img = np.empty(
|
||||
(x.shape[1] * scale_up, x.shape[2] * scale_up, x.shape[0]), dtype=np.uint8
|
||||
)
|
||||
|
||||
for color in range(x.shape[0]):
|
||||
for row in range(x.shape[1]):
|
||||
for col in range(x.shape[2]):
|
||||
pixel = x[color, row, col]
|
||||
new_row = row * scale_up
|
||||
new_col = col * scale_up
|
||||
img[
|
||||
new_row : new_row + scale_up,
|
||||
new_col : new_col + scale_up,
|
||||
color,
|
||||
] = pixel
|
||||
if x.shape[0] == 1:
|
||||
return Image.fromarray(img.squeeze())
|
||||
else:
|
||||
return Image.fromarray(img, "RGB")
|
||||
|
||||
|
||||
def color_mask(x, mask, color=np.array([255, 0, 0])):
|
||||
if len(mask.shape) == 2:
|
||||
mask = np.broadcast_to(mask, (3, mask.shape[0], mask.shape[1]))
|
||||
|
||||
x = to_rgb(x)
|
||||
x = x * (1 - mask)
|
||||
color = jax.vmap(lambda x, y: x * y, (0, 0))(color, mask)
|
||||
x = x + color
|
||||
|
||||
return x
|
Loading…
Reference in New Issue
Block a user