sparse_autoencoder/doc/reconstruction-error.qmd
2024-08-17 15:49:39 +02:00

246 lines
7.1 KiB
Plaintext

Load the features and the generations.
```{python}
from pathlib import Path
features2image_diffusion_dir = Path("../features2image_diffusion")
run_dir = features2image_diffusion_dir / "xyz/run/4409b6282a7d05f0b08880228d6d6564011fa40be412073ff05aff8bf2dc49fa"
batch_size = 64
shuffle = True
```
```{python}
from collections import OrderedDict
from itertools import islice
import jax.numpy as jnp
import numpy as np
import torchvision
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
class EditedFeaturesAndGenerations(Dataset):
def __init__(self, gen_path, mnist_path="./res/MNIST", progress_bar=True, transform=None):
self.dir = gen_path
self.subdirs = OrderedDict()
mnist_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)),
])
self.mnist = torchvision.datasets.MNIST(
str(mnist_path),
train=True,
download=True,
transform=mnist_transform,
)
a_generations_file = next(iter(filter(
lambda p: p.name.split("-")[-1] == "generations.npy",
Path(gen_path).rglob("*")
)))
self.generations_per_file = self.load_numpy(a_generations_file).shape[0]
self.len = 0
subdir_iter = filter(lambda d: d.is_dir(), Path(gen_path).iterdir())
if progress_bar:
subdir_iter = tqdm(
list(subdir_iter),
desc="loading features and generations",
postfix={"len": self.len}
)
for d in subdir_iter:
# convert /dir/parent/id-type.suffix to ./parent/id
files = sorted(map(
lambda f: (f.parent / f.stem.split("-")[0]).relative_to(self.dir),
# only iterate over generated images
filter(lambda f: f.suffix == ".png", d.rglob("*"))
))
self.subdirs[int(d.name)] = files
self.len += len(files) * self.generations_per_file
if progress_bar:
subdir_iter.set_postfix({"len": self.len}) # pyright: ignore[reportAttributeAccessIssue]
self.transform = torchvision.transforms.Compose(
[
# torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)),
]
) if transform is None else transform
@staticmethod
def load_numpy(path):
return np.load(path, allow_pickle=False)
def __getitem__(self, idx):
# find which directory to sample from
subdirs = iter(self.subdirs.items())
d, files = next(subdirs)
while idx // self.generations_per_file > len(files):
idx -= len(files) * self.generations_per_file
d, files = next(subdirs)
original, label = self.mnist[d]
file = files[idx // self.generations_per_file]
features = self.load_numpy(self.dir / file.parent / (file.stem + "-features.npy"))
generation = self.load_numpy(self.dir / file.parent / (file.stem + "-generations.npy"))
# We only get a single generation from the file that contains multiple generations.
generation = generation[idx % self.generations_per_file]
# Temporarily convert the array to a torch tensor to apply the transformations.
generation = self.transform(torch.Tensor(generation)).numpy()
return str(file), generation, original, features, label
def iter_unedited(self):
for d in self.subdirs:
original, label = self.mnist[d]
d = str(d)
features = self.load_numpy(self.dir / d / "unedited-features.npy")
for generation in self.load_numpy(self.dir / d / "unedited-generations.npy"):
generation = self.transform(torch.Tensor(generation)).numpy()
yield d + "/unedited", generation, original, features, label
def __len__(self):
return self.len
dataset = EditedFeaturesAndGenerations(run_dir)
data_loader = DataLoader(
dataset,
collate_fn=lambda batch: tuple(map(np.array, zip(*batch))),
batch_size=batch_size,
shuffle=shuffle
)
path, generation, original, features, label = next(iter(data_loader))
print(path.shape, generation.shape, original.shape, features.shape, label.shape)
```
```{python}
_, gen_img, orig_img, _, _ = next(dataset.iter_unedited())
print(orig_img.max(), gen_img.max())
print(orig_img.min(), gen_img.min())
```
```{python}
import jo3mnist
import matplotlib.pyplot as plt
_, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(jo3mnist.to_img(orig_img))
axes[1].imshow(jo3mnist.to_img(gen_img))
(ax.axis('off') for ax in axes)
plt.show()
```
Load the CNN and the SAE and sow the CNN.
```{python}
import tomllib
# the config.toml in the features2image_diffusion's run_dir describes which SAE run results it used.
sae_run_dir = features2image_diffusion_dir / tomllib.loads((run_dir / "config.toml").read_text())["eval"][0]["feature_dir"]
# the sae_run_dir contains configuration infarmation.
# We need this for knowing which cnn and sae to load and at what place to intercept cnn activations.
sae_config = tomllib.loads((sae_run_dir / "config.toml").read_text())
cnn_path = sae_config["cnn_storage"]
sae_path = sae_run_dir / "sae.eqx"
sow_layer = sae_config["sae"][0]["layer"]
print(cnn_path, sae_path, sow_layer, sep="\n")
```
```{python}
import jax
import equinox as eqx
import jo3util.eqx as jo3eqx
from src.sae import SAE
from src.cnn import CNN
sae = jo3eqx.load(sae_path, SAE)
# equinox wants to see an example cnn before loading the weights into it
cnn = CNN(jax.random.PRNGKey(0))
# load the weights into the example cnn
cnn = eqx.tree_deserialise_leaves(cnn_path, cnn)
# sow the cnn, so we can intercept intermediate activations every forward pass
cnn = jo3eqx.sow(lambda m: m.layers[sow_layer], cnn)
```
Now reconstruct features from the generated images and compare the reconstructed features with the original.
```{python}
path, gens, orig, orig_features, label = tuple(map(np.array, zip(*dataset.iter_unedited())))
print(path.shape, gens.shape, orig_features.shape)
# Pass the images generated by the diffusion through the cnn and fetch the intermediate activations.
activ, pred = jax.vmap(cnn)(gens) # pyright: ignore[reportArgumentType]
print(activ.shape, pred.shape)
# Reconstruct the features by encoding the intermediate activations with the SAE.
recon_features = jax.vmap(sae.encode)(activ)
print(recon_features.shape)
# Look at the difference between the original features and the reconstructed features.
mse = jnp.sum((orig_features - recon_features)**2, axis=-1) # pyright: ignore[reportAssignmentType]
print(mse)
```
If we use generated images from edited features instead, the reconstruction difference is larger.
```{python}
path, gens, orig, orig_features, label = next(iter(data_loader))
activ, pred = jax.vmap(cnn)(gens) # pyright: ignore[reportArgumentType]
recon_features = jax.vmap(sae.encode)(activ)
mse = jnp.sum((orig_features - recon_features)**2, axis=-1) # pyright: ignore[reportAssignmentType]
print(mse)
```