Diving inside Neural Networks#

This tutorial provides a short practical overview of the Recorder class, which is designed to ease interactions with the internal states of PyTorch Neural Network objects (torch.nn.Module) during or after inference.

from scio.recorder import Recorder

Wrapping and visualizing your Neural Network#

Let us first load an arbitrary Neural Network. We use a lightweight Tiniest architecture trained on CIFAR10 and hosted on ThalesGroup’s hub. We also prepare future input data for this tutorial.

import torch

inputs = torch.rand(5, 3, 32, 32)  # Random inputs with 5 samples
net = torch.hub.load("ThalesGroup/scio:hub", "tiniest", trust_repo=True, verbose=False)
net = net.to(inputs)

To wrap it into a Recorder Net, rnet, one simply needs to specify an input_size (including batch dimension) or provide input_data. This will directly analyze and store the control flow of the model, using the torchinfo library.

rnet = Recorder(net, input_data=inputs[[0]])
rnet  # Visualize layers
Recorder instance for the following network
============================================================================================================================================
Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Param %
============================================================================================================================================
Tiniest (Tiniest)                        [1, 3, 32, 32]            [1, 10]                   --                             --
├─Conv2d (conv1): 1-1                    [1, 3, 32, 32]            [1, 48, 32, 32]           1,344                       1.38%
├─LayerNorm2d (ln1): 1-2                 [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    └─LayerNorm (ln): 2-1               [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
├─Block (l1): 1-3                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-2             [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-3             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-4             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-5             [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-1          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-6                 [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-7                 [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Block (l2): 1-4                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-8             [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-9             [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-10            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-11            [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-2          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-12                [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-13                [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Block (l3): 1-5                        [1, 48, 32, 32]           [1, 48, 32, 32]           48                          0.05%
│    └─Conv2d (dwconv1): 2-14            [1, 12, 32, 32]           [1, 12, 32, 32]           120                         0.12%
│    └─Conv2d (dwconv2): 2-15            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─Conv2d (dwconv3): 2-16            [1, 12, 32, 32]           [1, 12, 32, 32]           600                         0.62%
│    └─LayerNorm2d (ln): 2-17            [1, 48, 32, 32]           [1, 48, 32, 32]           --                             --
│    │    └─LayerNorm (ln): 3-3          [1, 32, 32, 48]           [1, 32, 32, 48]           96                          0.10%
│    └─Conv2d (fc1): 2-18                [1, 48, 32, 32]           [1, 96, 32, 32]           4,704                       4.82%
│    └─Conv2d (fc2): 2-19                [1, 48, 32, 32]           [1, 48, 32, 32]           2,352                       2.41%
├─Conv2d (dsconv): 1-6                   [1, 48, 32, 32]           [1, 80, 16, 16]           3,920                       4.02%
├─LayerNorm2d (ln2): 1-7                 [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    └─LayerNorm (ln): 2-20              [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
├─Block (l4): 1-8                        [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-21            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-22            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-23            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-24            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-4          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-25                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-26                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─Block (l5): 1-9                        [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-27            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-28            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-29            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-30            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-5          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-31                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-32                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─Block (l6): 1-10                       [1, 80, 16, 16]           [1, 80, 16, 16]           80                          0.08%
│    └─Conv2d (dwconv1): 2-33            [1, 20, 16, 16]           [1, 20, 16, 16]           200                         0.21%
│    └─Conv2d (dwconv2): 2-34            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─Conv2d (dwconv3): 2-35            [1, 20, 16, 16]           [1, 20, 16, 16]           1,000                       1.03%
│    └─LayerNorm2d (ln): 2-36            [1, 80, 16, 16]           [1, 80, 16, 16]           --                             --
│    │    └─LayerNorm (ln): 3-6          [1, 16, 16, 80]           [1, 16, 16, 80]           160                         0.16%
│    └─Conv2d (fc1): 2-37                [1, 80, 16, 16]           [1, 160, 16, 16]          12,960                     13.29%
│    └─Conv2d (fc2): 2-38                [1, 80, 16, 16]           [1, 80, 16, 16]           6,480                       6.64%
├─AdaptiveAvgPool2d (avgpool): 1-11      [1, 80, 16, 16]           [1, 80, 1, 1]             --                             --
├─Linear (fc): 1-12                      [1, 80]                   [1, 10]                   810                         0.83%
============================================================================================================================================
Total params: 97,530
Trainable params: 97,530
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 44.73
============================================================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 9.05
Params size (MB): 0.39
Estimated Total Size (MB): 9.45
============================================================================================================================================
Currently recording: None
============================================================================================================================================

Tip

For summary customization, refer to torchinfo.summary options. For example, it is possible to bound the depth of the representation tree with depth=2.

Note

In the case of dynamic control flow, refer to the force_static_flow argument of Recorder.

In many ways, this wrapper is transparent to the user. For example, one can naturally process data with rnet(inputs).

Selecting layers of interest#

The penultimate line of the above summary reports the layers that are currently set to be recorded (stored in rnet.recording). By default after instantiation, there are none.

print(repr(rnet).split("\n")[-2])  # Show penultimate summary line
Currently recording: None

One can arbitrarily set this using rnet.record() with the depth-idx identifiers from the summary (e.g. 1-9). For example, the following specifies that the output of the first Block and the penultimate layer should be recorded.

rnet.record((1, 3), (1, 11))
print(repr(rnet).split("\n")[-2])  # Show penultimate summary line
Currently recording: 1-3, 1-11

Note

Though not shown in the summary, it is possible to select 0-1 to refer to the entire model.

Warning

torchinfo.summary can only detect torch.nn.Module calls. As such, if a Neural Network uses activation functions, it should call their Module implementation (instead of their functional counterpart) and declare them as an attribute, for them to be visible as a layer in the summary. It is not necessary to declare a different attribute for every activation call (e.g. one self.relu = nn.ReLU() can be used multiple times in forward()).

Capturing internal states#

Once the recording layers are set, every forward pass will automatically store the corresponding internal states in the rnet.activations mapping. Its keys are the (depth, idx) 2-tuples.

out = rnet(inputs)  # Forward pass, records the activations
rnet.activations
mappingproxy({(1, 3): tensor([[[[ 6.5233e-01,  1.2725e+00,  3.6750e+00,  ..., -2.0159e+00,
            1.6833e+00, -7.0795e+00],
          [-2.3511e+00,  3.0193e+00,  5.5817e+00,  ..., -2.4651e+00,
            5.0006e+00, -1.2303e-01],
          [ 4.8228e+00, -2.6506e+00,  1.9564e+00,  ..., -2.3619e+00,
            3.6713e-01,  3.3659e+00],
          ...,
          [ 3.2584e+00,  7.2038e+00, -1.3254e-02,  ...,  1.4127e+00,
            2.7081e+00, -5.4466e-01],
          [ 2.0626e+00,  1.3517e+00,  3.4653e+00,  ...,  6.4322e+00,
            4.3682e+00, -4.7631e+00],
          [-3.0282e-01,  4.7805e+00,  4.1829e-01,  ..., -7.3475e+00,
           -1.0656e+01, -6.3892e-01]],

         [[ 2.0016e+00, -4.2851e+00, -3.3361e+00,  ...,  9.0451e-01,
           -2.5546e-02,  1.2962e+01],
          [ 7.5434e+00,  1.7444e-01, -4.1723e-01,  ...,  5.9211e+00,
           -3.0610e-01, -1.9664e-01],
          [ 1.0832e+00,  6.9586e+00, -7.1043e-01,  ..., -1.0445e-01,
            3.6905e+00, -2.0549e+00],
          ...,
          [ 1.3292e+00,  5.6252e+00,  1.0228e+00,  ...,  2.7499e+00,
           -8.0648e-01, -2.5378e+00],
          [ 8.7772e-01,  6.7162e-02,  8.5031e-01,  ..., -1.0823e+00,
           -3.4787e+00,  4.8770e+00],
          [ 2.7268e+00, -6.9700e-01,  3.4253e+00,  ...,  1.8982e+01,
            1.3100e+01,  3.4777e+00]],

         [[ 3.9734e+00,  7.0858e+00,  5.2050e+00,  ..., -1.7372e-01,
            4.6883e+00,  4.0503e+00],
          [-3.6360e+00,  3.2495e-01,  6.5608e-01,  ..., -3.9364e+00,
           -2.6859e+00, -3.2185e-01],
          [-1.1749e-01, -4.2827e+00, -4.4952e-01,  ..., -1.1595e+00,
           -3.6082e+00,  1.8130e+00],
          ...,
          [ 1.5995e+00,  1.7333e+00,  2.5619e+00,  ..., -2.2567e+00,
           -3.7918e+00,  8.4707e-01],
          [ 7.6828e-01, -9.5656e-01,  2.9186e+00,  ...,  3.0933e+00,
            3.6908e+00,  3.5129e-01],
          [-6.6618e+00, -6.2583e+00, -8.9655e+00,  ..., -8.5055e+00,
           -9.3281e+00, -5.7110e+00]],

         ...,

         [[-2.0452e+00,  1.6465e+00, -3.6077e+00,  ..., -1.7322e+00,
           -4.7165e+00, -9.8097e+00],
          [-4.0488e+00, -4.5992e+00, -9.9427e+00,  ..., -5.1120e+00,
            1.4924e-01, -1.6572e+00],
          [ 5.2877e-01, -5.2715e+00, -1.6590e+00,  ...,  2.9039e-01,
           -2.0551e+00, -7.5272e+00],
          ...,
          [-1.6548e+00, -5.9707e+00, -4.4551e+00,  ..., -4.1366e+00,
           -3.3673e+00, -3.4288e+00],
          [-1.3946e+00, -2.4655e+00, -4.4540e+00,  ..., -1.2424e+01,
           -7.0328e+00, -2.9884e+00],
          [-1.5960e+00, -4.8778e-01, -3.2412e+00,  ..., -1.1939e+01,
           -9.6983e+00, -1.5338e+00]],

         [[-1.3770e+00,  2.5613e+00,  3.6752e-01,  ..., -4.8071e+00,
           -1.3603e-01, -1.9606e+00],
          [-3.8776e+00, -6.0525e+00, -1.1574e+01,  ...,  7.5650e-01,
            1.1499e+00, -6.9600e-01],
          [-1.0761e+00,  2.3007e-01,  1.6208e+00,  ..., -3.2745e-01,
           -1.2642e+00, -6.1888e+00],
          ...,
          [ 4.3056e-01, -4.9257e+00, -2.1079e+00,  ..., -1.6968e-01,
           -1.1581e-01, -4.0252e+00],
          [-4.9029e-01, -2.8119e+00, -1.3632e+00,  ..., -9.6341e+00,
           -8.7947e+00, -7.4184e-01],
          [-1.9741e+00,  5.0125e-01, -4.8286e+00,  ..., -1.1642e+00,
           -2.4537e+00, -9.1049e-01]],

         [[-1.0088e+00,  3.3511e+00,  1.9951e+00,  ...,  7.8124e-01,
            6.5472e-01, -7.3793e+00],
          [-1.0471e+00, -2.9371e+00, -4.3393e+00,  ..., -3.6256e+00,
            1.4546e+00, -1.4636e-01],
          [-3.1720e-01, -3.1569e+00,  2.4827e-01,  ..., -1.3969e+00,
           -4.5368e+00, -4.5059e-01],
          ...,
          [ 5.6192e-02, -1.1771e+01, -7.3454e-01,  ..., -3.9873e+00,
           -7.3943e-01,  4.5994e-01],
          [ 6.9071e-01,  9.9550e-01,  3.9379e+00,  ..., -9.3957e-02,
            5.8742e-01, -2.5478e+00],
          [-2.3552e+00,  2.5087e+00, -3.6868e+00,  ..., -9.6194e+00,
           -6.6895e+00, -2.7953e+00]]],


        [[[-7.3062e-01,  4.1227e+00, -1.8534e+00,  ..., -1.4031e+00,
            2.4605e+00, -3.2796e-01],
          [ 5.3100e+00, -4.0370e+00,  4.0779e+00,  ..., -3.1971e+00,
           -9.1014e+00, -3.9647e+00],
          [ 7.8537e+00, -8.1526e+00, -2.3895e+00,  ..., -2.2035e-01,
            4.4271e+00,  2.2666e+00],
          ...,
          [-1.4318e+00,  1.2989e-01, -4.2285e+00,  ..., -7.5293e-01,
            6.3534e+00, -4.3295e+00],
          [ 1.0216e-01,  4.2370e+00, -4.1434e+00,  ..., -8.2557e-01,
            8.1967e+00, -6.7572e+00],
          [ 5.5297e+00,  1.8557e+00,  3.9495e+00,  ..., -2.4411e+00,
            3.9152e+00, -6.3006e+00]],

         [[ 5.1019e+00,  4.3368e+00, -2.6981e-01,  ..., -1.2600e-02,
           -1.8705e+00,  1.4273e+00],
          [ 1.8983e-01,  6.0136e+00, -2.4732e+00,  ...,  1.6463e+00,
            1.2504e+01,  4.0633e-01],
          [-1.6422e+00,  1.4972e+01,  1.2799e+00,  ...,  6.1390e-01,
           -2.8601e+00, -2.3114e+00],
          ...,
          [ 9.4956e+00,  3.2206e+00,  1.1456e+01,  ..., -2.9516e-01,
            5.2304e+00,  2.8423e+00],
          [ 4.0394e+00, -1.8296e+00,  1.2325e+01,  ..., -9.0142e-01,
            5.5555e-01,  9.7610e+00],
          [ 4.1251e+00,  2.2094e-01,  3.9263e-01,  ...,  4.3591e+00,
           -6.1109e-01,  3.7285e+00]],

         [[-9.2616e-01,  4.8996e+00,  2.2038e-01,  ...,  3.8011e+00,
            7.7366e+00,  2.8519e+00],
          [-2.1451e+00, -4.6895e+00,  3.4184e+00,  ..., -6.8061e-01,
           -5.8211e+00, -1.6821e-01],
          [ 2.2074e+00, -3.4175e+00,  1.0254e+00,  ..., -1.9705e+00,
           -1.1869e-01, -3.5983e-01],
          ...,
          [-4.8427e+00, -3.6016e-01,  1.9653e+00,  ...,  2.5899e+00,
            4.9322e+00,  1.5865e+00],
          [ 1.5283e-01,  5.9435e-01, -3.4333e+00,  ..., -9.8196e-01,
           -2.3765e+00, -4.0314e+00],
          [-4.8860e+00, -7.7206e+00, -8.1581e+00,  ..., -6.0939e+00,
           -8.4689e+00, -8.0877e+00]],

         ...,

         [[-2.6553e+00, -4.8703e+00, -8.4347e-01,  ..., -1.3372e+00,
           -2.3852e+00, -4.1670e+00],
          [ 3.0502e-01, -1.9758e+00, -7.5766e-01,  ..., -1.8331e+00,
           -8.5537e+00,  4.7931e-02],
          [-1.1037e+01, -6.7490e+00, -3.1965e+00,  ...,  9.2125e-01,
           -6.9282e+00, -4.6404e+00],
          ...,
          [-5.5029e+00, -4.4675e+00, -7.4941e+00,  ..., -1.0316e+00,
           -6.1296e+00, -1.7304e+00],
          [-4.3681e+00, -7.3960e+00, -8.0245e+00,  ...,  2.3347e+00,
           -4.7002e+00, -3.4558e+00],
          [-3.9489e+00,  3.0523e+00, -3.9450e+00,  ..., -1.3849e+00,
            1.9543e+00, -1.1938e+00]],

         [[-5.7532e+00, -7.2714e+00, -4.0360e+00,  ..., -2.2289e-01,
           -7.1593e+00, -2.9728e+00],
          [ 1.8120e+00, -1.2933e+00, -4.5048e+00,  ..., -3.8416e-01,
           -5.7972e-01, -1.3680e+00],
          [-9.5057e+00, -1.7900e+00, -2.3940e+00,  ..., -2.4502e+00,
           -3.1517e+00, -5.8292e+00],
          ...,
          [-3.5026e+00, -5.9044e+00, -1.3900e+00,  ..., -2.1667e+00,
           -5.1183e+00, -2.4302e+00],
          [ 3.2499e-01, -8.0495e+00, -1.7420e+00,  ..., -7.7651e-01,
           -2.1016e+00, -8.0108e-01],
          [ 7.1137e-01,  5.5817e-02, -1.5151e+00,  ..., -2.9268e-01,
           -4.3607e-01, -4.8940e+00]],

         [[-1.9933e+00, -5.4509e+00,  2.3585e-01,  ..., -2.5857e-01,
            1.1802e+00,  9.4774e-01],
          [ 3.4265e-01, -4.8233e+00,  4.3390e+00,  ..., -1.4111e+00,
           -1.2383e+01,  8.4576e-01],
          [-1.4839e+00, -9.7594e+00,  1.7826e+00,  ..., -8.0777e-01,
            2.8530e+00, -2.1039e-01],
          ...,
          [-4.9097e+00, -1.0676e+00, -5.7658e+00,  ..., -2.0357e-01,
           -7.3938e+00,  1.9967e+00],
          [-4.1636e-01, -1.1093e+00, -7.6709e+00,  ..., -1.5491e+00,
           -5.8253e-02, -4.8818e+00],
          [-1.5166e+00, -1.5905e+00, -1.6248e+00,  ..., -2.8731e+00,
           -2.0117e+00, -3.3845e+00]]],


        [[[-5.9192e+00,  4.2036e+00,  3.4209e+00,  ...,  2.8022e+00,
            1.8121e+00,  2.9062e+00],
          [ 7.4623e+00,  2.9832e+00, -2.3139e+00,  ..., -2.2023e+00,
            6.6131e-01, -3.1860e+00],
          [ 6.0408e+00, -3.0043e+00,  4.6770e-02,  ..., -3.7701e+00,
            3.4826e+00,  3.0656e+00],
          ...,
          [ 8.0958e+00,  1.7608e-02,  5.2744e-01,  ..., -1.1528e+00,
           -8.5997e+00, -7.0205e+00],
          [-3.3167e+00, -1.0717e+00,  3.9274e+00,  ...,  1.5429e+00,
           -9.1712e+00,  1.9269e+00],
          [ 4.0280e+00,  2.9898e+00, -1.1646e+00,  ..., -7.5879e+00,
            4.8128e+00, -7.7549e+00]],

         [[ 1.0683e+01, -2.4404e+00,  1.7317e+00,  ..., -1.3659e+00,
           -2.0939e+00, -1.5179e+00],
          [-8.4931e-01,  6.4297e-01,  6.3456e+00,  ...,  1.5607e+00,
            6.9650e-02,  7.4600e+00],
          [-1.8805e+00,  1.7074e+00, -1.0821e+00,  ...,  4.9930e+00,
            3.5487e-01, -2.3461e+00],
          ...,
          [-4.0633e+00, -1.5812e+00, -1.6448e+00,  ..., -1.2773e+00,
            1.1291e+01,  6.8508e+00],
          [ 1.4816e+01,  4.6878e+00, -1.8894e+00,  ...,  1.2750e+00,
            1.7430e+01, -9.0080e-01],
          [-7.3151e-01,  2.9363e+00, -1.9774e+00,  ...,  1.6337e+01,
           -2.3246e-01,  9.8385e+00]],

         [[-1.0607e+00,  5.9837e+00,  7.4401e-01,  ...,  8.2282e+00,
            6.8660e+00,  6.6937e+00],
          [-5.4667e-01,  6.2833e-02, -1.4221e+00,  ...,  2.2680e+00,
            1.5257e+00, -5.6832e+00],
          [ 3.9410e+00,  7.3042e-01,  2.4384e+00,  ...,  7.6438e-01,
           -7.7620e-03,  3.4148e+00],
          ...,
          [ 2.6569e+00, -3.3457e-01,  2.1153e+00,  ..., -1.2069e+00,
           -3.9180e+00, -5.4361e+00],
          [-4.8030e+00,  1.4880e+00,  5.1481e+00,  ..., -5.0601e-02,
           -2.1836e+00,  1.2549e+00],
          [-3.0122e+00, -6.9644e+00, -5.7939e+00,  ..., -9.0861e+00,
           -5.5394e+00, -5.9919e+00]],

         ...,

         [[-4.5980e+00, -7.7655e+00, -4.6116e+00,  ..., -7.6056e+00,
           -9.9377e-01, -1.0555e+01],
          [-6.1375e+00, -2.8210e+00, -5.0584e+00,  ..., -3.7663e+00,
            1.3283e-01, -2.9088e+00],
          [-1.3960e-01, -9.8522e-01,  3.8749e-01,  ..., -3.6240e+00,
           -4.6810e+00, -1.9846e+00],
          ...,
          [-9.9973e+00, -3.7210e+00,  1.5144e+00,  ...,  3.1846e+00,
           -8.8899e+00, -3.7720e+00],
          [-9.4362e+00, -3.9496e+00, -1.0421e+01,  ..., -2.8981e+00,
           -1.3408e+01, -6.9782e+00],
          [-2.5041e+00, -1.8554e+00, -3.1249e+00,  ..., -8.1617e+00,
           -8.5065e+00, -5.2079e+00]],

         [[-4.3261e+00, -8.6874e+00,  7.8414e-01,  ..., -7.0039e+00,
            1.3562e-01, -6.5936e+00],
          [-2.8993e+00, -1.7060e+00,  1.5017e+00,  ..., -2.7355e+00,
           -5.9267e-01, -6.2765e+00],
          [-2.8051e+00, -3.7840e+00, -1.9246e+00,  ...,  4.1752e-02,
           -8.8691e-01, -5.2326e-01],
          ...,
          [-6.9509e+00, -7.2668e-01, -2.8867e+00,  ..., -4.1812e+00,
           -1.4468e+00, -1.5905e+00],
          [-9.8756e-01,  1.4465e+00, -9.4484e+00,  ..., -3.7408e+00,
           -3.5955e-01, -7.3358e+00],
          [-3.0290e+00, -7.2690e-02, -2.3467e+00,  ..., -2.9267e+00,
           -8.3142e+00, -6.5755e-01]],

         [[-3.4004e+00,  1.4781e+00, -1.3803e+00,  ..., -1.0499e+00,
            1.0945e+00,  1.3589e+00],
          [ 1.9862e+00, -1.6888e-01, -6.8863e+00,  ..., -9.6377e-01,
            2.0018e+00, -5.3712e+00],
          [ 2.4903e+00, -1.0219e+00, -3.3611e-01,  ..., -2.2725e+00,
            1.3307e+00,  5.9890e+00],
          ...,
          [ 2.0407e-01, -4.7059e-02,  9.7970e-01,  ..., -4.2242e+00,
           -1.0915e+01, -6.0606e+00],
          [-9.6519e+00, -1.1645e+00, -8.6258e-01,  ..., -1.6677e+00,
           -1.5523e+01, -5.7993e-01],
          [ 2.6275e-01, -1.1228e+00,  8.5746e-01,  ..., -8.0541e+00,
           -1.5399e+00, -8.1798e+00]]],


        [[[ 5.8936e-01,  7.8243e-01, -7.1179e+00,  ..., -3.7642e+00,
            6.5379e+00, -2.4196e+00],
          [ 1.3411e-01, -8.5906e+00,  7.0834e-01,  ..., -1.2447e+01,
            2.2846e+00, -3.2347e+00],
          [ 7.2328e+00,  5.0532e+00,  4.6450e+00,  ...,  4.0002e+00,
            3.8318e+00, -1.3907e+01],
          ...,
          [ 6.2913e+00, -1.8038e+00, -2.9926e+00,  ...,  4.2868e+00,
           -1.8576e+00, -2.5001e+00],
          [ 8.5990e+00,  2.2179e+00, -1.3448e+00,  ..., -3.3233e+00,
            7.3457e-01, -1.3801e+00],
          [ 6.7728e+00,  3.5462e+00,  2.2410e+00,  ..., -2.3807e+00,
           -1.8719e+00, -3.1368e-01]],

         [[ 5.0711e+00, -1.1276e+00,  2.3973e+00,  ...,  8.8975e+00,
           -1.0609e+00,  1.0654e-01],
          [ 6.3718e+00,  7.3765e+00, -8.3533e-01,  ...,  2.0430e+01,
           -1.6735e+00,  1.0811e-01],
          [-2.7160e+00, -1.6037e+00, -3.4057e+00,  ...,  6.0701e+00,
            7.5746e+00,  1.4070e+01],
          ...,
          [ 2.8536e+00,  5.7896e+00,  5.0008e+00,  ...,  2.5695e+00,
           -2.6709e-01,  1.8222e+00],
          [-1.9213e+00, -1.0731e-01,  2.0868e+00,  ...,  4.8657e+00,
            3.6648e+00,  5.7251e+00],
          [-7.4048e-02,  5.0601e+00, -1.5218e+00,  ...,  7.1122e-01,
           -1.4880e+00,  1.6471e+00]],

         [[ 6.7612e-01,  5.2518e+00,  1.0126e+00,  ...,  3.9457e+00,
            6.6549e+00,  4.8504e+00],
          [-2.4306e+00, -1.8958e+00,  2.9191e+00,  ..., -5.8503e+00,
            2.8005e+00,  3.5782e-01],
          [ 4.7547e+00,  1.7572e+00,  3.4305e+00,  ..., -3.0324e+00,
           -2.1111e+00, -5.0665e+00],
          ...,
          [ 1.1894e+00, -2.1654e+00, -5.4536e+00,  ..., -1.5198e+00,
           -2.2588e-01,  3.7329e+00],
          [-9.0491e-01, -2.1042e+00, -4.0123e+00,  ..., -1.6335e+00,
           -1.3892e+00, -1.6607e+00],
          [-5.0064e+00, -7.4857e+00, -5.5036e+00,  ..., -7.8977e+00,
           -6.8192e+00, -6.6739e+00]],

         ...,

         [[-2.9253e+00,  3.5005e+00, -3.5539e+00,  ..., -6.5340e+00,
           -8.8980e+00,  2.7477e-01],
          [-7.4264e+00, -8.0098e+00, -3.7769e+00,  ..., -8.8385e+00,
           -4.5811e+00,  9.3861e-01],
          [-8.0120e+00, -2.2455e+00, -4.4846e+00,  ..., -4.6475e+00,
           -5.1621e+00, -9.3971e+00],
          ...,
          [-6.9831e+00, -2.8692e+00, -3.5771e+00,  ..., -2.8834e+00,
           -1.0022e+00, -3.0789e+00],
          [-8.5837e+00,  5.1946e-01,  1.2923e+00,  ..., -4.0153e+00,
           -2.3443e+00, -5.4136e+00],
          [-8.2162e+00, -4.1129e+00, -4.6675e+00,  ..., -9.6153e-01,
            5.9326e-01,  1.6929e+00]],

         [[-5.3795e+00, -2.6245e+00, -4.3451e+00,  ...,  1.1462e+00,
           -1.3452e+00, -3.3478e-01],
          [-3.3427e+00, -1.4048e+00, -4.6597e+00,  ..., -1.9735e+00,
           -5.9406e+00, -4.2700e-01],
          [-8.5716e+00, -4.3841e-02, -5.2244e+00,  ..., -5.7196e+00,
           -6.3830e+00, -1.3435e+00],
          ...,
          [-6.1542e+00, -2.9798e+00, -7.0204e-01,  ...,  1.2970e-01,
            1.9245e+00, -6.5425e+00],
          [-6.0461e+00,  7.4263e-01, -1.3138e+00,  ...,  1.9562e+00,
           -3.6258e+00, -4.8260e+00],
          [-5.8750e+00, -3.5743e+00, -2.4581e+00,  ..., -2.3343e+00,
           -1.9664e+00, -5.1732e-02]],

         [[-3.8448e+00, -1.1813e+00, -1.9804e+00,  ..., -5.8891e+00,
            3.7898e+00, -3.1666e-01],
          [-1.0531e+01, -6.9133e+00, -2.1756e+00,  ..., -1.0904e+01,
            9.4328e-01,  3.7365e-01],
          [ 1.1880e+00,  2.4892e+00,  4.7611e-01,  ..., -8.0032e+00,
           -6.4745e+00, -9.9672e+00],
          ...,
          [-3.3869e+00, -6.3937e-01, -2.0463e+00,  ..., -3.9541e-01,
            1.2610e-01,  1.8582e+00],
          [-1.5247e+00,  6.0222e-02,  1.3919e+00,  ..., -1.9552e+00,
            3.5866e-02, -6.6510e+00],
          [-4.0393e+00, -8.8218e+00,  9.4977e-01,  ..., -2.4028e+00,
           -1.8817e+00, -2.0748e+00]]],


        [[[ 5.1902e+00, -2.5074e+00, -2.3236e+00,  ..., -6.3354e+00,
            9.1413e-01, -1.3416e+01],
          [ 8.1971e+00,  1.4858e+00, -2.4290e-01,  ...,  2.7673e-01,
           -5.3310e-01,  2.2657e+00],
          [ 3.9276e+00,  5.2696e-01,  7.5711e-02,  ...,  3.6775e+00,
           -7.2514e+00,  2.5372e+00],
          ...,
          [-8.4791e+00, -3.0635e-01, -2.8241e+00,  ...,  3.3238e+00,
            4.6406e+00, -2.3960e-01],
          [ 5.7148e+00,  3.7471e-01, -2.9990e+00,  ...,  3.6763e+00,
           -7.5908e+00, -1.1047e+00],
          [-5.3574e+00, -8.5598e-01, -1.8702e+00,  ...,  9.8831e-01,
           -3.0815e+00,  8.9682e-01]],

         [[-7.9681e-01,  8.1125e+00,  3.8045e+00,  ...,  1.4377e+01,
            2.0649e+00,  1.9851e+01],
          [ 6.6132e+00,  2.5968e+00,  2.8381e-02,  ...,  1.4469e+00,
            1.1063e+00, -8.7605e-01],
          [ 6.6006e+00,  7.9083e-01, -9.8732e-01,  ...,  9.6654e-01,
            1.3830e+01,  6.0280e+00],
          ...,
          [ 1.4505e+01,  5.8782e-01,  6.1208e+00,  ...,  2.1617e-01,
           -2.2152e+00,  2.0203e+00],
          [-2.1795e+00,  4.9967e+00,  1.2258e+00,  ..., -7.7887e-01,
            1.8443e+01, -7.8966e-01],
          [ 1.5052e+01, -1.6728e-01, -1.4162e+00,  ...,  1.8614e+00,
            1.1371e+01,  1.1383e+00]],

         [[ 6.8054e+00,  2.6648e+00, -5.9114e-02,  ...,  1.5452e+00,
            4.2991e+00, -4.2143e-01],
          [-2.3147e+00, -5.9793e+00, -1.1817e-01,  ...,  3.4014e-01,
            1.3695e+00,  2.7101e+00],
          [-8.1830e-01,  2.1261e+00,  1.3695e+00,  ...,  1.4361e+00,
           -3.6925e+00,  4.4965e-01],
          ...,
          [-4.7109e+00,  5.5037e-01, -1.6684e+00,  ..., -2.9527e+00,
            1.3759e+00, -1.3377e+00],
          [ 4.6933e-01, -6.8781e-01, -3.2603e+00,  ...,  2.7682e-01,
           -7.7228e+00, -8.2130e-01],
          [-7.1306e+00, -7.8911e+00, -7.5318e+00,  ..., -4.3379e+00,
           -9.3014e+00, -3.8596e+00]],

         ...,

         [[-8.9817e+00, -8.1663e+00, -1.9369e+00,  ..., -8.9113e+00,
           -1.8504e+00, -1.2318e+01],
          [-8.1072e+00, -2.5226e+00,  3.1993e+00,  ...,  6.0087e-01,
            2.3217e+00, -7.7996e+00],
          [-7.7898e+00,  1.9552e+00,  6.7709e-01,  ..., -4.7968e+00,
           -8.9165e+00, -3.0918e+00],
          ...,
          [-7.3040e+00,  1.2128e+00, -2.8852e-01,  ..., -2.9824e+00,
           -1.2174e+00, -2.2805e+00],
          [-8.8611e+00, -4.5683e+00, -1.1759e+00,  ..., -9.6056e-01,
           -8.5215e+00, -6.1102e+00],
          [-8.6701e+00, -1.1889e+00,  1.2131e-01,  ...,  1.9976e+00,
           -2.8892e+00, -5.2790e+00]],

         [[-5.9191e+00, -5.4077e-01, -6.1802e+00,  ...,  5.7916e-01,
            2.8767e-01, -1.2585e+00],
          [-7.1150e+00, -1.4072e-02, -1.6010e+00,  ..., -8.3323e-01,
           -2.0631e+00, -9.1862e+00],
          [-4.1453e+00, -2.5330e+00, -1.1388e+00,  ..., -1.1534e+00,
           -2.8276e+00, -5.1779e+00],
          ...,
          [-2.9091e+00,  2.0146e-01, -5.9117e+00,  ..., -2.4761e+00,
           -9.8258e-01, -1.3773e+00],
          [-7.0199e+00, -6.5448e+00, -6.8515e-01,  ..., -9.4906e-01,
           -2.4331e+00, -1.0144e+01],
          [-2.5723e+00, -7.7172e-01, -2.4114e+00,  ..., -4.6692e+00,
           -4.9402e+00, -8.1389e+00]],

         [[ 1.4879e+00, -2.3059e+00, -9.6300e-01,  ..., -9.2393e+00,
           -2.7633e-01, -1.4581e+01],
          [-7.7094e+00, -1.4177e+00,  8.5891e-01,  ...,  2.6503e-01,
           -1.6040e+00, -1.8665e+00],
          [-9.7185e+00, -1.9287e-01,  7.3794e-01,  ...,  3.1534e-01,
           -1.3374e+01, -7.5156e+00],
          ...,
          [-4.7006e+00,  8.0546e-01, -2.0104e+00,  ..., -1.3869e+00,
            2.3705e+00, -2.0037e+00],
          [ 6.3267e-01, -7.1452e-01,  3.7883e-01,  ...,  3.5608e+00,
           -1.0132e+01, -3.7304e+00],
          [-7.5110e+00, -2.5800e+00, -1.2457e+00,  ..., -3.3402e+00,
           -4.3287e+00, -5.3510e+00]]]], grad_fn=<AddBackward0>), (1, 11): tensor([[[[-4.8212e-01]],

         [[-5.8952e-02]],

         [[ 2.4431e-02]],

         [[ 4.8122e-03]],

         [[ 1.1827e-02]],

         [[ 9.2532e-03]],

         [[ 3.1806e-02]],

         [[ 7.3927e-02]],

         [[ 6.3753e-02]],

         [[-4.6143e-02]],

         [[-3.3760e-02]],

         [[ 4.1940e-02]],

         [[-1.2214e-01]],

         [[-3.0096e-02]],

         [[-7.0030e-02]],

         [[-5.6574e-02]],

         [[-7.2440e-03]],

         [[ 1.3917e-02]],

         [[ 8.8828e-02]],

         [[-8.4762e-02]],

         [[ 6.4006e-02]],

         [[-7.2600e-03]],

         [[-5.6362e-02]],

         [[ 1.8911e+00]],

         [[ 1.0133e-01]],

         [[ 1.6228e-03]],

         [[ 4.2924e-02]],

         [[ 8.0931e-01]],

         [[ 1.4707e-01]],

         [[ 5.1171e-02]],

         [[-1.3607e+00]],

         [[-4.9574e-01]],

         [[-1.2056e-01]],

         [[-3.7938e-02]],

         [[ 8.4034e-01]],

         [[ 1.3793e-01]],

         [[ 1.0915e-01]],

         [[-1.7730e-02]],

         [[ 9.7807e-02]],

         [[ 4.8814e-02]],

         [[-5.6123e-02]],

         [[ 1.8888e-02]],

         [[-4.8763e-02]],

         [[-1.9005e-02]],

         [[-4.5759e-02]],

         [[ 5.0913e-02]],

         [[-9.6862e-03]],

         [[ 1.2879e-01]],

         [[ 5.0088e-03]],

         [[-1.0172e-02]],

         [[-7.6081e-03]],

         [[-7.6202e-02]],

         [[ 3.6588e-02]],

         [[-1.3640e-01]],

         [[-2.4451e-02]],

         [[ 4.3620e-03]],

         [[ 3.2869e-04]],

         [[ 3.9907e-02]],

         [[ 6.4439e-02]],

         [[-1.2589e-01]],

         [[ 9.6267e-02]],

         [[ 4.3577e-02]],

         [[ 1.6063e-02]],

         [[-7.7793e-02]],

         [[-1.9218e-02]],

         [[ 6.5340e-02]],

         [[-6.3679e-02]],

         [[-3.3674e-02]],

         [[-8.9329e-02]],

         [[-5.2894e-02]],

         [[-1.1498e-02]],

         [[ 4.2859e-02]],

         [[-1.6304e+00]],

         [[ 1.2883e-01]],

         [[ 9.3691e-02]],

         [[ 1.3553e-02]],

         [[ 5.4265e-02]],

         [[-2.7910e-02]],

         [[ 4.2478e-03]],

         [[-1.2281e-01]]],


        [[[-1.0707e+00]],

         [[-2.6321e-02]],

         [[ 2.9850e-02]],

         [[ 2.7604e-02]],

         [[ 2.9011e-02]],

         [[ 4.1728e-02]],

         [[ 5.1648e-02]],

         [[ 9.0549e-02]],

         [[ 3.9145e-02]],

         [[-4.0699e-02]],

         [[-2.3407e-02]],

         [[ 3.7422e-02]],

         [[-1.1937e-01]],

         [[-1.2759e-02]],

         [[-6.7126e-02]],

         [[-3.7855e-02]],

         [[-2.0908e-02]],

         [[ 4.5304e-02]],

         [[ 1.1141e-01]],

         [[-6.9791e-02]],

         [[ 1.7829e-02]],

         [[-1.2003e-02]],

         [[-2.9423e-02]],

         [[ 1.9270e+00]],

         [[ 8.2600e-02]],

         [[-1.7376e-02]],

         [[ 9.7004e-03]],

         [[ 4.6126e-01]],

         [[ 9.0155e-02]],

         [[ 8.9232e-01]],

         [[-1.7061e+00]],

         [[-1.1616e-01]],

         [[-1.3741e-01]],

         [[-2.3761e-02]],

         [[ 1.4548e+00]],

         [[ 1.2688e-01]],

         [[ 1.0433e-01]],

         [[-6.4277e-03]],

         [[ 1.0425e-01]],

         [[ 5.7487e-02]],

         [[-1.0029e-01]],

         [[ 8.4045e-03]],

         [[-2.4089e-02]],

         [[ 2.4195e-03]],

         [[-5.4022e-02]],

         [[ 2.5963e-02]],

         [[ 6.5183e-03]],

         [[ 1.1286e-01]],

         [[ 3.5302e-03]],

         [[-7.4671e-03]],

         [[ 3.8748e-02]],

         [[-6.9504e-02]],

         [[ 5.2977e-02]],

         [[-7.4818e-02]],

         [[-1.6362e-02]],

         [[-1.5918e-02]],

         [[-1.7679e-02]],

         [[ 1.9427e-02]],

         [[ 2.9266e-02]],

         [[-1.3562e-01]],

         [[ 1.1950e-01]],

         [[ 2.0752e-02]],

         [[-8.6824e-03]],

         [[-7.9381e-02]],

         [[ 1.6962e-02]],

         [[ 4.0761e-02]],

         [[-1.0492e-01]],

         [[ 1.2135e-02]],

         [[-1.0679e-01]],

         [[-2.0475e-02]],

         [[ 9.3614e-03]],

         [[ 2.9920e-02]],

         [[-1.0729e+00]],

         [[ 3.8340e-01]],

         [[ 9.2552e-02]],

         [[-8.3185e-03]],

         [[ 8.0000e-02]],

         [[-3.9019e-02]],

         [[ 3.6674e-02]],

         [[-1.2075e-01]]],


        [[[-7.5072e-01]],

         [[-5.4823e-02]],

         [[-1.4970e-02]],

         [[-4.0144e-02]],

         [[ 2.8744e-02]],

         [[ 6.5376e-02]],

         [[-1.8614e-02]],

         [[ 8.0005e-02]],

         [[ 2.7076e-02]],

         [[-3.8657e-02]],

         [[-3.1796e-02]],

         [[ 9.8985e-02]],

         [[-9.3073e-02]],

         [[-2.0905e-02]],

         [[-2.9689e-02]],

         [[-5.3493e-02]],

         [[ 6.2237e-03]],

         [[ 4.1819e-02]],

         [[ 6.4987e-02]],

         [[-3.8220e-02]],

         [[ 2.4581e-02]],

         [[-4.1277e-02]],

         [[-4.7817e-02]],

         [[ 1.7738e+00]],

         [[ 1.1574e-01]],

         [[-4.5813e-02]],

         [[-1.2059e-02]],

         [[ 8.7244e-01]],

         [[ 7.3906e-02]],

         [[-8.7500e-02]],

         [[-1.7406e+00]],

         [[-1.3740e-01]],

         [[-5.6365e-02]],

         [[-1.3872e-02]],

         [[ 7.0461e-01]],

         [[ 8.7310e-02]],

         [[ 1.0459e-01]],

         [[ 3.5917e-02]],

         [[ 1.4522e-01]],

         [[ 6.7422e-02]],

         [[-7.9781e-02]],

         [[ 4.1630e-02]],

         [[-4.6471e-02]],

         [[ 3.5472e-02]],

         [[-2.9973e-02]],

         [[ 4.0180e-02]],

         [[-1.2047e-02]],

         [[ 9.5932e-02]],

         [[-4.5369e-02]],

         [[-2.5814e-02]],

         [[-1.9315e-02]],

         [[-9.8329e-02]],

         [[ 5.8369e-02]],

         [[-1.1962e-01]],

         [[-3.5700e-02]],

         [[-1.5899e-02]],

         [[-4.3251e-02]],

         [[ 5.6479e-02]],

         [[ 1.4259e-02]],

         [[-1.2109e-01]],

         [[ 8.3271e-02]],

         [[ 1.3155e-02]],

         [[ 1.8504e-02]],

         [[-7.3442e-02]],

         [[-5.6152e-03]],

         [[ 9.5700e-02]],

         [[-5.4335e-02]],

         [[ 1.3616e-02]],

         [[-2.7832e-02]],

         [[-1.8983e-02]],

         [[-4.4270e-02]],

         [[ 6.5219e-02]],

         [[-8.8316e-01]],

         [[ 1.9295e-01]],

         [[ 8.8603e-02]],

         [[ 6.2901e-02]],

         [[ 2.3009e-02]],

         [[-4.2782e-02]],

         [[-8.5840e-03]],

         [[-1.0142e-01]]],


        [[[-1.1146e+00]],

         [[-5.8133e-02]],

         [[ 2.0659e-02]],

         [[ 6.3766e-03]],

         [[ 5.1625e-03]],

         [[ 3.6851e-02]],

         [[-2.3279e-03]],

         [[ 1.3387e-01]],

         [[ 4.4322e-02]],

         [[-5.4313e-02]],

         [[-2.7488e-02]],

         [[ 7.1881e-02]],

         [[-1.2224e-01]],

         [[-7.6094e-03]],

         [[-1.0609e-01]],

         [[-8.5367e-02]],

         [[ 1.8635e-02]],

         [[ 7.8259e-03]],

         [[ 6.2709e-02]],

         [[-7.1854e-02]],

         [[ 4.2441e-02]],

         [[-1.1480e-02]],

         [[-8.7642e-02]],

         [[ 2.1765e+00]],

         [[ 9.6011e-02]],

         [[ 2.2086e-02]],

         [[ 4.2160e-02]],

         [[ 7.9041e-01]],

         [[ 1.1993e-01]],

         [[ 3.3617e-01]],

         [[-1.4075e+00]],

         [[-3.3878e-01]],

         [[-1.1112e-01]],

         [[ 2.0954e-03]],

         [[ 1.2112e+00]],

         [[ 1.3614e-01]],

         [[ 7.8139e-02]],

         [[ 1.0703e-02]],

         [[ 1.2752e-01]],

         [[ 5.2833e-02]],

         [[-6.0520e-02]],

         [[-3.2265e-02]],

         [[ 2.8184e-03]],

         [[-7.7502e-03]],

         [[-6.3597e-02]],

         [[ 5.1120e-02]],

         [[-1.7001e-02]],

         [[ 1.1653e-01]],

         [[-2.6885e-03]],

         [[-5.5297e-03]],

         [[ 3.3320e-03]],

         [[-7.6868e-02]],

         [[ 2.0787e-02]],

         [[-8.7813e-02]],

         [[-2.9039e-02]],

         [[-3.4621e-02]],

         [[-8.7955e-03]],

         [[ 2.3810e-02]],

         [[ 2.1641e-02]],

         [[-1.5007e-01]],

         [[ 1.2207e-01]],

         [[ 3.0461e-02]],

         [[ 8.2840e-03]],

         [[-1.0523e-01]],

         [[-1.2943e-02]],

         [[ 3.0175e-02]],

         [[-8.0922e-02]],

         [[-2.3520e-02]],

         [[-7.0612e-02]],

         [[-2.1879e-02]],

         [[ 2.4740e-02]],

         [[ 5.8875e-02]],

         [[-7.6240e-01]],

         [[-2.0177e-01]],

         [[ 9.4671e-02]],

         [[ 1.7415e-02]],

         [[ 6.4323e-02]],

         [[-9.2020e-02]],

         [[ 1.1823e-02]],

         [[-1.1909e-01]]],


        [[[-1.2166e+00]],

         [[-7.7967e-02]],

         [[ 2.5106e-02]],

         [[ 2.0526e-03]],

         [[ 3.4703e-03]],

         [[ 1.4131e-02]],

         [[-3.8334e-03]],

         [[ 6.1465e-02]],

         [[ 4.2720e-02]],

         [[-1.4046e-02]],

         [[-4.7764e-02]],

         [[ 8.4652e-02]],

         [[-1.1927e-01]],

         [[-4.6535e-02]],

         [[-6.8880e-02]],

         [[-7.3123e-02]],

         [[ 1.9078e-02]],

         [[ 8.9287e-03]],

         [[ 6.5274e-02]],

         [[-7.7005e-02]],

         [[ 1.9824e-02]],

         [[-5.4538e-03]],

         [[-6.4757e-02]],

         [[ 2.3328e+00]],

         [[ 9.6518e-02]],

         [[-1.4625e-02]],

         [[ 2.8340e-02]],

         [[ 9.7249e-01]],

         [[ 1.1618e-01]],

         [[ 3.4606e-01]],

         [[-1.2764e+00]],

         [[-6.0303e-01]],

         [[-8.0727e-02]],

         [[-3.5762e-03]],

         [[ 1.0174e+00]],

         [[ 1.2546e-01]],

         [[ 9.4252e-02]],

         [[ 2.3340e-02]],

         [[ 1.4983e-01]],

         [[ 6.2312e-02]],

         [[-4.7162e-02]],

         [[-1.6018e-02]],

         [[-4.1424e-02]],

         [[ 2.5610e-02]],

         [[-6.5602e-02]],

         [[ 4.1556e-02]],

         [[-2.6147e-02]],

         [[ 7.8778e-02]],

         [[-4.0165e-03]],

         [[ 2.2457e-03]],

         [[-7.5596e-03]],

         [[-8.1376e-02]],

         [[ 7.2851e-03]],

         [[-1.0132e-01]],

         [[-3.9784e-02]],

         [[-2.1593e-02]],

         [[-1.5245e-02]],

         [[ 6.6059e-02]],

         [[ 2.0387e-02]],

         [[-1.3475e-01]],

         [[ 8.8745e-02]],

         [[ 6.8566e-03]],

         [[-3.4470e-02]],

         [[-1.0417e-01]],

         [[-4.8426e-02]],

         [[ 5.1008e-02]],

         [[-7.8170e-02]],

         [[-5.0224e-02]],

         [[-8.4004e-02]],

         [[-2.2815e-02]],

         [[ 3.6521e-02]],

         [[ 4.8454e-02]],

         [[-7.2690e-01]],

         [[ 4.5278e-01]],

         [[ 7.4350e-02]],

         [[ 3.1784e-02]],

         [[ 5.7295e-02]],

         [[-6.1810e-02]],

         [[ 2.7856e-02]],

         [[-1.4902e-01]]]], grad_fn=<AsStridedBackward1>)})

Other functionalities#

In addition to recording, the following functionalities are available and documented in Recorder:

  • One-the-fly postprocessing of activations during inference (e.g. clipping).

  • Local disabling of the recording during forward pass.

  • Access the recorded layers’ Module objects directly.

  • Get the named parameters of the recorded layers.

Implementation details#

This utility works by affecting hooks to every layer of net with Module.register_forward_hook. However, since layers are not aware of the context in which they are called, these hooks carry references to rnet and with it, the sufficient context to know when to trigger. This means that two different Recorder Nets can wrap the same net without any conflict. As an implementation detail, note that these references are made weak in order to be properly cleaned up upon deletion of rnet.

Total running time of the script: (0 minutes 0.148 seconds)

Gallery generated by Sphinx-Gallery