Load the features and the generations. ```{python} from pathlib import Path features2image_diffusion_dir = Path("../features2image_diffusion") run_dir = features2image_diffusion_dir / "xyz/run/4409b6282a7d05f0b08880228d6d6564011fa40be412073ff05aff8bf2dc49fa" batch_size = 64 shuffle = True ``` ```{python} from collections import OrderedDict from itertools import islice import jax.numpy as jnp import numpy as np import torchvision import torch from torch.utils.data import Dataset, DataLoader from tqdm import tqdm class EditedFeaturesAndGenerations(Dataset): def __init__(self, gen_path, mnist_path="./res/MNIST", progress_bar=True, transform=None): self.dir = gen_path self.subdirs = OrderedDict() mnist_transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (0.5,)), ]) self.mnist = torchvision.datasets.MNIST( str(mnist_path), train=True, download=True, transform=mnist_transform, ) a_generations_file = next(iter(filter( lambda p: p.name.split("-")[-1] == "generations.npy", Path(gen_path).rglob("*") ))) self.generations_per_file = self.load_numpy(a_generations_file).shape[0] self.len = 0 subdir_iter = filter(lambda d: d.is_dir(), Path(gen_path).iterdir()) if progress_bar: subdir_iter = tqdm( list(subdir_iter), desc="loading features and generations", postfix={"len": self.len} ) for d in subdir_iter: # convert /dir/parent/id-type.suffix to ./parent/id files = sorted(map( lambda f: (f.parent / f.stem.split("-")[0]).relative_to(self.dir), # only iterate over generated images filter(lambda f: f.suffix == ".png", d.rglob("*")) )) self.subdirs[int(d.name)] = files self.len += len(files) * self.generations_per_file if progress_bar: subdir_iter.set_postfix({"len": self.len}) # pyright: ignore[reportAttributeAccessIssue] self.transform = torchvision.transforms.Compose( [ # torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (0.5,)), ] ) if transform is None else transform @staticmethod def load_numpy(path): return np.load(path, allow_pickle=False) def __getitem__(self, idx): # find which directory to sample from subdirs = iter(self.subdirs.items()) d, files = next(subdirs) while idx // self.generations_per_file > len(files): idx -= len(files) * self.generations_per_file d, files = next(subdirs) original, label = self.mnist[d] file = files[idx // self.generations_per_file] features = self.load_numpy(self.dir / file.parent / (file.stem + "-features.npy")) generation = self.load_numpy(self.dir / file.parent / (file.stem + "-generations.npy")) # We only get a single generation from the file that contains multiple generations. generation = generation[idx % self.generations_per_file] # Temporarily convert the array to a torch tensor to apply the transformations. generation = self.transform(torch.Tensor(generation)).numpy() return str(file), generation, original, features, label def iter_unedited(self): for d in self.subdirs: original, label = self.mnist[d] d = str(d) features = self.load_numpy(self.dir / d / "unedited-features.npy") for generation in self.load_numpy(self.dir / d / "unedited-generations.npy"): generation = self.transform(torch.Tensor(generation)).numpy() yield d + "/unedited", generation, original, features, label def __len__(self): return self.len dataset = EditedFeaturesAndGenerations(run_dir) data_loader = DataLoader( dataset, collate_fn=lambda batch: tuple(map(np.array, zip(*batch))), batch_size=batch_size, shuffle=shuffle ) path, generation, original, features, label = next(iter(data_loader)) print(path.shape, generation.shape, original.shape, features.shape, label.shape) ``` ```{python} _, gen_img, orig_img, _, _ = next(dataset.iter_unedited()) print(orig_img.max(), gen_img.max()) print(orig_img.min(), gen_img.min()) ``` ```{python} import jo3mnist import matplotlib.pyplot as plt _, axes = plt.subplots(1, 2, figsize=(10, 5)) axes[0].imshow(jo3mnist.to_img(orig_img)) axes[1].imshow(jo3mnist.to_img(gen_img)) (ax.axis('off') for ax in axes) plt.show() ``` Load the CNN and the SAE and sow the CNN. ```{python} import tomllib # the config.toml in the features2image_diffusion's run_dir describes which SAE run results it used. sae_run_dir = features2image_diffusion_dir / tomllib.loads((run_dir / "config.toml").read_text())["eval"][0]["feature_dir"] # the sae_run_dir contains configuration infarmation. # We need this for knowing which cnn and sae to load and at what place to intercept cnn activations. sae_config = tomllib.loads((sae_run_dir / "config.toml").read_text()) cnn_path = sae_config["cnn_storage"] sae_path = sae_run_dir / "sae.eqx" sow_layer = sae_config["sae"][0]["layer"] print(cnn_path, sae_path, sow_layer, sep="\n") ``` ```{python} import jax import equinox as eqx import jo3util.eqx as jo3eqx from src.sae import SAE from src.cnn import CNN sae = jo3eqx.load(sae_path, SAE) # equinox wants to see an example cnn before loading the weights into it cnn = CNN(jax.random.PRNGKey(0)) # load the weights into the example cnn cnn = eqx.tree_deserialise_leaves(cnn_path, cnn) # sow the cnn, so we can intercept intermediate activations every forward pass cnn = jo3eqx.sow(lambda m: m.layers[sow_layer], cnn) ``` Now reconstruct features from the generated images and compare the reconstructed features with the original. ```{python} path, gens, orig, orig_features, label = tuple(map(np.array, zip(*dataset.iter_unedited()))) print(path.shape, gens.shape, orig_features.shape) # Pass the images generated by the diffusion through the cnn and fetch the intermediate activations. activ, pred = jax.vmap(cnn)(gens) # pyright: ignore[reportArgumentType] print(activ.shape, pred.shape) # Reconstruct the features by encoding the intermediate activations with the SAE. recon_features = jax.vmap(sae.encode)(activ) print(recon_features.shape) # Look at the difference between the original features and the reconstructed features. mse = jnp.sum((orig_features - recon_features)**2, axis=-1) # pyright: ignore[reportAssignmentType] print(mse) ``` If we use generated images from edited features instead, the reconstruction difference is larger. ```{python} path, gens, orig, orig_features, label = next(iter(data_loader)) activ, pred = jax.vmap(cnn)(gens) # pyright: ignore[reportArgumentType] recon_features = jax.vmap(sae.encode)(activ) mse = jnp.sum((orig_features - recon_features)**2, axis=-1) # pyright: ignore[reportAssignmentType] print(mse) ```