Diving inside Neural Networks#

This tutorial provides a short practical overview of the Recorder class, which is designed to ease interactions with the internal states of PyTorch Neural Network objects (torch.nn.Module) during or after inference.

from scio.recorder import Recorder

Wrapping and visualizing your Neural Network#

Let us first load an arbitrary Neural Network. We use a lightweight Tiniest architecture trained on CIFAR10 and hosted on our hub. We also prepare future input data for this tutorial.

import torch

inputs = torch.rand(5, 3, 32, 32)  # Random inputs with 5 samples
net = torch.hub.load("ThalesGroup/scio:hub", "tiniest", trust_repo=True, verbose=False)
net = net.to(inputs)

To wrap it into a Recorder Net, rnet, one simply needs to specify an input_size (including batch dimension) or provide input_data. This will directly analyze and store the control flow of the model, using the torchinfo library.

rnet = Recorder(net, input_data=inputs[[0]])
rnet  # Visualize layers
Recorder instance for the following network
============================================================================================================================================
Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Param %
============================================================================================================================================
Tiniest (Tiniest)                        [1, 3, 32, 32]            [1, 10]                   --                             --
├─Conv2d (conv1): 1-1                    [1, 3, 32, 32]            [1, 48, 32, 32]           1,344                       1.38%
├─LayerNorm2d (ln1): 1-2                 [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    └─LayerNorm (ln): 2-1               [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
├─Block (l1): 1-3                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-2             [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-3             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-4             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-5             [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-1          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-6                 [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-7                 [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Block (l2): 1-4                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-8             [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-9             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-10            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-11            [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-2          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-12                [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-13                [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Block (l3): 1-5                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-14            [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-15            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-16            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-17            [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-3          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-18                [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-19                [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Conv2d (dsconv): 1-6                   [1, 48, 32, 32]           [1, 80, 16, 16]           3,920                       4.02%
├─LayerNorm2d (ln2): 1-7                 [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    └─LayerNorm (ln): 2-20              [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
├─Block (l4): 1-8                        [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-21            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-22            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-23            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-24            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-4          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-25                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-26                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─Block (l5): 1-9                        [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-27            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-28            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-29            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-30            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-5          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-31                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-32                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─Block (l6): 1-10                       [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-33            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-34            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-35            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-36            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-6          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-37                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-38                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─AdaptiveAvgPool2d (avgpool): 1-11      [1, 80, 16, 16]           [1, 80, 1, 1]             --                             --
├─Linear (fc): 1-12                      [1, 80]                   [1, 10]                   810                         0.83%
============================================================================================================================================
Total params: 97,530
Trainable params: 97,530
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 44.73
============================================================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 9.05
Params size (MB): 0.39
Estimated Total Size (MB): 9.45
============================================================================================================================================
Currently recording: None
============================================================================================================================================

Tip

For summary customization, refer to torchinfo.summary options. For example, it is possible to bound the depth of the representation tree with depth=2.

Note

In the case of dynamic control flow, refer to the force_static_flow argument of Recorder.

In many ways, this wrapper is transparent to the user. For example, one can naturally process data with rnet(inputs).

Selecting layers of interest#

The penultimate line of the above summary reports the layers that are currently set to be recorded (stored in rnet.recording). By default after instantiation, there are none.

print(repr(rnet).split("\n")[-2])  # Show penultimate summary line
Currently recording: None

One can arbitrarily set this using rnet.record() with the depth-idx identifiers from the summary (e.g. 1-9). For example, the following specifies that the output of the first Block and the penultimate layer should be recorded.

rnet.record((1, 3), (1, 11))
print(repr(rnet).split("\n")[-2])  # Show penultimate summary line
Currently recording: 1-3, 1-11

Note

Though not shown in the summary, it is possible to select 0-1 to refer to the entire model.

Warning

torchinfo.summary can only detect torch.nn.Module calls. As such, if a Neural Network uses activation functions, it should call their Module implementation (instead of their functional counterpart) and declare them as an attribute, for them to be visible as a layer in the summary. It is not necessary to declare a different attribute for every activation call (e.g. one self.relu = nn.ReLU() can be used multiple times in forward()).

Capturing internal states#

Once the recording layers are set, every forward pass will automatically store the corresponding internal states in the rnet.activations mapping. Its keys are the (depth, idx) 2-tuples.

out = rnet(inputs)  # Forward pass, records the activations
rnet.activations
mappingproxy({(1, 3): tensor([[[[ 1.1840e+00, -3.9768e+00, -6.1408e+00,  ...,  1.0248e+00,
            1.3194e+00, -8.6776e+00],
          [ 2.8668e+00,  5.5052e+00,  2.9481e+00,  ...,  3.1327e-01,
            4.5262e+00,  7.0912e-01],
          [ 3.3240e+00,  6.3997e+00, -9.9907e-01,  ...,  1.0634e+00,
            2.5642e+00, -2.6569e+00],
          ...,
          [-4.3052e+00, -1.0219e+00, -1.2242e+00,  ...,  4.2656e+00,
           -8.4259e-01, -6.4476e+00],
          [ 8.6488e+00, -5.8616e+00, -4.9792e-01,  ...,  1.0739e+00,
           -1.4000e+00, -1.4348e+00],
          [-3.2551e+00,  2.0203e+00,  1.6859e+00,  ...,  3.9990e-01,
            2.8529e+00, -7.7367e-01]],

         [[ 5.6371e+00,  4.7450e+00,  7.8421e+00,  ..., -1.8605e+00,
            3.9773e+00,  8.7987e+00],
          [-2.6222e-01, -1.9252e+00, -2.8565e-01,  ...,  2.6793e-01,
           -8.8865e-01, -2.5676e-01],
          [ 6.9759e-01, -2.0214e+00,  8.4964e+00,  ...,  8.6856e+00,
           -9.7730e-01,  8.0766e+00],
          ...,
          [ 7.2839e+00,  2.9292e-01,  3.9100e+00,  ..., -4.2314e-01,
            4.9931e-01,  7.9439e+00],
          [-3.9626e+00,  9.1079e+00,  4.5778e-01,  ...,  5.9899e+00,
            3.6035e+00, -1.7889e+00],
          [ 4.1193e+00,  1.7301e+00,  1.3910e+00,  ..., -1.1367e+00,
           -9.3407e-01, -1.9953e+00]],

         [[-8.9496e-01,  1.4572e+00,  1.9626e+00,  ...,  5.3902e+00,
            2.4456e+00, -3.4632e-01],
          [ 2.9627e-01,  1.4951e+00,  1.3362e+00,  ..., -2.6111e+00,
           -1.8569e+00, -1.2107e+00],
          [ 2.5190e-01, -2.7497e-01, -5.7544e+00,  ...,  2.5789e+00,
            1.9760e+00, -3.7083e+00],
          ...,
          [-4.2791e+00, -1.9051e+00, -3.1659e+00,  ...,  2.5021e+00,
            2.0395e+00,  6.8477e-01],
          [ 1.3172e+00, -4.9972e+00, -2.6876e+00,  ...,  4.4032e-01,
           -3.4425e+00,  1.2055e+00],
          [-6.8624e+00, -6.5304e+00, -7.3056e+00,  ..., -7.3569e+00,
           -6.4987e+00, -4.9576e+00]],

         ...,

         [[-3.1299e+00, -4.4222e+00, -6.2174e+00,  ...,  4.1773e-02,
           -4.1154e+00, -3.4646e+00],
          [-1.7081e+00, -6.8900e+00, -3.4841e+00,  ..., -2.6103e+00,
           -3.5371e+00, -3.6157e-02],
          [-3.8578e+00, -8.9715e+00, -6.1737e+00,  ..., -7.3860e+00,
           -1.2085e+00, -2.4203e+00],
          ...,
          [-6.6153e+00,  4.6556e+00, -1.4511e+00,  ..., -3.9467e+00,
            1.0293e+00, -4.8110e+00],
          [-1.9970e+00, -5.0816e+00,  1.4338e+00,  ..., -8.2089e+00,
           -2.7101e+00, -4.1119e+00],
          [-3.6448e+00, -2.0770e+00, -7.1846e-01,  ..., -1.6314e+00,
            1.0381e+00,  1.6649e-02]],

         [[-6.2612e+00, -3.3117e-02, -5.6506e-01,  ..., -1.6037e+00,
           -4.6247e+00, -3.1226e+00],
          [-7.1616e-01, -8.9988e+00, -3.6300e+00,  ..., -3.7637e+00,
           -2.3856e+00,  5.7279e-01],
          [ 1.1187e+00, -8.4097e+00, -3.2850e+00,  ..., -4.1676e+00,
           -8.2766e+00, -6.6962e+00],
          ...,
          [-3.5546e+00,  2.1283e-01, -1.0575e+00,  ..., -1.1217e-01,
            8.9076e-02,  5.4303e-01],
          [-3.0165e+00, -4.5547e+00, -1.7245e+00,  ..., -1.8988e+00,
           -4.6866e+00, -7.2491e+00],
          [-4.9434e+00, -3.6602e+00, -4.2957e+00,  ..., -3.9212e+00,
            3.8156e-01, -7.9680e-01]],

         [[-3.3626e+00, -3.8030e+00, -4.9477e+00,  ...,  3.8481e-01,
           -1.7839e+00, -2.1101e+00],
          [ 3.7442e+00, -8.6049e-01, -1.0720e+00,  ..., -1.6703e+00,
           -5.9080e+00,  4.2176e-01],
          [ 1.4793e+00, -8.2004e-01, -9.8738e+00,  ..., -8.2815e+00,
           -1.3527e+00, -5.6160e+00],
          ...,
          [-3.2134e+00, -1.0749e+00, -2.5337e+00,  ...,  2.2854e+00,
            6.6428e-01, -5.3389e+00],
          [ 3.2489e+00, -5.8236e+00, -2.2799e+00,  ..., -7.1258e+00,
           -3.1669e+00, -1.8513e+00],
          [-2.4286e+00, -4.5986e+00, -3.9689e+00,  ..., -1.3083e+00,
            2.5574e-01, -9.4972e-01]]],


        [[[ 6.1135e-01, -4.0790e-01,  1.7306e+00,  ..., -7.8867e+00,
            3.8079e+00,  2.5640e+00],
          [ 8.1182e+00,  1.9262e-01,  5.3612e+00,  ..., -1.1460e+00,
            7.2472e+00,  1.0969e+00],
          [ 6.4297e+00,  3.4087e+00,  3.2971e+00,  ...,  3.7006e+00,
            1.2790e+00, -8.7348e+00],
          ...,
          [ 8.0043e+00, -7.6881e+00,  7.2996e+00,  ..., -4.0845e-01,
            1.7941e+00, -3.7360e-01],
          [-3.4967e+00,  4.5644e+00,  5.4447e+00,  ..., -2.8013e+00,
            8.1206e-01, -4.1381e+00],
          [-1.2101e+00,  1.3591e+00, -3.9865e+00,  ...,  5.4945e+00,
            5.7316e+00, -2.1428e+00]],

         [[-9.5532e-01, -1.4504e+00,  9.5772e+00,  ...,  8.5158e+00,
           -3.1989e+00,  7.0274e+00],
          [-1.0869e+00,  1.0710e+00,  2.3832e+00,  ...,  5.7870e+00,
            9.9784e-01,  6.7216e+00],
          [-1.4057e+00,  4.7372e+00, -2.4202e+00,  ...,  2.6021e+00,
            7.1666e-01,  2.9739e+00],
          ...,
          [-4.4602e+00,  1.1406e+01,  7.4940e-01,  ...,  7.2097e+00,
            2.1769e+00,  1.3647e+00],
          [ 5.3515e+00, -3.7577e+00, -1.2903e+00,  ...,  2.9092e+00,
            6.8235e+00, -3.6641e-01],
          [ 8.5835e+00,  3.6280e-01,  2.6331e+00,  ..., -1.5287e+00,
            3.7305e+00,  1.7017e+00]],

         [[ 5.4719e-01,  5.8945e+00,  3.6384e+00,  ...,  5.9290e-01,
            6.5401e+00, -8.3168e-01],
          [-2.3086e+00, -1.6883e+00, -2.9866e+00,  ...,  1.2742e-01,
           -3.3645e+00, -1.5333e+00],
          [ 2.3729e+00, -3.9553e+00,  1.1772e+00,  ..., -1.6573e+00,
            2.1480e+00, -3.0698e+00],
          ...,
          [ 1.6869e+00, -5.7966e+00, -1.7680e+00,  ..., -5.9824e+00,
           -3.7660e+00, -2.8763e+00],
          [-3.7292e+00, -3.2804e-01, -4.0741e+00,  ..., -4.4123e+00,
           -1.7123e+00, -4.7638e+00],
          [-7.0541e+00, -5.5711e+00, -6.8061e+00,  ..., -5.6458e+00,
           -7.3458e+00, -6.6388e+00]],

         ...,

         [[-2.3684e+00, -1.2042e+00, -8.8991e+00,  ..., -4.2127e+00,
            8.4223e-02, -4.2727e+00],
          [-5.4316e+00,  2.8190e-01, -4.2922e+00,  ..., -5.0320e+00,
           -9.1889e+00, -6.6783e+00],
          [-2.2074e+00, -9.1135e-01,  4.8699e-01,  ..., -3.9411e+00,
            6.8704e-01, -2.0734e+00],
          ...,
          [ 2.3566e-01, -5.5058e+00, -6.6349e+00,  ..., -7.6634e+00,
           -3.2290e+00, -1.3448e+00],
          [-1.2552e+00,  1.6721e-01, -4.2178e+00,  ..., -1.2164e+00,
           -8.0527e+00, -3.5983e-01],
          [-9.6257e-01, -1.1936e-01, -3.9775e+00,  ..., -4.3331e+00,
           -3.4952e+00, -6.4649e-01]],

         [[-2.6087e+00, -5.8321e-01, -4.4971e+00,  ..., -4.9800e+00,
            2.1208e+00, -4.5245e+00],
          [-1.4842e+00,  5.1948e-01,  1.0777e+00,  ..., -7.6349e-01,
           -4.0771e+00, -3.2840e+00],
          [-3.7402e-01,  4.2671e-01,  1.2273e-01,  ..., -4.9246e-01,
           -2.2199e+00, -4.1484e+00],
          ...,
          [-1.4252e+00, -3.2505e+00, -5.5020e+00,  ..., -4.7821e-01,
           -2.1398e+00, -1.0328e+00],
          [-5.7796e+00, -2.1981e+00, -7.0260e+00,  ..., -5.7466e+00,
           -2.3525e+00, -3.5453e+00],
          [-3.5804e+00, -1.9787e+00, -3.7982e+00,  ..., -8.7940e+00,
            2.6128e-01, -2.0655e+00]],

         [[ 4.0095e-01,  1.3515e-01, -1.0337e+01,  ..., -5.1152e+00,
            4.4089e+00, -5.3602e+00],
          [-3.2711e-01,  7.6293e-02,  3.2586e-01,  ..., -3.4257e+00,
            2.6616e-01, -4.6196e+00],
          [-2.1327e-01, -2.3186e+00,  2.8104e+00,  ..., -9.6214e-01,
            6.1060e-01, -5.7282e+00],
          ...,
          [ 4.5902e+00, -9.7073e+00, -1.3207e+00,  ..., -9.0174e+00,
           -9.8054e-02, -1.2954e+00],
          [-1.9223e+00,  1.9472e+00, -4.1010e+00,  ..., -2.4121e+00,
           -9.2364e-01,  3.7176e-02],
          [-2.3366e+00, -3.2244e+00, -3.6521e+00,  ..., -3.4147e+00,
           -1.5715e+00, -4.4342e+00]]],


        [[[-2.5378e+00,  3.2608e+00,  4.4965e-01,  ..., -1.3264e+00,
           -3.9152e+00, -4.7755e+00],
          [ 6.7685e+00,  2.8770e+00, -1.0622e+01,  ...,  3.7247e+00,
           -2.5822e+00, -4.6125e+00],
          [ 1.5560e+00, -1.1886e+01,  2.2942e+00,  ...,  4.7512e+00,
            1.2038e+00,  4.7101e+00],
          ...,
          [ 6.8173e-02,  1.6808e+00,  3.0122e+00,  ...,  4.7925e+00,
            3.4413e+00, -6.1975e+00],
          [ 3.5979e-01,  1.6340e+00, -6.6561e+00,  ..., -6.2519e-01,
           -7.7153e-01, -1.0528e+01],
          [ 3.0109e+00,  7.5131e-01,  1.5622e+00,  ...,  4.2193e+00,
           -2.1440e+00,  2.0547e+00]],

         [[ 9.0380e+00, -4.6405e+00,  1.0877e+00,  ...,  4.0065e+00,
            4.5041e+00,  2.3741e+00],
          [-3.3808e+00, -1.5415e+00,  1.0883e+01,  ...,  2.6561e-02,
            6.9931e+00,  3.4434e+00],
          [ 2.3052e+00,  1.9048e+01, -3.3737e+00,  ...,  2.2934e+00,
           -1.4045e+00, -1.6980e+00],
          ...,
          [ 6.9435e+00, -2.3090e+00,  2.6158e-01,  ...,  1.5535e+00,
           -1.7923e+00,  6.0902e+00],
          [ 1.5760e+00,  2.9722e+00,  1.4296e+01,  ...,  4.7134e+00,
            1.2776e+00,  1.1338e+01],
          [-1.7333e+00, -1.9201e+00, -1.1540e+00,  ...,  2.9943e+00,
            1.9958e+00,  7.0859e+00]],

         [[-1.1758e-01,  8.8850e+00,  7.0958e+00,  ...,  3.2172e+00,
            1.9240e+00,  2.8798e+00],
          [ 3.4787e+00,  2.7848e+00, -2.8159e+00,  ...,  3.0141e+00,
           -3.1759e+00, -1.5789e+00],
          [-5.0195e-01, -4.9523e+00,  1.5819e+00,  ...,  3.5549e+00,
           -2.6064e+00,  1.0034e+00],
          ...,
          [-3.8253e+00,  1.5192e+00, -7.3521e-01,  ...,  5.1912e-01,
           -2.0005e-01, -3.0163e+00],
          [-3.7866e+00, -5.6902e+00, -2.3657e+00,  ...,  3.5699e+00,
            7.2896e-01, -5.0548e+00],
          [-4.0153e+00, -8.0954e+00, -6.6185e+00,  ..., -7.8700e+00,
           -7.8903e+00, -3.6639e+00]],

         ...,

         [[-3.6395e+00, -2.3480e+00, -2.3475e+00,  ..., -6.3476e+00,
           -2.2719e+00, -3.0728e-01],
          [-1.0507e+00, -8.2144e+00, -6.0888e+00,  ..., -9.0607e+00,
           -3.2355e+00, -2.6327e+00],
          [-3.4441e+00, -1.0709e+01, -6.0631e-01,  ..., -8.8394e+00,
            3.6931e-01, -6.4683e+00],
          ...,
          [-1.7831e+00, -1.7127e+00, -1.2041e+00,  ..., -9.6801e+00,
           -5.0828e+00, -3.0705e+00],
          [-1.2312e+00,  1.5819e+00, -9.4848e+00,  ..., -3.4315e+00,
            1.6442e+00, -4.3877e+00],
          [-3.2998e-01, -1.7279e+00, -1.0364e+00,  ..., -3.4662e+00,
           -1.6369e+00, -5.5787e+00]],

         [[-6.2131e+00, -3.8308e+00, -2.3575e+00,  ..., -6.8802e-01,
           -3.2881e+00, -1.9603e+00],
          [ 8.5165e-01, -3.6808e+00, -2.3493e+00,  ..., -7.4835e+00,
           -2.1663e+00, -3.8986e+00],
          [-1.1513e+00,  3.7875e-01, -1.9749e+00,  ..., -1.3156e+01,
           -6.3181e-01, -3.5813e+00],
          ...,
          [-5.8868e+00,  1.0928e+00,  3.5021e-01,  ..., -1.0817e+01,
           -2.1675e+00, -4.0988e-01],
          [-9.0863e-01,  1.9460e-01,  2.2699e-01,  ...,  1.3075e-01,
           -2.7057e+00, -2.6742e+00],
          [-1.5481e+00, -2.9707e+00, -3.4295e+00,  ..., -2.2443e-01,
           -1.0947e+00, -7.4968e+00]],

         [[-4.9729e+00,  3.7500e+00, -1.3418e+00,  ..., -2.6551e+00,
           -2.0662e+00,  1.4086e+00],
          [ 7.0536e+00,  2.2636e+00, -9.4439e+00,  ..., -1.8370e+00,
           -4.9959e+00, -8.0171e-01],
          [-2.7448e-01, -1.1506e+01,  4.1355e+00,  ..., -4.8487e+00,
           -4.0891e+00,  2.3634e+00],
          ...,
          [-1.2493e+00,  5.2933e-01, -1.1684e+00,  ..., -5.6619e+00,
            7.3091e-01, -5.4969e+00],
          [ 6.9305e-02, -1.0491e+00, -8.9900e+00,  ..., -1.3288e+00,
           -1.0795e+00, -8.1355e+00],
          [ 1.4307e+00, -2.1308e-01, -2.1390e+00,  ..., -1.9080e+00,
           -2.8320e+00, -7.0195e+00]]],


        [[[ 2.9129e+00, -3.2524e+00, -1.0498e+01,  ..., -9.3783e-01,
            6.5297e-01, -2.0500e+00],
          [ 8.1225e+00,  1.4358e+00, -7.3510e-03,  ...,  9.4304e-01,
           -7.0998e-01, -8.6658e+00],
          [ 5.9673e+00,  3.2915e+00,  4.3534e+00,  ...,  4.2182e+00,
           -1.3153e+01, -5.7487e+00],
          ...,
          [-6.2518e+00, -4.8582e+00,  6.2799e+00,  ..., -8.3154e-01,
            4.0259e+00, -6.4322e+00],
          [ 7.1037e+00, -8.5388e-01,  4.6802e+00,  ...,  5.0693e+00,
           -9.4300e+00,  6.5311e-01],
          [-8.7696e+00,  2.4000e+00, -8.8924e-02,  ...,  6.5733e+00,
           -1.1505e+00, -2.8300e-01]],

         [[ 2.9804e+00,  1.4091e+00,  1.1575e+01,  ...,  5.6818e-01,
            4.8548e+00,  1.5143e-01],
          [-3.8385e+00,  8.4773e-01, -1.6827e+00,  ...,  2.3126e+00,
            8.7318e-01,  8.9006e+00],
          [-2.5339e-02, -3.2646e-01,  5.0536e-01,  ..., -1.8991e+00,
            1.4900e+01,  2.2718e+00],
          ...,
          [ 1.5411e+01,  9.7844e+00,  5.1071e+00,  ...,  4.5280e+00,
           -2.5426e-01,  1.0773e+01],
          [-2.6787e+00,  2.8455e+00, -3.8178e+00,  ..., -1.4451e+00,
            1.7451e+01, -1.4249e+00],
          [ 1.6547e+01, -1.9263e-01, -4.3781e-01,  ..., -1.2596e+00,
            3.7611e-01, -3.8341e+00]],

         [[ 1.0818e+00,  3.1222e+00,  1.5981e+00,  ...,  1.7457e+00,
            1.5172e+00,  5.2124e+00],
          [-1.4462e-02, -2.9712e+00,  1.3128e+00,  ..., -2.0578e+00,
           -9.5090e-02, -4.7197e+00],
          [-1.1285e+00,  1.2692e+00, -2.0608e+00,  ..., -3.6723e-01,
           -5.5803e+00,  1.3684e+00],
          ...,
          [-5.1631e+00, -6.4088e+00, -4.8315e+00,  ...,  1.5361e-03,
            7.7202e-01, -5.8011e+00],
          [ 3.3883e+00,  6.6355e-01,  2.2215e+00,  ...,  5.8554e-01,
           -4.2502e+00,  4.0990e-02],
          [-8.4053e+00, -3.5686e+00, -5.0168e+00,  ..., -6.7569e+00,
           -8.6903e+00, -3.3849e+00]],

         ...,

         [[-4.2581e+00, -1.6043e+00, -6.3839e+00,  ..., -4.4372e+00,
           -2.3815e+00, -5.3646e+00],
          [-5.2819e+00, -1.7557e+00, -4.0093e+00,  ..., -6.7299e-01,
            2.9654e+00, -4.4950e+00],
          [-5.8555e+00, -6.1766e-01, -1.9686e+00,  ..., -6.5745e+00,
           -7.6376e+00,  4.3026e-01],
          ...,
          [-1.0421e+01, -4.6293e+00, -8.6746e+00,  ..., -3.5965e+00,
           -3.5933e+00, -4.8813e+00],
          [-7.2667e+00, -8.3722e-01,  2.3402e+00,  ..., -4.3745e+00,
           -7.7938e+00, -3.0501e+00],
          [-9.8982e+00, -8.5274e-01, -3.9499e+00,  ..., -2.3490e+00,
           -1.4945e-01,  3.4413e+00]],

         [[-4.0149e+00,  4.9003e-01, -1.1769e+00,  ...,  1.4570e+00,
           -1.6056e+00, -5.8829e+00],
          [-5.5961e+00, -1.0684e+00, -5.2302e+00,  ..., -3.3129e-01,
           -7.2735e-01, -1.1385e+00],
          [-3.6494e+00,  5.4787e-01, -9.0548e-01,  ..., -9.4075e+00,
           -1.7863e+00, -1.5484e+00],
          ...,
          [-1.7944e+00, -1.9112e+00, -9.4494e+00,  ..., -2.6003e+00,
           -1.8443e+00, -4.4106e+00],
          [-2.7612e+00, -1.8931e+00, -1.0150e+00,  ..., -4.4055e+00,
           -2.4169e+00, -8.5292e+00],
          [-2.3432e+00, -1.8556e+00, -3.7089e+00,  ..., -3.2623e+00,
           -8.9070e-01,  5.8870e-01]],

         [[-1.8713e+00, -3.5337e-01, -7.7096e+00,  ..., -3.8883e+00,
           -2.3640e+00,  1.1545e+00],
          [-3.7782e-01, -1.6465e+00, -1.6819e+00,  ..., -7.6526e-01,
           -5.1353e-02, -7.6864e+00],
          [-1.2148e+00,  1.1784e+00, -6.8672e-01,  ..., -1.2983e+00,
           -1.0557e+01,  1.1041e-01],
          ...,
          [-1.1753e+01, -5.4576e+00, -7.4540e+00,  ..., -2.3526e+00,
            2.0062e+00, -6.5911e+00],
          [ 4.0558e+00, -1.6913e+00, -3.7656e-01,  ..., -8.1791e-01,
           -8.4219e+00, -4.8191e-01],
          [-1.0964e+01, -9.6215e-01, -7.2972e-01,  ..., -1.8072e+00,
           -2.2855e+00,  1.0610e+00]]],


        [[[ 2.5016e+00,  1.0698e+00,  3.4521e+00,  ...,  5.0118e+00,
            7.5608e+00,  1.4694e+00],
          [ 7.0833e-01,  1.3762e+00, -4.7640e+00,  ..., -2.0714e+00,
            8.8311e-01,  2.1845e+00],
          [ 2.9299e+00, -6.0453e-02, -1.0535e+01,  ...,  5.4147e+00,
            6.2731e+00, -1.1893e+01],
          ...,
          [ 6.8906e+00,  3.5560e+00,  1.7076e+00,  ...,  2.3938e+00,
           -8.0700e+00,  4.7419e+00],
          [-5.8548e+00,  4.8004e+00, -7.5251e-01,  ..., -1.1322e+01,
            1.1121e-01, -6.3947e-01],
          [ 7.1562e+00,  5.7367e+00, -4.9174e+00,  ...,  7.0329e+00,
            2.6265e+00, -6.3648e+00]],

         [[ 4.4402e+00, -6.9526e-01, -1.2699e+00,  ..., -2.2463e+00,
            4.1943e-01,  1.7225e-01],
          [ 1.0682e+01,  4.8871e+00,  6.9055e+00,  ...,  8.7761e+00,
            1.0974e+00, -5.2440e-01],
          [-1.0705e+00,  2.0356e-01,  1.5952e+01,  ...,  2.8583e+00,
           -1.0515e-01,  1.6879e+01],
          ...,
          [-7.0857e-01, -4.3137e+00,  3.6930e+00,  ..., -5.1122e-01,
            1.5402e+01,  3.1482e-01],
          [ 1.5361e+01,  4.4810e+00,  3.3494e+00,  ...,  1.9350e+01,
           -3.9128e-01, -5.3630e-01],
          [ 7.0751e-02,  2.5153e-01,  9.9632e+00,  ..., -5.4473e-01,
           -1.2144e+00,  4.2313e+00]],

         [[ 1.1954e+00,  6.4219e+00,  6.5672e+00,  ...,  7.1422e+00,
            4.0349e+00,  1.7659e+00],
          [-2.3926e-01, -2.8691e+00, -4.8337e+00,  ...,  1.7903e+00,
            9.1861e-01,  1.5375e+00],
          [-1.2605e+00,  2.0248e+00, -2.9877e+00,  ...,  4.0647e+00,
           -5.4739e-01, -5.7719e+00],
          ...,
          [ 8.5466e-01,  3.3956e+00,  3.6725e-01,  ...,  3.6085e+00,
           -2.8826e+00, -1.1086e+00],
          [-4.1737e+00,  2.7180e+00,  1.7408e+00,  ..., -5.1662e+00,
            1.8879e+00,  4.7089e+00],
          [-3.7262e+00, -4.5568e+00, -8.7538e+00,  ..., -6.0021e+00,
           -5.4607e+00, -5.4281e+00]],

         ...,

         [[-4.3234e+00,  1.6179e+00, -6.8555e+00,  ..., -1.1388e+01,
           -5.0748e+00, -5.2572e+00],
          [-1.0714e+01, -1.7800e+00, -5.1540e+00,  ..., -6.4717e+00,
           -1.8327e+00, -1.2003e+00],
          [-4.3838e+00,  3.7503e+00, -1.0334e+01,  ..., -3.5575e+00,
           -7.5437e+00, -8.1095e+00],
          ...,
          [-8.7238e+00,  5.9201e-01, -1.2579e-01,  ..., -4.6013e+00,
           -9.6982e+00, -8.7206e+00],
          [-1.2267e+01, -8.0808e+00, -3.3668e+00,  ..., -1.0132e+01,
           -2.9706e+00, -1.1482e+00],
          [-8.7687e+00, -8.1628e+00, -4.8981e+00,  ..., -7.9001e+00,
           -3.4371e+00, -2.3546e+00]],

         [[-4.6445e+00,  7.1691e-01,  7.1731e-01,  ..., -7.1943e+00,
           -4.0470e+00, -3.7757e+00],
          [-3.8608e+00, -1.0421e+00,  2.7530e-01,  ...,  9.0852e-01,
           -2.4247e-01,  2.1944e-02],
          [-3.5391e+00, -1.5460e+00, -1.0164e-01,  ..., -5.6156e+00,
           -1.9198e+00, -1.1008e+00],
          ...,
          [-1.6786e+00,  6.8684e-01, -3.0361e+00,  ...,  1.7271e+00,
            7.5542e-01, -6.6224e+00],
          [-1.9133e+00, -7.6415e+00,  4.5468e-01,  ..., -2.5753e-01,
           -3.3249e+00,  1.3465e+00],
          [-8.0335e+00, -1.1284e+01, -1.7011e+00,  ..., -6.2486e+00,
           -3.2033e+00, -5.3966e+00]],

         [[-2.7029e+00,  3.3049e+00,  3.1050e+00,  ...,  1.2178e+00,
           -1.1030e+01, -5.2787e-01],
          [-9.8279e+00, -2.0378e+00, -5.8701e+00,  ..., -5.9798e+00,
           -2.4024e-01,  1.1307e+00],
          [ 1.3140e+00, -7.4284e-01, -1.4497e+01,  ..., -4.9141e+00,
            9.5263e-01, -9.2134e+00],
          ...,
          [ 2.4497e+00,  3.9600e+00, -2.8287e+00,  ...,  2.6477e+00,
           -1.2291e+01, -2.3906e+00],
          [-1.4567e+01, -5.7395e+00, -1.0856e+00,  ..., -9.9037e+00,
           -8.2304e-02,  3.9799e+00],
          [-3.0177e+00, -5.0472e+00, -5.7972e+00,  ..., -1.0227e+00,
           -1.2476e+00, -3.1531e+00]]]], grad_fn=<AddBackward0>), (1, 11): tensor([[[[-5.2457e-01]],

         [[-9.0641e-02]],

         [[ 6.9309e-03]],

         [[-4.1651e-02]],

         [[ 3.6035e-02]],

         [[ 7.6587e-02]],

         [[-4.3537e-02]],

         [[ 8.2914e-02]],

         [[ 7.5191e-02]],

         [[-3.5259e-02]],

         [[ 9.7200e-03]],

         [[ 7.6519e-02]],

         [[-1.0540e-01]],

         [[-9.4722e-03]],

         [[-1.0252e-01]],

         [[-9.5942e-02]],

         [[-6.3120e-02]],

         [[ 8.3527e-03]],

         [[ 9.0732e-02]],

         [[-9.4510e-03]],

         [[ 7.8720e-02]],

         [[-2.7347e-02]],

         [[-3.4990e-02]],

         [[ 1.4953e+00]],

         [[ 9.3071e-02]],

         [[-3.8515e-03]],

         [[ 8.3714e-03]],

         [[ 7.5222e-01]],

         [[ 7.9497e-02]],

         [[-3.2848e-02]],

         [[-1.3847e+00]],

         [[-9.0579e-02]],

         [[-8.4702e-02]],

         [[-3.5936e-02]],

         [[ 6.4979e-01]],

         [[ 1.0354e-01]],

         [[ 7.1519e-02]],

         [[ 3.4473e-02]],

         [[ 1.8043e-01]],

         [[-2.2056e-02]],

         [[-4.5158e-02]],

         [[-4.6025e-03]],

         [[-1.6268e-02]],

         [[ 1.0126e-02]],

         [[-6.3374e-02]],

         [[ 9.3107e-02]],

         [[-6.3685e-03]],

         [[ 1.2761e-01]],

         [[-2.8283e-02]],

         [[-1.1529e-02]],

         [[-1.0910e-01]],

         [[-9.6132e-02]],

         [[ 3.8084e-02]],

         [[-7.1576e-02]],

         [[-4.8643e-02]],

         [[-1.3422e-02]],

         [[-1.6516e-02]],

         [[ 1.5640e-02]],

         [[ 6.5359e-04]],

         [[-1.3652e-01]],

         [[ 6.3765e-02]],

         [[-1.8865e-02]],

         [[ 1.4003e-02]],

         [[-6.0155e-02]],

         [[-4.4193e-02]],

         [[ 6.0450e-02]],

         [[-2.3319e-02]],

         [[-3.4301e-02]],

         [[ 3.3352e-03]],

         [[-4.2756e-02]],

         [[-6.1809e-02]],

         [[ 3.6859e-02]],

         [[-1.4811e+00]],

         [[-3.1079e-01]],

         [[ 9.4004e-02]],

         [[ 1.3108e-02]],

         [[ 4.9802e-02]],

         [[-1.0885e-01]],

         [[-1.7630e-02]],

         [[-1.1146e-01]]],


        [[[-1.4515e+00]],

         [[-1.2795e-01]],

         [[ 3.4611e-04]],

         [[-1.1858e-02]],

         [[ 1.0673e-02]],

         [[ 2.7754e-02]],

         [[-1.6700e-02]],

         [[ 1.3671e-01]],

         [[ 1.5448e-03]],

         [[-4.0310e-02]],

         [[-2.3125e-02]],

         [[ 8.5489e-02]],

         [[-1.0249e-01]],

         [[-7.2909e-03]],

         [[-1.0817e-01]],

         [[-7.8327e-02]],

         [[ 3.6495e-02]],

         [[-2.1137e-03]],

         [[ 8.9906e-02]],

         [[-4.1041e-02]],

         [[ 6.6720e-02]],

         [[-3.0323e-02]],

         [[-9.8313e-02]],

         [[ 2.3465e+00]],

         [[ 1.3456e-01]],

         [[ 4.3833e-02]],

         [[ 4.8033e-02]],

         [[ 1.2038e+00]],

         [[ 1.3081e-01]],

         [[ 8.6614e-01]],

         [[-1.7027e+00]],

         [[-2.5907e-01]],

         [[-8.2633e-02]],

         [[-1.3535e-02]],

         [[ 1.2002e+00]],

         [[ 1.4329e-01]],

         [[ 7.0418e-02]],

         [[ 4.6614e-02]],

         [[ 1.5534e-01]],

         [[ 3.0151e-02]],

         [[ 1.1418e-03]],

         [[-2.7738e-02]],

         [[-3.1918e-02]],

         [[ 1.8135e-02]],

         [[-8.0968e-02]],

         [[ 5.4231e-02]],

         [[-1.4887e-02]],

         [[ 1.2174e-01]],

         [[-5.5556e-03]],

         [[-5.5448e-03]],

         [[-2.4155e-02]],

         [[-1.0438e-01]],

         [[ 2.5149e-03]],

         [[-1.0605e-01]],

         [[-5.0420e-02]],

         [[-1.0955e-02]],

         [[-8.0149e-03]],

         [[ 5.7242e-02]],

         [[-3.0045e-03]],

         [[-1.3625e-01]],

         [[ 8.9857e-02]],

         [[ 1.6135e-03]],

         [[ 5.6447e-03]],

         [[-1.0309e-01]],

         [[-3.7386e-02]],

         [[ 2.9250e-02]],

         [[-8.5455e-02]],

         [[-6.0839e-02]],

         [[-2.9519e-02]],

         [[-1.8547e-02]],

         [[-1.0303e-02]],

         [[ 1.2949e-02]],

         [[-1.5479e+00]],

         [[ 5.9475e-02]],

         [[ 9.3700e-02]],

         [[ 7.5539e-03]],

         [[ 7.0730e-02]],

         [[-1.0236e-01]],

         [[ 1.0816e-03]],

         [[-1.4112e-01]]],


        [[[-8.0570e-01]],

         [[-4.9268e-02]],

         [[ 2.3565e-02]],

         [[-3.7457e-02]],

         [[ 6.1992e-02]],

         [[ 7.9479e-02]],

         [[-5.6199e-02]],

         [[ 1.3817e-01]],

         [[ 6.8712e-02]],

         [[-8.7010e-02]],

         [[-3.9026e-02]],

         [[ 1.1440e-01]],

         [[-1.3158e-01]],

         [[-1.3754e-02]],

         [[-5.0816e-02]],

         [[-7.0937e-02]],

         [[ 1.0374e-02]],

         [[ 2.5079e-02]],

         [[ 8.8696e-02]],

         [[-4.0659e-02]],

         [[ 4.5634e-02]],

         [[-1.3016e-02]],

         [[-8.0217e-02]],

         [[ 1.4438e+00]],

         [[ 9.5005e-02]],

         [[-2.2090e-03]],

         [[-1.1135e-02]],

         [[ 6.4710e-01]],

         [[ 7.7375e-02]],

         [[ 5.5975e-02]],

         [[-1.2375e+00]],

         [[-2.2754e-02]],

         [[-1.1392e-01]],

         [[-2.0087e-02]],

         [[ 1.2146e+00]],

         [[ 1.3295e-01]],

         [[ 9.0176e-02]],

         [[ 4.9637e-02]],

         [[ 1.8386e-01]],

         [[ 2.8876e-02]],

         [[-8.3324e-02]],

         [[ 2.1065e-02]],

         [[-2.8856e-03]],

         [[-1.8431e-02]],

         [[-5.9949e-02]],

         [[ 6.5565e-02]],

         [[ 2.9046e-02]],

         [[ 1.4447e-01]],

         [[-6.4821e-02]],

         [[ 5.8598e-03]],

         [[-4.5724e-02]],

         [[-8.1505e-02]],

         [[-6.7232e-03]],

         [[-1.1918e-01]],

         [[-1.2851e-02]],

         [[-2.8473e-02]],

         [[-4.6490e-02]],

         [[ 5.0919e-02]],

         [[-4.9510e-03]],

         [[-1.5798e-01]],

         [[ 1.0105e-01]],

         [[-7.4230e-03]],

         [[-4.9707e-05]],

         [[-1.0887e-01]],

         [[-3.6714e-03]],

         [[ 5.4465e-02]],

         [[-6.0318e-02]],

         [[-2.1823e-03]],

         [[-5.2520e-02]],

         [[-7.7618e-02]],

         [[-1.1952e-02]],

         [[ 6.6764e-02]],

         [[-1.3330e+00]],

         [[-5.3365e-02]],

         [[ 9.9060e-02]],

         [[ 4.0359e-02]],

         [[ 6.6137e-02]],

         [[-1.4468e-01]],

         [[-3.3664e-02]],

         [[-1.2537e-01]]],


        [[[-6.9426e-01]],

         [[-4.4863e-02]],

         [[ 2.2523e-02]],

         [[-4.7155e-03]],

         [[ 5.5015e-02]],

         [[ 7.9360e-02]],

         [[-1.2255e-02]],

         [[ 1.2986e-01]],

         [[ 6.6702e-02]],

         [[-8.5852e-02]],

         [[-3.6752e-02]],

         [[ 3.7085e-02]],

         [[-1.4229e-01]],

         [[ 3.2262e-02]],

         [[-6.6824e-02]],

         [[-6.5876e-02]],

         [[-3.7177e-02]],

         [[ 6.0351e-02]],

         [[ 5.4006e-02]],

         [[-1.3893e-02]],

         [[ 9.5680e-02]],

         [[-2.3285e-02]],

         [[-4.7369e-02]],

         [[ 1.2449e+00]],

         [[ 7.2566e-02]],

         [[-2.0827e-02]],

         [[ 3.0519e-02]],

         [[ 1.0390e+00]],

         [[ 1.0937e-01]],

         [[ 1.5196e-01]],

         [[-1.4263e+00]],

         [[-3.3065e-01]],

         [[-1.0648e-01]],

         [[-4.3167e-02]],

         [[ 1.4001e+00]],

         [[ 1.4733e-01]],

         [[ 1.0521e-01]],

         [[-2.1732e-02]],

         [[ 8.2775e-02]],

         [[ 6.1806e-03]],

         [[-1.1490e-01]],

         [[ 3.5223e-02]],

         [[-2.5000e-02]],

         [[-3.6146e-02]],

         [[-6.2931e-02]],

         [[ 4.4221e-02]],

         [[ 9.2059e-03]],

         [[ 1.7048e-01]],

         [[-5.4687e-03]],

         [[-8.8526e-03]],

         [[-1.7545e-02]],

         [[-8.7997e-02]],

         [[ 9.5173e-02]],

         [[-1.0401e-01]],

         [[-2.9480e-02]],

         [[-2.3341e-02]],

         [[-2.3680e-02]],

         [[ 6.5220e-03]],

         [[ 2.1216e-02]],

         [[-1.7137e-01]],

         [[ 1.3781e-01]],

         [[ 3.6666e-02]],

         [[ 3.5770e-02]],

         [[-5.8884e-02]],

         [[ 2.6840e-02]],

         [[ 7.6601e-02]],

         [[-2.2913e-02]],

         [[ 2.9755e-02]],

         [[-5.8130e-02]],

         [[-5.0633e-02]],

         [[-4.0252e-02]],

         [[ 7.2196e-02]],

         [[-1.3212e+00]],

         [[ 3.3883e-01]],

         [[ 1.2068e-01]],

         [[ 2.6969e-02]],

         [[ 6.8377e-02]],

         [[-7.9761e-02]],

         [[-4.4104e-02]],

         [[-1.3036e-01]]],


        [[[-4.4211e-01]],

         [[-4.8531e-02]],

         [[-3.5263e-03]],

         [[-1.4494e-02]],

         [[ 6.3142e-02]],

         [[ 4.0937e-02]],

         [[ 1.8612e-02]],

         [[ 5.1083e-02]],

         [[ 3.2326e-03]],

         [[ 1.6770e-02]],

         [[ 2.6134e-02]],

         [[ 7.4758e-02]],

         [[-9.1505e-02]],

         [[-3.2799e-02]],

         [[-7.5923e-02]],

         [[-5.9663e-02]],

         [[-1.1412e-02]],

         [[ 4.6281e-02]],

         [[ 1.0953e-01]],

         [[-4.4402e-02]],

         [[ 2.5579e-02]],

         [[-3.3282e-02]],

         [[ 2.3869e-02]],

         [[ 2.3184e+00]],

         [[ 6.5172e-02]],

         [[-5.3314e-02]],

         [[-2.0621e-02]],

         [[ 1.0820e+00]],

         [[ 1.0450e-01]],

         [[-2.4184e-03]],

         [[-1.6176e+00]],

         [[-3.4672e-01]],

         [[-9.9723e-02]],

         [[-2.1342e-02]],

         [[ 1.6232e+00]],

         [[ 9.6315e-02]],

         [[ 8.7533e-02]],

         [[ 4.7890e-02]],

         [[ 1.7129e-01]],

         [[ 5.3281e-02]],

         [[-1.9626e-02]],

         [[ 3.9370e-02]],

         [[-2.6666e-02]],

         [[ 2.0070e-02]],

         [[-2.9020e-02]],

         [[ 2.5559e-02]],

         [[-2.8746e-02]],

         [[ 8.8259e-02]],

         [[-2.0874e-02]],

         [[-8.1804e-03]],

         [[-3.7313e-03]],

         [[-1.1712e-01]],

         [[ 1.9256e-02]],

         [[-6.3816e-02]],

         [[-1.8919e-02]],

         [[ 3.6255e-02]],

         [[-2.6700e-02]],

         [[ 7.4924e-02]],

         [[-1.2189e-02]],

         [[-1.2165e-01]],

         [[ 3.8856e-02]],

         [[-5.6185e-02]],

         [[ 7.0175e-03]],

         [[-7.9655e-02]],

         [[-3.2928e-02]],

         [[ 4.0587e-02]],

         [[-7.0626e-02]],

         [[ 1.3630e-02]],

         [[ 1.7611e-02]],

         [[-4.5643e-02]],

         [[-2.8755e-02]],

         [[ 3.4550e-02]],

         [[-7.7943e-01]],

         [[ 6.2508e-02]],

         [[ 1.0325e-01]],

         [[-2.4046e-02]],

         [[ 1.1975e-01]],

         [[-3.9672e-02]],

         [[ 5.5347e-02]],

         [[-1.0445e-01]]]], grad_fn=<AsStridedBackward1>)})

Other functionalities#

In addition to recording, the following functionalities are available and documented in Recorder:

  • One-the-fly postprocessing of activations during inference (e.g. clipping).

  • Local disabling of the recording during forward pass.

  • Access the recorded layers’ Module objects directly.

  • Get the named parameters of the recorded layers.

Implementation details#

This utility works by affecting hooks to every layer of net with Module.register_forward_hook. However, since layers are not aware of the context in which they are called, these hooks carry references to rnet and with it, the sufficient context to know when to trigger. This means that two different Recorder Nets can wrap the same net without any conflict. As an implementation detail, note that these references are made weak in order to be properly cleaned up upon deletion of rnet.

Total running time of the script: (0 minutes 0.158 seconds)

Gallery generated by Sphinx-Gallery