An example of gradient calculation for variables of non leaf nodes in Python

Time:2021-3-9

In Python, gradient calculation is generally only performed on leaf nodes, that is, nodes D and E in the figure below, while for non leaf nodes, that is, nodes C and E, B node does not explicitly keep the gradient in the intermediate calculation process (because generally only leaf nodes need to be updated), which can save a large part of the memory. However, in the debugging process, sometimes we need to monitor the gradient of intermediate variables to ensure the effectiveness of the network. At this time, we need to print out the gradient of non leaf nodes, in order to realize the optimization At present, we can achieve this goal by two means.

Register hook function

Tensor.register_ Hook [2] can register a hook function for reverse gradient conduction, which will calculate the tensor every timeIt is often used to print out the non leaf node gradient when debugging. Of course, with this method, you can also customize the gradient update method of a certain layer. [3] Specific to the gradient of printing non leaf nodes, the code is as follows:


def hook_y(grad):
 print(grad)

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3

y.register_hook(hook_y) 

out = z.mean()
out.backward()

The output is as follows:


tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

retain_grad()

Tensor.retain_ Grad () explicitly saves the gradient of non leaf nodes at the cost of increasing the consumption of video memory. However, the hook function is used to print directly in reverse calculation, so it will not increase the consumption of video memory, but it can be used to retain_ Grad () is more convenient than hook function. The codes are as follows:


x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)

The output is as follows:


tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

The above example of gradient calculation for variables of non leaf nodes in Python is the whole content shared by Xiaobian. I hope it can give you a reference and support developer.

Recommended Today

Implementation example of go operation etcd

etcdIt is an open-source, distributed key value pair data storage system, which provides shared configuration, service registration and discovery. This paper mainly introduces the installation and use of etcd. Etcdetcd introduction etcdIt is an open source and highly available distributed key value storage system developed with go language, which can be used to configure sharing […]