246 lines
7.1 KiB
Plaintext
246 lines
7.1 KiB
Plaintext
|
|
Load the features and the generations.
|
|
|
|
```{python}
|
|
from pathlib import Path
|
|
features2image_diffusion_dir = Path("../features2image_diffusion")
|
|
run_dir = features2image_diffusion_dir / "xyz/run/4409b6282a7d05f0b08880228d6d6564011fa40be412073ff05aff8bf2dc49fa"
|
|
batch_size = 64
|
|
shuffle = True
|
|
```
|
|
|
|
```{python}
|
|
from collections import OrderedDict
|
|
from itertools import islice
|
|
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
import torchvision
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from tqdm import tqdm
|
|
|
|
class EditedFeaturesAndGenerations(Dataset):
|
|
|
|
def __init__(self, gen_path, mnist_path="./res/MNIST", progress_bar=True, transform=None):
|
|
self.dir = gen_path
|
|
self.subdirs = OrderedDict()
|
|
|
|
mnist_transform = torchvision.transforms.Compose([
|
|
torchvision.transforms.ToTensor(),
|
|
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
|
])
|
|
self.mnist = torchvision.datasets.MNIST(
|
|
str(mnist_path),
|
|
train=True,
|
|
download=True,
|
|
transform=mnist_transform,
|
|
)
|
|
|
|
a_generations_file = next(iter(filter(
|
|
lambda p: p.name.split("-")[-1] == "generations.npy",
|
|
Path(gen_path).rglob("*")
|
|
)))
|
|
self.generations_per_file = self.load_numpy(a_generations_file).shape[0]
|
|
|
|
self.len = 0
|
|
subdir_iter = filter(lambda d: d.is_dir(), Path(gen_path).iterdir())
|
|
if progress_bar:
|
|
subdir_iter = tqdm(
|
|
list(subdir_iter),
|
|
desc="loading features and generations",
|
|
postfix={"len": self.len}
|
|
)
|
|
|
|
for d in subdir_iter:
|
|
# convert /dir/parent/id-type.suffix to ./parent/id
|
|
files = sorted(map(
|
|
lambda f: (f.parent / f.stem.split("-")[0]).relative_to(self.dir),
|
|
# only iterate over generated images
|
|
filter(lambda f: f.suffix == ".png", d.rglob("*"))
|
|
))
|
|
self.subdirs[int(d.name)] = files
|
|
self.len += len(files) * self.generations_per_file
|
|
if progress_bar:
|
|
subdir_iter.set_postfix({"len": self.len}) # pyright: ignore[reportAttributeAccessIssue]
|
|
|
|
self.transform = torchvision.transforms.Compose(
|
|
[
|
|
# torchvision.transforms.ToTensor(),
|
|
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
|
]
|
|
) if transform is None else transform
|
|
|
|
@staticmethod
|
|
def load_numpy(path):
|
|
return np.load(path, allow_pickle=False)
|
|
|
|
def __getitem__(self, idx):
|
|
# find which directory to sample from
|
|
subdirs = iter(self.subdirs.items())
|
|
d, files = next(subdirs)
|
|
while idx // self.generations_per_file > len(files):
|
|
idx -= len(files) * self.generations_per_file
|
|
d, files = next(subdirs)
|
|
|
|
original, label = self.mnist[d]
|
|
file = files[idx // self.generations_per_file]
|
|
features = self.load_numpy(self.dir / file.parent / (file.stem + "-features.npy"))
|
|
generation = self.load_numpy(self.dir / file.parent / (file.stem + "-generations.npy"))
|
|
# We only get a single generation from the file that contains multiple generations.
|
|
generation = generation[idx % self.generations_per_file]
|
|
# Temporarily convert the array to a torch tensor to apply the transformations.
|
|
generation = self.transform(torch.Tensor(generation)).numpy()
|
|
|
|
return str(file), generation, original, features, label
|
|
|
|
def iter_unedited(self):
|
|
for d in self.subdirs:
|
|
original, label = self.mnist[d]
|
|
d = str(d)
|
|
features = self.load_numpy(self.dir / d / "unedited-features.npy")
|
|
for generation in self.load_numpy(self.dir / d / "unedited-generations.npy"):
|
|
generation = self.transform(torch.Tensor(generation)).numpy()
|
|
yield d + "/unedited", generation, original, features, label
|
|
|
|
def __len__(self):
|
|
return self.len
|
|
|
|
dataset = EditedFeaturesAndGenerations(run_dir)
|
|
data_loader = DataLoader(
|
|
dataset,
|
|
collate_fn=lambda batch: tuple(map(np.array, zip(*batch))),
|
|
batch_size=batch_size,
|
|
shuffle=shuffle
|
|
)
|
|
path, generation, original, features, label = next(iter(data_loader))
|
|
print(path.shape, generation.shape, original.shape, features.shape, label.shape)
|
|
```
|
|
|
|
```{python}
|
|
_, gen_img, orig_img, _, _ = next(dataset.iter_unedited())
|
|
print(orig_img.max(), gen_img.max())
|
|
print(orig_img.min(), gen_img.min())
|
|
```
|
|
|
|
```{python}
|
|
import jo3mnist
|
|
import matplotlib.pyplot as plt
|
|
_, axes = plt.subplots(1, 2, figsize=(10, 5))
|
|
axes[0].imshow(jo3mnist.to_img(orig_img))
|
|
axes[1].imshow(jo3mnist.to_img(gen_img))
|
|
(ax.axis('off') for ax in axes)
|
|
plt.show()
|
|
```
|
|
|
|
Load the CNN and the SAE and sow the CNN.
|
|
|
|
```{python}
|
|
import tomllib
|
|
# the config.toml in the features2image_diffusion's run_dir describes which SAE run results it used.
|
|
sae_run_dir = features2image_diffusion_dir / tomllib.loads((run_dir / "config.toml").read_text())["eval"][0]["feature_dir"]
|
|
# the sae_run_dir contains configuration infarmation.
|
|
# We need this for knowing which cnn and sae to load and at what place to intercept cnn activations.
|
|
sae_config = tomllib.loads((sae_run_dir / "config.toml").read_text())
|
|
cnn_path = sae_config["cnn_storage"]
|
|
sae_path = sae_run_dir / "sae.eqx"
|
|
sow_layer = sae_config["sae"][0]["layer"]
|
|
|
|
print(cnn_path, sae_path, sow_layer, sep="\n")
|
|
```
|
|
|
|
```{python}
|
|
import jax
|
|
import equinox as eqx
|
|
import jo3util.eqx as jo3eqx
|
|
|
|
from src.sae import SAE
|
|
from src.cnn import CNN
|
|
|
|
sae = jo3eqx.load(sae_path, SAE)
|
|
|
|
# equinox wants to see an example cnn before loading the weights into it
|
|
cnn = CNN(jax.random.PRNGKey(0))
|
|
# load the weights into the example cnn
|
|
cnn = eqx.tree_deserialise_leaves(cnn_path, cnn)
|
|
# sow the cnn, so we can intercept intermediate activations every forward pass
|
|
cnn = jo3eqx.sow(lambda m: m.layers[sow_layer], cnn)
|
|
```
|
|
|
|
Now reconstruct features from the generated images and compare the reconstructed features with the original.
|
|
|
|
```{python}
|
|
path, gens, orig, orig_features, label = tuple(map(np.array, zip(*dataset.iter_unedited())))
|
|
print(path.shape, gens.shape, orig_features.shape)
|
|
# Pass the images generated by the diffusion through the cnn and fetch the intermediate activations.
|
|
activ, pred = jax.vmap(cnn)(gens) # pyright: ignore[reportArgumentType]
|
|
print(activ.shape, pred.shape)
|
|
# Reconstruct the features by encoding the intermediate activations with the SAE.
|
|
recon_features = jax.vmap(sae.encode)(activ)
|
|
print(recon_features.shape)
|
|
# Look at the difference between the original features and the reconstructed features.
|
|
mse = jnp.sum((orig_features - recon_features)**2, axis=-1) # pyright: ignore[reportAssignmentType]
|
|
print(mse)
|
|
```
|
|
|
|
If we use generated images from edited features instead, the reconstruction difference is larger.
|
|
|
|
```{python}
|
|
path, gens, orig, orig_features, label = next(iter(data_loader))
|
|
activ, pred = jax.vmap(cnn)(gens) # pyright: ignore[reportArgumentType]
|
|
recon_features = jax.vmap(sae.encode)(activ)
|
|
mse = jnp.sum((orig_features - recon_features)**2, axis=-1) # pyright: ignore[reportAssignmentType]
|
|
print(mse)
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|