diff --git a/requirements.txt b/requirements.txt index 44c9804..62d66ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ build debugpy ipython +jax jupytext matplotlib-backend-kitty numpy diff --git a/src/root.py b/src/root.py new file mode 100644 index 0000000..204a5c0 --- /dev/null +++ b/src/root.py @@ -0,0 +1,46 @@ +#! /usr/bin/env python3 + +import os +from pathlib import Path +from importlib.resources import files +import __main__ + + +def root_dir(): + if "FILE_NAME" in os.environ: + return Path(os.environ["FILE_NAME"]).parent() + + if __main__.__package__: + return files(__main__.__package__) + + # Use the folder of the main file as toml_dir. + if "__file__" in dir(__main__): + toml_dir = Path(__main__.__file__).parent + return toml_dir + + # Find the path of the ipython notebook. + try: + import ipynbname + + return ipynbname.path().parent + except IndexError: + pass + + return Path(os.path.abspath(".")) + + +def root_file(): + if "FILE_NAME" in os.environ: + return Path(os.environ["FILE_NAME"]) + + if "__file__" in dir(__main__): + return Path(__main__.__file__) + + try: + import ipynbname + + return ipynbname.path() + except IndexError: + pass + + raise NotImplementedError diff --git a/src/timing.py b/src/timing.py new file mode 100644 index 0000000..b7fac76 --- /dev/null +++ b/src/timing.py @@ -0,0 +1,35 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 +from time import time + +class ReportTime: + def __init__(self): + self.i = 0 + self.last_t = time() + self.t = {} + + def __call__(self, name=None): + t = time() + name = self.i if name is None else name + self.t[name] = [t - self.last_t] + self.i += 1 + self.last_t = time() + + def next(self): + self.i = 0 + self.__call__ = self.call_next + + def call_next(self, name=None): + t = time() + name = self.i if name is None else name + self.t[name].append(t - self.last_t) + self.i += 1 + self.last_t = time() + + def __repr__(self): + out = [] + for i, t in self.t.items(): + t = sum(t) / len(t) + out.append(f"{i}={t:.4f}") + return "\n".join(out) + diff --git a/src/vis.py b/src/vis.py new file mode 100644 index 0000000..8328ecd --- /dev/null +++ b/src/vis.py @@ -0,0 +1,46 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 +import jax +import numpy as np +from PIL import Image + +def to_rgb(x): + x = x / x.max() * 255 + x = np.broadcast_to(x, (3, x.shape[1], x.shape[2])) + return x.astype("uint8") + + +def to_img(x, scale_up=10): + x = to_rgb(x) + + img = np.empty( + (x.shape[1] * scale_up, x.shape[2] * scale_up, x.shape[0]), dtype=np.uint8 + ) + + for color in range(x.shape[0]): + for row in range(x.shape[1]): + for col in range(x.shape[2]): + pixel = x[color, row, col] + new_row = row * scale_up + new_col = col * scale_up + img[ + new_row : new_row + scale_up, + new_col : new_col + scale_up, + color, + ] = pixel + if x.shape[0] == 1: + return Image.fromarray(img.squeeze()) + else: + return Image.fromarray(img, "RGB") + + +def color_mask(x, mask, color=np.array([255, 0, 0])): + if len(mask.shape) == 2: + mask = np.broadcast_to(mask, (3, mask.shape[0], mask.shape[1])) + + x = to_rgb(x) + x = x * (1 - mask) + color = jax.vmap(lambda x, y: x * y, (0, 0))(color, mask) + x = x + color + + return x