Collect Example

This example is going to show how to collect a intermediary output from a PyTorch model and create a dataset with it, that can be used for analisys or for training probes and other techniques.

[1]:
import os # Path operations
import glob # List files
import shutil # Remove generated files

import torch # PyTorch
from torch.utils.data import Dataset, DataLoader # Creating the example dataset and dataloader

from pytorch_probing import collect, CollectedDataset # Collect dataset and load it

For this example, we create a example model that has two linear layers with a ReLU activation:

[2]:
class ExampleModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()

        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)

        return x
[3]:
input_size = 2
hidden_size = 3
output_size = 1

model = ExampleModel(input_size, hidden_size, output_size)
model.eval()
[3]:
ExampleModel(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (relu): ReLU()
  (linear2): Linear(in_features=3, out_features=1, bias=True)
)

And a example dataset that generates inputs and targets with the value index:

[4]:
class ExampleDataset(Dataset):
    def __init__(self, x_size, y_size, len) -> None:
        super().__init__()

        self._x_size = x_size
        self._y_size = y_size
        self._len = len

    def __len__(self) -> int:
        return self._len

    def __getitem__(self, idx:int):
        return torch.empty(self._x_size).fill_(idx), torch.empty(self._y_size).fill_(idx)

We create the dataset and a dataloader from it:

[5]:
dataset_size = 32
batch_size = 4

dataset = ExampleDataset(input_size, output_size, dataset_size)
dataloader = DataLoader(dataset, batch_size, shuffle=False)

And can finally collect the dataset. We pass the model, the “linear1” output as the intermediary output to collect, the dataloader, a name for the dataset, and enable the saving of the model inputs, targets and predictions. The function will execute the model over the dataloader elements, and store the required values with the intercepted outputs:

[6]:
paths = ["linear1"]

dataset_path = collect(model, paths, dataloader, dataset_name="CollectExample",
                       save_input=True, save_target=True, save_prediction=True)

After collecting, the dataset is stored in the dataset_path path in chunks:

[7]:
pattern = os.path.join(dataset_path, "*.pt")
glob.glob(pattern)
[7]:
['.\\CollectExample\\0.pt',
 '.\\CollectExample\\1.pt',
 '.\\CollectExample\\2.pt',
 '.\\CollectExample\\3.pt',
 '.\\CollectExample\\4.pt',
 '.\\CollectExample\\5.pt',
 '.\\CollectExample\\6.pt',
 '.\\CollectExample\\7.pt']

We can use the CollectedDataset to load the dataset:

[8]:
collected_dataset = CollectedDataset(dataset_path, get_input=True, get_prediction=True, get_target=True)
collected_dataset.name
[8]:
'CollectExample'

It loads all the 32 saved samples:

[9]:
len(collected_dataset)
[9]:
32

We can get the first sample of the dataset:

[10]:
intercepted_output, target, prediction, saved_input = collected_dataset[0]

print("Intercepted Output")
print(intercepted_output, "\n")

print("Target")
print(target, "\n")

print("Prediction")
print(prediction, "\n")

print("Input")
print(saved_input, "\n")

Intercepted Output
{'linear1': tensor([0.5083, 0.6371, 0.4391])}

Target
tensor([0.])

Prediction
tensor([0.5722])

Input
tensor([0., 0.])

And compare with the original values:

[11]:
x, y = dataset[0]

with torch.no_grad():
    pred = model(x)

    linear1_output = model.linear1(x)

print("Original linear1 output")
print(linear1_output, "\n")


print("Original target")
print(y, "\n")

print("Original prediction")
print(pred, "\n")

print("Original input")
print(x, "\n")


Original linear1 output
tensor([0.5083, 0.6371, 0.4391])

Original target
tensor([0.])

Original prediction
tensor([0.5722])

Original input
tensor([0., 0.])

Deletes the generated dataset:

[12]:
shutil.rmtree(dataset_path)