Diving inside Neural Networks#

This tutorial provides a short practical overview of the Recorder class, which is designed to ease interactions with the internal states of PyTorch Neural Network objects (torch.nn.Module) during or after inference.

from scio.recorder import Recorder

Wrapping and visualizing your Neural Network#

Let us first load an arbitrary Neural Network. We use a lightweight Tiniest architecture trained on CIFAR10 and hosted on ThalesGroup’s hub. We also prepare future input data for this tutorial.

import torch

inputs = torch.rand(5, 3, 32, 32)  # Random inputs with 5 samples
net = torch.hub.load("ThalesGroup/scio:hub", "tiniest", trust_repo=True, verbose=False)
net = net.to(inputs)

To wrap it into a Recorder Net, rnet, one simply needs to specify an input_size (including batch dimension) or provide input_data. This will directly analyze and store the control flow of the model, using the torchinfo library.

rnet = Recorder(net, input_data=inputs[[0]])
rnet  # Visualize layers
Recorder instance for the following network
============================================================================================================================================
Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Param %
============================================================================================================================================
Tiniest (Tiniest)                        [1, 3, 32, 32]            [1, 10]                   --                             --
├─Conv2d (conv1): 1-1                    [1, 3, 32, 32]            [1, 48, 32, 32]           1,344                       1.38%
├─LayerNorm2d (ln1): 1-2                 [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    └─LayerNorm (ln): 2-1               [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
├─Block (l1): 1-3                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-2             [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-3             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-4             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-5             [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-1          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-6                 [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-7                 [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Block (l2): 1-4                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-8             [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-9             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-10            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-11            [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-2          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-12                [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-13                [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Block (l3): 1-5                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-14            [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-15            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-16            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-17            [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-3          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-18                [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-19                [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Conv2d (dsconv): 1-6                   [1, 48, 32, 32]           [1, 80, 16, 16]           3,920                       4.02%
├─LayerNorm2d (ln2): 1-7                 [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    └─LayerNorm (ln): 2-20              [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
├─Block (l4): 1-8                        [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-21            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-22            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-23            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-24            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-4          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-25                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-26                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─Block (l5): 1-9                        [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-27            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-28            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-29            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-30            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-5          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-31                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-32                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─Block (l6): 1-10                       [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-33            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-34            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-35            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-36            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-6          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-37                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-38                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─AdaptiveAvgPool2d (avgpool): 1-11      [1, 80, 16, 16]           [1, 80, 1, 1]             --                             --
├─Linear (fc): 1-12                      [1, 80]                   [1, 10]                   810                         0.83%
============================================================================================================================================
Total params: 97,530
Trainable params: 97,530
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 44.73
============================================================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 9.05
Params size (MB): 0.39
Estimated Total Size (MB): 9.45
============================================================================================================================================
Currently recording: None
============================================================================================================================================

Tip

For summary customization, refer to torchinfo.summary options. For example, it is possible to bound the depth of the representation tree with depth=2.

Note

In the case of dynamic control flow, refer to the force_static_flow argument of Recorder.

In many ways, this wrapper is transparent to the user. For example, one can naturally process data with rnet(inputs).

Selecting layers of interest#

The penultimate line of the above summary reports the layers that are currently set to be recorded (stored in rnet.recording). By default after instantiation, there are none.

print(repr(rnet).split("\n")[-2])  # Show penultimate summary line
Currently recording: None

One can arbitrarily set this using rnet.record() with the depth-idx identifiers from the summary (e.g. 1-9). For example, the following specifies that the output of the first Block and the penultimate layer should be recorded.

rnet.record((1, 3), (1, 11))
print(repr(rnet).split("\n")[-2])  # Show penultimate summary line
Currently recording: 1-3, 1-11

Note

Though not shown in the summary, it is possible to select 0-1 to refer to the entire model.

Warning

torchinfo.summary can only detect torch.nn.Module calls. As such, if a Neural Network uses activation functions, it should call their Module implementation (instead of their functional counterpart) and declare them as an attribute, for them to be visible as a layer in the summary. It is not necessary to declare a different attribute for every activation call (e.g. one self.relu = nn.ReLU() can be used multiple times in forward()).

Capturing internal states#

Once the recording layers are set, every forward pass will automatically store the corresponding internal states in the rnet.activations mapping. Its keys are the (depth, idx) 2-tuples.

out = rnet(inputs)  # Forward pass, records the activations
rnet.activations
mappingproxy({(1, 3): tensor([[[[ 2.8420e+00,  4.4273e+00, -2.0810e+00,  ...,  2.0311e+00,
           -1.0973e+01,  8.4303e-01],
          [ 6.3156e+00, -6.7697e-01,  1.6519e+00,  ...,  4.9773e+00,
           -8.6386e-01, -3.9291e+00],
          [-1.6540e+00,  5.6147e+00,  2.2276e+00,  ..., -2.1683e+00,
            5.2803e+00, -6.5720e+00],
          ...,
          [ 1.3550e+00,  2.1577e+00,  5.5742e+00,  ...,  2.2021e+00,
           -1.0316e+00,  1.7413e+00],
          [-1.9731e+00,  2.0812e+00,  7.1377e-02,  ...,  4.3805e-01,
            1.9455e+00, -1.4778e-01],
          [ 4.9764e+00,  1.0693e+00, -7.4499e-01,  ...,  1.4489e+00,
            3.6643e+00, -4.1437e+00]],

         [[ 1.3507e+00,  8.6619e+00,  5.7422e+00,  ...,  2.9645e+00,
            1.5294e+01, -2.5170e+00],
          [-2.2380e+00,  3.6264e+00,  1.1200e+00,  ..., -1.5105e+00,
           -2.7446e+00, -1.2021e+00],
          [ 3.3216e+00,  1.0051e+00,  5.1448e+00,  ...,  5.6158e+00,
            1.9937e-02,  1.0467e+01],
          ...,
          [-1.6429e+00, -9.5434e-01, -9.7184e-01,  ...,  3.0675e+00,
           -1.6406e+00, -4.1313e+00],
          [ 1.2051e+01,  6.1480e+00,  6.0691e+00,  ...,  3.2583e+00,
           -2.6474e-01,  1.4200e+00],
          [-2.7879e-01,  8.2557e-03,  8.2720e+00,  ...,  1.0185e+01,
            2.3029e+00,  5.8195e+00]],

         [[-1.1300e+00,  3.5811e+00,  2.4237e+00,  ...,  1.4918e+00,
            3.2682e+00,  7.5876e+00],
          [ 1.4604e+00, -3.3043e+00,  1.6434e+00,  ..., -7.6549e-01,
            8.3824e-01,  3.3841e-03],
          [-1.1598e+00,  2.9703e+00, -8.6725e-01,  ..., -6.0408e-01,
           -9.6863e-01, -3.2279e+00],
          ...,
          [-3.3875e+00,  5.9453e-01,  1.5919e-01,  ..., -3.0239e+00,
            2.7633e+00,  6.2630e+00],
          [ 1.2493e+00, -2.8350e+00,  1.9953e-01,  ..., -1.8353e+00,
            1.0246e-01,  5.1745e-01],
          [-2.4358e+00, -5.6508e+00, -7.6016e+00,  ..., -6.3176e+00,
           -6.6115e+00, -6.0625e+00]],

         ...,

         [[-2.7766e+00, -6.3150e+00, -3.7029e+00,  ..., -4.7805e+00,
           -9.6718e+00, -3.6660e+00],
          [ 1.6638e+00, -1.2483e-01,  1.0744e+00,  ..., -6.0021e-01,
           -2.4963e-01,  6.8361e-02],
          [-2.4005e+00, -3.5704e+00, -2.8819e+00,  ..., -3.4709e+00,
           -8.4034e+00, -7.1512e+00],
          ...,
          [-3.2388e+00, -1.1784e+00, -9.9880e+00,  ..., -1.1189e+00,
            1.2202e-01, -5.3838e+00],
          [-1.1199e+01, -1.3061e+00, -4.8230e+00,  ...,  3.2339e+00,
           -4.2897e-01, -2.6841e+00],
          [-8.6726e+00, -2.5014e-01, -5.8868e+00,  ..., -7.3977e+00,
           -3.9331e+00, -4.2529e+00]],

         [[-4.4641e+00, -6.0848e+00, -3.1120e+00,  ...,  6.1810e-02,
            4.0646e-01, -3.4574e+00],
          [-4.9035e-02, -6.7427e+00, -4.7861e+00,  ...,  1.0503e+00,
           -1.0854e+00, -2.0908e+00],
          [-4.8489e+00, -2.2788e+00, -3.7109e+00,  ...,  2.4285e-01,
           -5.1419e+00,  1.0124e-01],
          ...,
          [-2.0639e+00,  1.6324e+00, -9.0275e+00,  ...,  7.6924e-01,
            5.4993e-01, -6.2602e+00],
          [-9.0649e-01, -1.3775e+00, -1.4171e-01,  ..., -3.8969e+00,
           -2.5080e-01,  4.0028e-02],
          [-5.3131e+00, -2.0747e+00, -4.4897e+00,  ..., -5.2434e+00,
           -4.4784e+00, -3.9117e+00]],

         [[ 3.4734e-01, -9.6062e+00, -8.2036e+00,  ..., -2.6756e-01,
           -9.6271e+00,  2.9496e+00],
          [ 5.1060e+00, -1.5110e+00, -1.0069e-01,  ...,  1.7573e+00,
            9.0696e-01,  6.3989e-01],
          [-1.6898e+00, -8.8562e-02, -1.9490e+00,  ..., -2.8644e+00,
           -5.4682e-01, -8.7605e+00],
          ...,
          [ 1.0619e+00,  4.7499e+00, -3.6191e+00,  ..., -1.0141e+00,
            1.5665e+00,  4.0694e+00],
          [-8.6281e+00, -3.1539e+00, -2.2756e+00,  ..., -4.8996e+00,
            1.2239e-01,  4.5641e-02],
          [-1.1596e+00, -2.0775e+00, -1.1806e+01,  ..., -1.3374e+01,
           -3.7764e+00, -5.7993e+00]]],


        [[[ 5.9621e+00,  1.4455e+00, -2.2340e+00,  ..., -3.4084e-02,
           -3.3006e+00, -4.3624e+00],
          [ 3.6286e+00, -5.8562e-01, -2.2610e+00,  ...,  1.3377e+00,
           -5.3778e+00, -1.2468e+00],
          [-4.2832e+00,  8.7533e-01,  4.8273e+00,  ..., -8.2866e-01,
           -3.3952e+00, -1.6456e+00],
          ...,
          [-1.5933e-01,  6.4941e-01,  4.5391e-01,  ..., -4.5244e+00,
            1.1234e-01, -2.6209e+00],
          [-4.5611e+00,  2.1963e+00,  2.7939e+00,  ...,  4.1223e+00,
           -6.6522e+00, -3.3354e-01],
          [ 2.3822e-01, -8.9368e+00,  2.7662e+00,  ...,  3.2473e+00,
           -6.8607e+00, -6.4876e-01]],

         [[-1.7381e+00, -3.3442e+00,  1.0841e-01,  ...,  1.3510e+00,
            8.4710e+00,  1.6372e+00],
          [ 2.1760e+00,  2.1689e+00,  3.1505e+00,  ...,  1.4733e+00,
            8.7623e+00,  9.4841e+00],
          [ 1.0223e+01,  4.4697e+00,  7.9485e+00,  ...,  1.1135e+00,
            7.0852e+00,  1.1237e+00],
          ...,
          [ 3.0566e+00,  1.0388e+00, -7.1946e-01,  ...,  8.3603e+00,
           -3.5966e-01,  5.7848e+00],
          [ 1.4422e+01, -1.2163e+00,  4.2834e-01,  ..., -1.4521e+00,
            1.2370e+01, -8.2423e-01],
          [ 1.9149e+00,  1.1713e+01, -1.4724e+00,  ..., -3.8575e-01,
            1.2822e+01,  4.5188e+00]],

         [[ 4.9238e+00,  9.1158e+00,  5.8365e+00,  ...,  3.5029e+00,
           -3.4892e-01,  3.7643e+00],
          [ 1.3117e-01, -8.1194e-01,  3.7075e+00,  ...,  9.1845e-01,
           -3.0742e+00,  1.7938e+00],
          [-2.7412e+00,  1.7637e-01,  1.5172e+00,  ...,  3.8453e+00,
           -6.1423e-01,  1.0143e+00],
          ...,
          [-2.3274e-01, -1.5521e+00, -3.3791e+00,  ..., -3.7106e+00,
            2.9397e+00, -1.1357e+00],
          [-5.1460e+00, -2.0376e-01, -1.2215e+00,  ..., -2.0066e-01,
           -6.1494e+00,  9.6931e-01],
          [-6.3288e+00, -7.6544e+00, -5.8582e+00,  ..., -6.7047e+00,
           -1.1267e+01, -5.6204e+00]],

         ...,

         [[-9.7749e+00, -1.8637e+00, -1.3972e+00,  ...,  1.3722e+00,
           -1.2530e+00,  8.0717e-01],
          [-2.2244e+00, -1.7044e+00, -1.6113e+00,  ..., -1.1461e+00,
           -6.3236e+00, -4.1299e+00],
          [-1.0323e+01, -4.1420e+00, -6.2381e+00,  ..., -4.0253e+00,
           -4.5029e+00, -2.8431e+00],
          ...,
          [-3.4897e+00, -7.5771e-01,  3.4172e+00,  ..., -5.1617e+00,
           -5.7993e+00, -2.3412e+00],
          [-6.5054e+00, -4.2612e+00, -4.1600e+00,  ..., -5.3065e+00,
           -7.4251e+00, -7.3673e+00],
          [-1.2217e+00, -8.6800e+00,  1.7027e-01,  ..., -8.7240e-01,
           -4.4210e+00, -4.1965e-01]],

         [[-5.0527e+00, -2.3448e+00, -6.8790e-01,  ..., -2.3415e+00,
           -5.5439e+00, -4.6923e+00],
          [-9.9265e-01, -2.1647e+00, -1.5033e+00,  ..., -6.2226e+00,
           -2.0686e+00, -6.1654e+00],
          [-1.5460e+00, -3.6793e+00, -6.1195e+00,  ..., -8.2441e+00,
           -3.7021e-01, -7.1743e-01],
          ...,
          [-2.1983e+00,  1.0882e+00, -4.5601e-01,  ..., -1.3695e+00,
           -6.7608e+00, -4.3423e+00],
          [-2.3327e+00, -2.5034e+00, -2.5741e+00,  ..., -1.8593e+00,
           -2.7594e+00, -1.1868e+01],
          [-3.4959e+00, -2.7478e+00,  9.0655e-01,  ..., -1.7377e+00,
           -3.0957e+00, -2.6584e+00]],

         [[ 4.4210e-01,  4.5917e+00, -1.0057e-01,  ..., -1.4424e+00,
           -1.5338e+00, -2.7505e+00],
          [ 3.7487e-01,  6.7281e-01,  7.5896e-01,  ...,  9.0369e-01,
           -7.5366e+00, -2.5926e+00],
          [-1.0950e+01, -4.6743e+00, -8.4272e+00,  ..., -2.4001e+00,
           -2.8578e+00,  7.9456e-01],
          ...,
          [-9.4300e-01,  1.9821e-01,  1.7910e+00,  ..., -7.0246e+00,
           -1.1222e-01, -5.2846e+00],
          [-5.3871e+00,  1.1314e-01, -1.8710e+00,  ...,  2.0158e+00,
           -1.0908e+01, -2.9692e+00],
          [ 2.6707e-01, -6.9155e+00, -1.1453e+00,  ..., -3.0510e+00,
           -8.3460e+00, -4.0057e+00]]],


        [[[ 7.1748e+00, -4.5118e+00,  2.5004e+00,  ..., -7.3491e+00,
           -2.9264e+00, -3.4760e+00],
          [ 7.1586e+00, -7.8375e+00,  2.3567e+00,  ...,  3.1427e+00,
            7.2580e+00, -1.1591e+01],
          [-2.4562e+00,  3.5660e+00,  4.0418e+00,  ...,  3.5772e+00,
           -6.3613e+00,  1.4020e+00],
          ...,
          [ 6.5881e+00, -1.9919e+00, -2.8121e+00,  ..., -3.4671e+00,
            2.2774e+00,  1.0924e+00],
          [ 6.6684e+00, -5.8324e+00, -1.8164e+00,  ...,  4.1258e+00,
           -1.1686e-01, -1.9188e+00],
          [ 2.0094e+00,  5.0667e+00,  4.8137e+00,  ..., -5.9762e+00,
            2.5489e+00,  2.0629e+00]],

         [[-1.1305e+00,  8.7750e+00, -1.1253e-01,  ...,  1.1135e+01,
            4.5165e+00,  1.6569e+00],
          [-1.1873e+00,  1.2485e+01, -1.1897e+00,  ..., -2.0340e+00,
           -1.9602e+00,  1.5549e+01],
          [ 8.0517e+00, -2.1965e+00, -3.6510e+00,  ...,  7.5174e-01,
            1.7097e+01,  6.6879e+00],
          ...,
          [-8.4877e-01,  4.7988e+00,  2.6078e+00,  ...,  2.6035e+00,
           -2.3894e+00,  6.0107e-02],
          [-2.5432e-01,  1.1361e+01,  4.6063e+00,  ...,  1.5907e-02,
           -1.4070e+00, -2.1952e+00],
          [-1.7280e+00, -1.3441e+00, -1.2083e+00,  ...,  1.5456e+01,
            1.5200e+00,  8.7036e-01]],

         [[ 4.0726e+00,  4.5839e-01,  4.7880e+00,  ...,  2.3804e+00,
            4.5400e+00,  3.2143e+00],
          [ 5.2156e-01, -6.6741e+00, -1.9597e-02,  ...,  1.8422e+00,
            2.8243e+00, -5.7023e+00],
          [-3.1766e+00,  3.8173e+00,  8.2692e-01,  ..., -6.7872e+00,
           -7.6065e+00, -2.0153e+00],
          ...,
          [ 3.1436e+00, -2.5336e+00, -3.2835e+00,  ..., -2.0447e+00,
            5.4782e-01, -1.5026e+00],
          [ 8.4596e-01, -6.1273e+00, -4.6527e+00,  ...,  2.3528e+00,
            1.2852e+00,  1.6685e+00],
          [-6.3066e+00, -7.0497e+00, -5.5292e+00,  ..., -1.1050e+01,
           -7.4969e+00, -5.4729e+00]],

         ...,

         [[-8.3947e+00, -4.8062e+00, -5.2923e+00,  ..., -8.7104e+00,
           -4.3753e+00, -2.2311e+00],
          [-5.1493e+00, -9.4229e+00,  8.1030e-01,  ..., -2.0103e+00,
           -9.8018e+00, -6.8862e+00],
          [-5.5428e+00, -7.3654e+00, -2.1399e-01,  ..., -8.6805e+00,
           -1.3322e+01, -3.8899e+00],
          ...,
          [-9.8075e+00, -2.5405e+00, -1.6708e+00,  ..., -6.9319e-01,
            5.1245e-01, -3.3572e+00],
          [-4.3616e+00, -5.6247e+00, -7.3513e-01,  ..., -6.3263e-01,
           -1.4678e+00,  7.4891e-01],
          [-1.7370e+00, -4.6718e+00, -4.7634e+00,  ..., -7.6442e+00,
           -1.7703e+00, -2.9149e+00]],

         [[-3.0375e+00, -2.3360e+00, -2.5423e+00,  ..., -1.9420e+00,
           -1.4413e+00, -3.5158e-01],
          [-3.7775e+00, -1.1276e+00, -5.9864e-01,  ...,  8.8618e-01,
           -7.4435e+00, -1.9713e+00],
          [-2.7243e+00, -6.4788e+00,  2.4655e-01,  ..., -9.2257e+00,
           -1.2329e+00, -6.3128e+00],
          ...,
          [-1.0587e+01, -1.5232e+00,  4.3046e-01,  ..., -5.0218e+00,
           -7.1009e-01, -1.5211e+00],
          [-4.1101e+00, -2.9652e+00, -4.6391e+00,  ..., -7.2309e+00,
           -4.9523e-01,  1.3519e+00],
          [-3.4404e+00, -8.7925e+00, -7.3745e+00,  ..., -2.5850e+00,
           -1.3480e+00, -2.6436e+00]],

         [[ 3.2614e+00, -6.1063e+00,  3.0308e-01,  ..., -1.2505e+01,
           -9.0478e-01,  6.9069e-01],
          [ 6.7505e-01, -1.4260e+01,  2.6364e+00,  ...,  2.0032e+00,
            1.7798e+00, -9.7120e+00],
          [-7.5726e+00,  8.3288e-01,  1.4997e+00,  ..., -5.3520e+00,
           -1.5614e+01, -6.3646e+00],
          ...,
          [-3.0049e+00, -3.5255e+00,  8.9272e-03,  ..., -2.7497e+00,
           -6.3980e-01, -1.2784e+00],
          [-2.3251e+00, -8.3791e+00, -8.3439e-01,  ..., -1.5853e+00,
            5.4992e-01,  8.0183e-01],
          [ 2.7707e-01, -3.2186e+00, -3.1080e+00,  ..., -9.9593e+00,
           -2.1098e+00, -3.1099e+00]]],


        [[[ 4.5305e+00,  5.4767e+00,  1.9562e+00,  ...,  3.0760e+00,
            1.5317e+00, -9.4330e+00],
          [-7.5206e+00,  1.3269e+00,  4.4376e+00,  ..., -8.9439e-01,
           -5.5123e+00,  6.0969e-02],
          [-3.8505e+00,  5.3131e-01,  7.0657e+00,  ...,  1.9390e+00,
            2.2721e+00, -3.0328e+00],
          ...,
          [ 3.8092e+00,  7.3650e+00,  3.6250e+00,  ...,  4.9643e+00,
            6.1939e-01, -2.0685e+00],
          [ 1.7308e+00, -9.8710e+00,  6.2731e+00,  ..., -2.5102e+00,
            6.1463e+00, -6.5744e-01],
          [ 2.6003e+00,  1.9007e+00,  5.1751e+00,  ...,  1.9466e+00,
           -1.3727e+00, -2.9885e+00]],

         [[ 2.1407e+00, -5.0797e-01, -1.9937e+00,  ...,  1.1988e+00,
            8.0608e-01,  7.7564e+00],
          [ 1.7385e+01, -2.0188e+00, -2.7685e+00,  ..., -6.0461e-01,
            1.2209e+01, -1.4851e+00],
          [ 8.5835e+00,  4.8589e+00,  4.4506e+00,  ..., -3.4656e+00,
            3.6447e-01,  1.8612e+00],
          ...,
          [-1.0488e+00,  4.5731e-01,  4.7558e+00,  ...,  6.4012e+00,
            1.2696e+01,  8.0009e-01],
          [ 4.7644e+00,  1.0588e+01,  9.7654e-01,  ..., -1.1973e+00,
            8.1417e+00, -6.9475e-01],
          [ 2.5049e+00,  2.1639e+00, -5.6579e-01,  ...,  2.4361e-01,
            2.4934e+00,  1.2214e+00]],

         [[ 4.1603e-01,  2.7324e+00,  4.7885e+00,  ...,  3.1784e+00,
            3.7596e+00, -2.1068e+00],
          [-3.0626e+00, -2.1526e+00, -1.3592e+00,  ..., -6.6196e+00,
           -6.1460e+00,  9.0252e-01],
          [-3.8419e+00,  5.1649e-01, -4.0003e-01,  ...,  2.4868e+00,
            1.2281e+00,  1.8677e+00],
          ...,
          [ 1.7838e-01, -3.7554e-01, -5.8808e+00,  ..., -6.9580e-01,
            1.9065e+00,  2.7954e+00],
          [-2.7934e-01, -4.1898e+00,  3.4449e-01,  ...,  6.7025e-01,
            2.9693e+00,  1.6735e+00],
          [-5.1979e+00, -6.9894e+00, -5.7696e+00,  ..., -9.8218e+00,
           -8.2856e+00, -6.7621e+00]],

         ...,

         [[-3.8402e+00, -3.8054e+00,  8.2086e-01,  ..., -3.8146e+00,
           -6.5891e+00, -3.7835e+00],
          [-1.0327e+01, -1.8294e+00,  2.2201e-01,  ..., -8.3145e-03,
           -1.0295e+01, -2.8039e+00],
          [-5.2530e+00, -3.7232e+00, -2.0356e+00,  ...,  1.9055e+00,
           -1.2827e+00, -1.6872e+00],
          ...,
          [-2.9233e+00, -3.3778e+00, -3.1521e+00,  ..., -6.3673e+00,
           -1.0393e+01, -1.2780e+00],
          [-5.5505e+00, -7.7596e+00, -2.9009e+00,  ..., -2.3937e+00,
           -8.2770e+00, -1.7952e+00],
          [-2.5277e+00,  3.1908e+00, -4.8716e+00,  ...,  4.5527e-01,
           -1.4389e+00, -5.0932e-01]],

         [[ 1.6565e+00,  4.6899e-01,  5.5947e-01,  ..., -2.9588e+00,
           -4.8498e-01, -2.4243e+00],
          [-6.8780e-01, -6.7681e+00, -1.2987e+00,  ..., -4.9600e+00,
           -2.2053e+00, -7.8756e+00],
          [-3.5408e+00, -4.1152e+00, -4.2555e+00,  ...,  2.1590e+00,
            1.8863e-02, -7.9127e-01],
          ...,
          [-1.8278e+00, -3.8451e+00,  1.4401e-01,  ..., -6.7518e+00,
           -4.4214e+00, -1.8849e+00],
          [-6.1052e-01, -2.1633e+00, -3.9621e+00,  ..., -1.9885e+00,
           -8.9501e+00,  1.5313e+00],
          [-1.8303e-01, -1.4884e+00, -2.7497e+00,  ..., -3.1298e+00,
           -2.3116e+00, -1.2665e+00]],

         [[ 2.2970e-02,  1.9033e+00,  7.7147e-01,  ...,  4.2411e-01,
            1.9232e+00, -7.0313e+00],
          [-7.8668e+00, -4.9783e+00, -2.3168e+00,  ...,  1.2854e-01,
           -1.0944e+01,  2.3174e-01],
          [-4.5128e+00, -6.7707e+00, -6.5001e+00,  ...,  4.0267e+00,
            1.3554e+00, -5.7253e-01],
          ...,
          [ 2.4720e+00, -2.7066e+00, -2.3215e+00,  ..., -7.7985e+00,
           -1.3994e+01,  1.1515e-01],
          [-1.3280e+00, -8.1062e+00, -2.2597e+00,  ...,  1.0752e+00,
           -8.7901e+00,  1.8237e+00],
          [-1.3959e+00, -2.7491e+00, -2.1824e+00,  ..., -3.9142e+00,
           -3.3950e+00, -2.7110e+00]]],


        [[[ 2.1758e+00,  5.9327e-01,  1.2321e+00,  ..., -8.6349e+00,
           -7.0824e+00, -5.5872e-01],
          [ 4.1945e+00,  3.3575e+00, -2.0221e-01,  ...,  4.5126e+00,
            4.5874e+00, -7.0484e+00],
          [-2.7822e+00,  2.2900e+00,  2.5726e+00,  ...,  3.5163e+00,
           -1.1748e+00, -3.0288e-01],
          ...,
          [-1.4339e+00,  5.5685e+00,  6.0679e-02,  ...,  1.8854e+00,
            3.3062e+00,  9.3671e-01],
          [ 2.5136e+00, -1.7346e-01,  1.2026e+00,  ...,  1.6699e+00,
           -1.9329e+00,  1.8256e+00],
          [ 4.9369e+00, -1.2356e+00,  3.7085e+00,  ...,  2.8774e+00,
           -1.4861e+00, -8.9487e-01]],

         [[-1.7941e+00,  2.3797e+00, -3.5294e-01,  ...,  1.2222e+01,
            1.0855e+01,  4.4987e+00],
          [ 7.2944e-01, -1.7572e+00,  2.5690e+00,  ..., -5.6604e-01,
           -1.8293e+00,  1.2401e+01],
          [ 1.1328e+01,  7.2125e+00, -1.8892e+00,  ...,  3.0096e-01,
           -1.1639e+00,  9.8209e-01],
          ...,
          [ 4.8918e+00, -1.2627e-01,  1.9149e-01,  ...,  1.2965e+00,
           -8.3881e-01, -1.5170e+00],
          [-1.5264e+00,  3.9689e+00,  5.7733e-01,  ..., -4.2879e-01,
            1.4066e+01,  3.3544e+00],
          [-2.1904e+00,  4.5277e+00,  2.3828e-01,  ...,  9.7892e-01,
            4.3096e+00, -3.1338e-01]],

         [[ 1.6200e+00,  5.8688e+00,  3.7786e+00,  ..., -3.2042e-02,
           -1.0947e+00, -1.0978e+00],
          [-1.6709e+00,  2.4116e+00, -5.9302e-01,  ..., -4.3238e+00,
           -1.7370e+00, -6.4911e+00],
          [-5.8312e+00, -4.3879e+00,  2.5381e+00,  ...,  2.8964e+00,
            4.8751e+00,  2.1989e+00],
          ...,
          [-1.0717e+00, -3.7832e+00, -2.4481e+00,  ..., -1.4853e+00,
           -3.9329e-01, -1.4572e-01],
          [ 7.2290e-01,  1.4523e+00, -2.5195e-01,  ...,  4.6703e+00,
           -1.2500e+00, -2.9886e-01],
          [-1.6390e+00, -6.2312e+00, -6.3457e+00,  ..., -6.5320e+00,
           -9.0617e+00, -4.1691e+00]],

         ...,

         [[-3.1741e+00, -3.5961e+00,  1.1599e+00,  ..., -1.0708e+01,
           -7.2371e+00, -3.4183e+00],
          [-2.9512e+00, -2.4396e+00, -3.3965e+00,  ...,  4.5411e-02,
           -2.1751e+00, -6.6114e+00],
          [-8.7887e+00, -3.6944e+00,  7.1111e-01,  ..., -2.1729e+00,
           -1.0560e+00, -9.2749e-01],
          ...,
          [-4.4438e+00, -7.4573e+00,  2.4914e-01,  ..., -2.6611e+00,
           -1.8291e+00, -3.5990e+00],
          [-2.4200e+00, -5.6907e+00, -2.2241e+00,  ..., -5.2605e+00,
           -1.1076e+01, -1.9971e+00],
          [-7.7697e+00, -4.0830e+00, -8.5536e+00,  ...,  1.1202e+00,
           -1.5295e+00,  2.0985e+00]],

         [[-1.9831e+00, -5.3594e+00, -2.6770e+00,  ..., -1.5307e+00,
           -3.1098e+00, -2.9751e+00],
          [-6.0538e-01, -2.8479e-01, -4.0700e+00,  ...,  1.3797e+00,
           -9.2275e+00, -1.1580e+00],
          [-4.0497e-01, -2.5699e+00,  2.5631e-01,  ...,  9.4149e-01,
           -6.2141e-01, -3.2685e+00],
          ...,
          [-8.5673e-01, -9.9745e+00, -2.4676e+00,  ..., -8.1825e-01,
           -1.5833e+00, -1.1985e+00],
          [-3.0810e+00,  1.2629e+00,  6.4224e-01,  ..., -5.9781e+00,
           -4.0432e+00, -4.3197e+00],
          [-4.3249e+00, -4.5425e+00, -7.7394e+00,  ..., -4.5886e+00,
           -2.2106e+00, -6.3148e-01]],

         [[ 2.7977e+00, -2.0868e+00, -2.2436e-01,  ..., -7.5427e+00,
           -6.5260e+00, -3.1116e+00],
          [ 9.0047e-01,  5.0054e+00, -3.8223e+00,  ...,  1.2566e+00,
           -1.3626e+00, -7.5790e+00],
          [-9.1484e+00, -3.5936e+00,  2.1066e+00,  ...,  1.7353e+00,
           -7.2539e-01, -2.5258e+00],
          ...,
          [-2.6909e+00, -5.9078e+00, -4.3391e-02,  ..., -2.0053e+00,
           -1.7248e+00, -8.4934e-01],
          [ 5.1667e-01, -4.7550e+00,  8.4912e-02,  ...,  5.8482e-01,
           -1.3108e+01, -5.4371e+00],
          [ 8.9254e-01, -5.8154e+00, -3.5535e+00,  ..., -3.0744e+00,
           -3.3707e+00, -4.1698e-01]]]], grad_fn=<AddBackward0>), (1, 11): tensor([[[[-7.4822e-01]],

         [[-4.0474e-02]],

         [[ 2.6209e-02]],

         [[ 6.0048e-03]],

         [[ 1.3856e-02]],

         [[ 1.3929e-02]],

         [[ 4.3777e-02]],

         [[ 5.0945e-02]],

         [[ 6.4966e-02]],

         [[-2.6290e-02]],

         [[-4.3105e-02]],

         [[ 2.2479e-02]],

         [[-1.0847e-01]],

         [[-2.6062e-02]],

         [[-6.6349e-02]],

         [[-4.8269e-02]],

         [[-1.1307e-02]],

         [[ 3.2397e-02]],

         [[ 7.8529e-02]],

         [[-8.4663e-02]],

         [[ 6.6399e-02]],

         [[ 5.7088e-03]],

         [[-2.7619e-02]],

         [[ 1.5104e+00]],

         [[ 8.8617e-02]],

         [[-6.4786e-03]],

         [[ 3.7181e-02]],

         [[ 5.4356e-01]],

         [[ 1.4037e-01]],

         [[-1.4094e-01]],

         [[-1.2397e+00]],

         [[-2.3682e-01]],

         [[-1.1380e-01]],

         [[-3.9854e-02]],

         [[ 1.0157e+00]],

         [[ 1.2265e-01]],

         [[ 1.0587e-01]],

         [[-2.8815e-02]],

         [[ 7.9557e-02]],

         [[ 4.3261e-02]],

         [[-7.1797e-02]],

         [[-1.6713e-02]],

         [[-2.4861e-02]],

         [[-2.4980e-02]],

         [[-4.6707e-02]],

         [[ 4.5929e-02]],

         [[-9.8940e-03]],

         [[ 1.1187e-01]],

         [[ 8.1387e-03]],

         [[-1.3140e-02]],

         [[-5.7129e-03]],

         [[-6.8272e-02]],

         [[ 7.9376e-02]],

         [[-9.5144e-02]],

         [[-3.2180e-02]],

         [[-7.2118e-03]],

         [[-8.5752e-03]],

         [[ 1.5660e-02]],

         [[ 5.8124e-02]],

         [[-1.2669e-01]],

         [[ 1.1428e-01]],

         [[ 5.6726e-02]],

         [[ 1.0857e-03]],

         [[-5.9308e-02]],

         [[-1.2102e-02]],

         [[ 3.7316e-02]],

         [[-4.6033e-02]],

         [[-1.8660e-02]],

         [[-9.4908e-02]],

         [[-8.9407e-03]],

         [[ 7.2383e-04]],

         [[ 4.0374e-02]],

         [[-1.1561e+00]],

         [[ 4.0183e-02]],

         [[ 1.0052e-01]],

         [[ 7.1533e-03]],

         [[ 7.0659e-02]],

         [[-1.7062e-02]],

         [[ 3.1400e-02]],

         [[-1.1585e-01]]],


        [[[-9.1093e-01]],

         [[-7.2730e-02]],

         [[ 4.6665e-03]],

         [[-1.2649e-02]],

         [[ 4.3804e-02]],

         [[ 4.9088e-02]],

         [[ 1.9215e-02]],

         [[ 1.1735e-01]],

         [[ 4.4435e-02]],

         [[-7.1439e-02]],

         [[ 1.2512e-03]],

         [[ 7.9875e-02]],

         [[-1.1604e-01]],

         [[ 3.7616e-02]],

         [[-7.2624e-02]],

         [[-4.2952e-02]],

         [[ 5.0094e-03]],

         [[ 1.3401e-02]],

         [[ 8.2043e-02]],

         [[-1.8213e-02]],

         [[ 5.1837e-02]],

         [[-4.8619e-02]],

         [[-6.7728e-02]],

         [[ 2.0457e+00]],

         [[ 8.5822e-02]],

         [[ 1.3747e-02]],

         [[ 3.5398e-02]],

         [[ 1.1464e+00]],

         [[ 9.2941e-02]],

         [[ 4.2882e-01]],

         [[-1.3445e+00]],

         [[-4.7511e-01]],

         [[-1.1156e-01]],

         [[-9.9971e-03]],

         [[ 1.5930e+00]],

         [[ 1.2014e-01]],

         [[ 6.8763e-02]],

         [[ 2.8476e-02]],

         [[ 1.1299e-01]],

         [[ 1.8088e-02]],

         [[-5.8403e-02]],

         [[ 1.2137e-02]],

         [[-1.6390e-02]],

         [[-1.3066e-02]],

         [[-5.7772e-02]],

         [[ 3.7589e-02]],

         [[-1.6163e-02]],

         [[ 1.4688e-01]],

         [[ 1.5389e-02]],

         [[ 2.5528e-02]],

         [[-1.8853e-02]],

         [[-6.3268e-02]],

         [[ 2.5449e-02]],

         [[-9.2687e-02]],

         [[-4.2817e-02]],

         [[ 3.4885e-03]],

         [[-1.6553e-02]],

         [[ 2.2975e-02]],

         [[ 2.1178e-02]],

         [[-1.3913e-01]],

         [[ 1.0426e-01]],

         [[-2.2789e-03]],

         [[ 6.2132e-03]],

         [[-8.2899e-02]],

         [[ 1.8497e-03]],

         [[ 1.6863e-02]],

         [[-6.2842e-02]],

         [[ 3.9875e-03]],

         [[-5.0250e-02]],

         [[-6.6747e-02]],

         [[-2.4827e-02]],

         [[ 2.3686e-02]],

         [[-1.3104e+00]],

         [[-3.3496e-02]],

         [[ 8.1682e-02]],

         [[-9.3074e-03]],

         [[ 5.5705e-02]],

         [[-1.1069e-01]],

         [[-2.0514e-02]],

         [[-1.1039e-01]]],


        [[[-9.2203e-01]],

         [[-4.5644e-03]],

         [[ 3.5233e-02]],

         [[ 2.4470e-02]],

         [[ 2.5927e-02]],

         [[ 4.2409e-03]],

         [[ 3.6164e-02]],

         [[ 6.6779e-02]],

         [[ 4.3562e-02]],

         [[-7.4008e-02]],

         [[-6.3954e-02]],

         [[ 6.3522e-02]],

         [[-1.1478e-01]],

         [[-3.2373e-02]],

         [[-1.5851e-02]],

         [[-1.0988e-02]],

         [[ 4.8525e-03]],

         [[ 7.9375e-03]],

         [[ 7.9662e-02]],

         [[-8.0588e-02]],

         [[ 2.6536e-02]],

         [[-6.8135e-03]],

         [[-6.6539e-02]],

         [[ 1.5932e+00]],

         [[ 6.9282e-02]],

         [[ 8.5928e-03]],

         [[-6.2198e-03]],

         [[ 3.6067e-01]],

         [[ 1.0573e-01]],

         [[ 1.8450e-01]],

         [[-1.8667e+00]],

         [[-5.6458e-01]],

         [[-1.5097e-01]],

         [[-3.5903e-02]],

         [[ 1.0627e+00]],

         [[ 1.4094e-01]],

         [[ 1.2245e-01]],

         [[-3.1380e-03]],

         [[ 1.0186e-01]],

         [[ 8.9269e-02]],

         [[-1.0392e-01]],

         [[ 4.7297e-02]],

         [[-3.6342e-02]],

         [[-2.0297e-02]],

         [[-4.6954e-02]],

         [[ 2.1278e-02]],

         [[ 3.6076e-02]],

         [[ 1.0626e-01]],

         [[-5.0688e-02]],

         [[ 2.4543e-02]],

         [[ 5.4098e-02]],

         [[-3.1049e-02]],

         [[ 1.1260e-02]],

         [[-1.3134e-01]],

         [[ 1.8662e-02]],

         [[-1.6954e-02]],

         [[-7.8215e-03]],

         [[ 5.2880e-02]],

         [[ 5.6676e-02]],

         [[-1.5296e-01]],

         [[ 1.2461e-01]],

         [[ 2.7213e-02]],

         [[-9.7989e-03]],

         [[-9.9688e-02]],

         [[ 3.3424e-02]],

         [[ 6.2275e-02]],

         [[-9.2602e-02]],

         [[ 3.4079e-03]],

         [[-1.4077e-01]],

         [[-4.1511e-02]],

         [[ 4.5301e-02]],

         [[ 4.2790e-02]],

         [[-1.6796e+00]],

         [[ 4.3438e-01]],

         [[ 7.5423e-02]],

         [[ 2.6293e-02]],

         [[ 4.6556e-02]],

         [[-5.3939e-02]],

         [[-3.9494e-04]],

         [[-1.0188e-01]]],


        [[[-8.0181e-01]],

         [[-6.0872e-02]],

         [[ 2.6518e-02]],

         [[ 1.0763e-02]],

         [[ 2.8162e-02]],

         [[ 3.6767e-02]],

         [[ 3.4474e-02]],

         [[ 7.3407e-02]],

         [[ 5.4176e-02]],

         [[-1.9530e-02]],

         [[ 1.6093e-02]],

         [[ 3.6236e-02]],

         [[-1.1459e-01]],

         [[ 3.2049e-03]],

         [[-9.9346e-02]],

         [[-8.5286e-02]],

         [[-2.5211e-02]],

         [[ 4.3432e-02]],

         [[ 9.3292e-02]],

         [[-5.0433e-02]],

         [[ 3.3465e-02]],

         [[-2.0842e-02]],

         [[-1.0976e-02]],

         [[ 2.4325e+00]],

         [[ 7.1651e-02]],

         [[-1.2168e-02]],

         [[ 2.3863e-02]],

         [[ 6.9919e-01]],

         [[ 1.1570e-01]],

         [[ 8.8412e-02]],

         [[-1.7923e+00]],

         [[-1.9857e-01]],

         [[-8.8584e-02]],

         [[ 7.4464e-03]],

         [[ 9.4857e-01]],

         [[ 1.0591e-01]],

         [[ 6.0264e-02]],

         [[ 1.6191e-02]],

         [[ 1.3199e-01]],

         [[ 3.3626e-02]],

         [[-3.8089e-02]],

         [[-1.2028e-02]],

         [[-3.0799e-02]],

         [[-3.8370e-03]],

         [[-4.1296e-02]],

         [[ 4.4644e-02]],

         [[-4.3898e-02]],

         [[ 8.9085e-02]],

         [[ 1.6809e-02]],

         [[-6.0743e-03]],

         [[-1.4915e-02]],

         [[-7.7690e-02]],

         [[ 2.2562e-02]],

         [[-5.5939e-02]],

         [[-4.1382e-02]],

         [[ 1.0779e-02]],

         [[ 4.2708e-03]],

         [[ 2.3684e-02]],

         [[ 1.6530e-02]],

         [[-8.9085e-02]],

         [[ 9.1743e-02]],

         [[-1.2811e-02]],

         [[ 5.2506e-03]],

         [[-8.9019e-02]],

         [[-1.3742e-02]],

         [[ 2.5177e-02]],

         [[-7.2293e-02]],

         [[-5.6415e-03]],

         [[-3.4890e-02]],

         [[-4.8723e-02]],

         [[-9.0879e-03]],

         [[ 2.9355e-02]],

         [[-5.5992e-01]],

         [[-4.4588e-02]],

         [[ 8.8262e-02]],

         [[-2.2690e-02]],

         [[ 1.1073e-01]],

         [[-5.9639e-02]],

         [[ 2.3763e-02]],

         [[-9.6556e-02]]],


        [[[-1.0881e+00]],

         [[-6.7351e-02]],

         [[ 2.5252e-02]],

         [[ 1.4765e-02]],

         [[ 6.7364e-03]],

         [[ 2.1351e-02]],

         [[ 1.8997e-02]],

         [[ 7.5013e-02]],

         [[ 6.5668e-02]],

         [[-3.9716e-02]],

         [[-2.2699e-02]],

         [[ 2.9843e-02]],

         [[-1.2453e-01]],

         [[ 1.1432e-02]],

         [[-9.7087e-02]],

         [[-7.2302e-02]],

         [[-2.7199e-02]],

         [[ 3.4174e-02]],

         [[ 8.9232e-02]],

         [[-3.5342e-02]],

         [[ 5.5133e-02]],

         [[-2.2613e-02]],

         [[-4.2604e-02]],

         [[ 2.2180e+00]],

         [[ 7.7475e-02]],

         [[ 1.0985e-03]],

         [[ 3.5922e-02]],

         [[ 8.2894e-01]],

         [[ 1.2703e-01]],

         [[ 7.3120e-01]],

         [[-1.7901e+00]],

         [[-2.8575e-01]],

         [[-1.0035e-01]],

         [[-3.5372e-03]],

         [[ 1.0665e+00]],

         [[ 1.3781e-01]],

         [[ 8.0308e-02]],

         [[-7.7634e-03]],

         [[ 1.0792e-01]],

         [[ 3.6673e-02]],

         [[-5.8687e-02]],

         [[-2.4987e-02]],

         [[-2.5769e-02]],

         [[-1.4609e-02]],

         [[-6.1633e-02]],

         [[ 5.6258e-02]],

         [[-4.7689e-02]],

         [[ 1.0748e-01]],

         [[ 3.4335e-02]],

         [[ 1.1853e-03]],

         [[-1.4728e-02]],

         [[-6.9002e-02]],

         [[ 3.6814e-02]],

         [[-7.9727e-02]],

         [[-3.9619e-02]],

         [[-3.1928e-03]],

         [[-1.9051e-04]],

         [[ 4.0728e-03]],

         [[ 3.3484e-02]],

         [[-1.2974e-01]],

         [[ 1.2706e-01]],

         [[ 1.9114e-02]],

         [[ 5.9182e-03]],

         [[-8.1433e-02]],

         [[-1.9542e-02]],

         [[ 2.8159e-02]],

         [[-6.2844e-02]],

         [[-2.3550e-02]],

         [[-7.5329e-02]],

         [[-3.4127e-02]],

         [[-2.5698e-03]],

         [[ 3.6313e-02]],

         [[-1.1798e+00]],

         [[ 3.4117e-01]],

         [[ 8.3288e-02]],

         [[-7.1232e-03]],

         [[ 9.2505e-02]],

         [[-6.9953e-02]],

         [[ 1.0684e-02]],

         [[-1.0230e-01]]]], grad_fn=<AsStridedBackward1>)})

Other functionalities#

In addition to recording, the following functionalities are available and documented in Recorder:

  • One-the-fly postprocessing of activations during inference (e.g. clipping).

  • Local disabling of the recording during forward pass.

  • Access the recorded layers’ Module objects directly.

  • Get the named parameters of the recorded layers.

Implementation details#

This utility works by affecting hooks to every layer of net with Module.register_forward_hook. However, since layers are not aware of the context in which they are called, these hooks carry references to rnet and with it, the sufficient context to know when to trigger. This means that two different Recorder Nets can wrap the same net without any conflict. As an implementation detail, note that these references are made weak in order to be properly cleaned up upon deletion of rnet.

Total running time of the script: (0 minutes 0.222 seconds)

Gallery generated by Sphinx-Gallery