Prober example
This example shows how to use Prober to couple already trained probes to a model.
[1]:
import torch # PyTorch
from pytorch_probing import Prober, ParallelModuleDict # Prober and dictionary of modules
We start creating a example model, a simple MLP:
[2]:
class ExampleModel(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.linear1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
[3]:
input_size = 2
hidden_size = 3
output_size = 1
model = ExampleModel(input_size, hidden_size, output_size)
model.eval()
[3]:
ExampleModel(
(linear1): Linear(in_features=2, out_features=3, bias=True)
(relu): ReLU()
(linear2): Linear(in_features=3, out_features=1, bias=True)
)
And a probe. Any torch module can be used as a probe. In this example we gonna use a simple Linear layer:
[4]:
probe_size = 2
probe = torch.nn.Linear(hidden_size, probe_size)
probe.eval()
[4]:
Linear(in_features=3, out_features=2, bias=True)
And we created the Prober, passing it the model and a dictionary mapping the paths of the modules to the probes that must be coupled to its outputs. When a None value is passed, it creates a Identity module, that just pass its inputs to the outputs.
[5]:
probes = {"linear1":probe, "relu":None}
probed_model = Prober(model, probes)
probed_model.eval()
[5]:
Prober(
(_module): ExampleModel(
(linear1): InterceptorLayer(
(_module): Linear(in_features=2, out_features=3, bias=True)
)
(relu): InterceptorLayer(
(_module): ReLU()
)
(linear2): Linear(in_features=3, out_features=1, bias=True)
)
(_probes): ModuleDict(
(linear1): Linear(in_features=3, out_features=2, bias=True)
(relu): Identity()
)
)
Observe that the “Prober” modifies the original model in-place:
[6]:
model
[6]:
ExampleModel(
(linear1): InterceptorLayer(
(_module): Linear(in_features=2, out_features=3, bias=True)
)
(relu): InterceptorLayer(
(_module): ReLU()
)
(linear2): Linear(in_features=3, out_features=1, bias=True)
)
We pass a sample value to the model:
[7]:
inputs = torch.randn([10, 2])
with torch.no_grad():
outputs = probed_model(inputs)
And the output is a tuple with the model output in the first value, and the probes outputs in the second:
[8]:
outputs[0]
[8]:
tensor([[-0.3165],
[-0.3262],
[-0.3362],
[-0.4985],
[-0.2987],
[-0.3520],
[-0.3182],
[-0.3269],
[-0.3418],
[-0.3352]])
[9]:
outputs[1]
[9]:
{'linear1': tensor([[ 0.1408, 0.8607],
[ 0.1683, 0.9637],
[-0.7003, 0.6100],
[ 0.6062, 1.2991],
[-0.1432, 0.7140],
[-1.0194, 0.2385],
[-0.4221, 0.6943],
[-0.6234, 0.5710],
[-0.9021, 0.4412],
[ 0.1946, 1.0828]]),
'relu': tensor([[0.0000, 0.0346, 0.0000],
[0.0000, 0.0534, 0.0000],
[1.3431, 0.0000, 0.0000],
[0.0000, 0.3868, 0.0000],
[0.0000, 0.0000, 0.0000],
[1.3982, 0.0000, 0.3337],
[0.7008, 0.0000, 0.0000],
[1.0109, 0.0000, 0.0000],
[1.5421, 0.0000, 0.0000],
[0.0000, 0.0708, 0.0000]])}
Multiple probes in the same place
We can also use more than one probe in the same place. For showing it, we gonna create a second probe and reduce the probed model to the original model:
[10]:
probe2_size = 1
probe2 = torch.nn.Linear(hidden_size, probe2_size)
probe2.eval()
[10]:
Linear(in_features=3, out_features=1, bias=True)
[11]:
model = probed_model.reduce()
We can than create a ParallelModuleDict with the two probes. We called with some input, the ParallelModuleDict pass the input to all its modules, and return a dictionary with each module output.
[12]:
linear1_probes = ParallelModuleDict({"probe1":probe, "probe2":probe2})
probes = {"linear1":linear1_probes}
probed_model = Prober(model, probes)
[13]:
inputs = torch.randn([10, 2])
with torch.no_grad():
outputs2 = probed_model(inputs)
[14]:
outputs2[1]
[14]:
{'linear1': {'probe1': tensor([[-0.2682, 0.8801],
[ 0.0351, 0.7865],
[-1.2988, 0.1559],
[-0.3947, 0.7546],
[-0.3776, 0.7478],
[-0.4901, 0.4542],
[ 0.1075, 0.9543],
[-0.5112, 0.7165],
[-0.5990, 0.6198],
[-0.1853, 0.7252]]),
'probe2': tensor([[-0.4589],
[-0.3497],
[-0.3711],
[-0.4320],
[-0.4252],
[-0.3211],
[-0.4073],
[-0.4413],
[-0.4184],
[-0.3721]])}}