torch.autograd.backward
torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)
retain_ graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way.Defaults to the value of create_ graph.create_ graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products.Defaults to False.
retain_ graph = True (when to use it?)
retain_ Graph is a parameter that we can’t use in ordinary times, but we will use it in special cases
-
- when there are two outputs in a network that need to be backwarded respectively: output1. Backward(), output2. Backward(). When there are two losses in a network that need to be backwarded respectively: loss1. Backward(), loss1. Backward(). </ OL> when there are two outputs in a network that need to be backwarded respectively
Take case 2. For example,
if the code is written like this, the parameter at the beginning of the blog will appear:
loss1.backward()
loss2.backward()
Correct code:
loss1.backward(retain_graph=True) #Keep the intermediate arguments after backward.
loss2.backward() # All intermediate variables are freed for the next loop
optimizer.step() # update parameters
retain_ The graph parameter is true to keep the intermediate parameter, so that the backward() of two loss will not affect each other.
Supplement: when two losses of two networks need to be backwarded respectively for backhaul: loss1. Backward(), loss1. Backward()
#The case of two networks requires defining separate optimizers for each of the two networks
optimizer1= torch.optim.SGD(net1.parameters(), learning_rate, momentum,weight_decay)
optimizer2= torch.optim.SGD(net2.parameters(), learning_rate, momentum,weight_decay)
.....
#train Part of the loss return processing
loss1 = loss()
loss2 = loss()
optimizer1.zero_grad() #set the grade to zero
loss1.backward(retain_graph=True) #Keep the intermediate parameters after backward.
optimizer1.step()
optimizer2.zero_grad() #set the grade to zero
loss2.backward()
optimizer2.step()
scheduler = torch.optim.lr_ Scheduler. Steplr (
appendix:
Step explanation
optimizer.zero_ grad()
Initialize the gradient to zero
(because the derivative of loss of a batch with respect to weight is the sum of the derivative of loss with respect to weight of all samples)
corresponding to d_ weights = [0] * n
output = net(inputs)
The predicted value is obtained by forward propagation
loss = Loss(outputs, labels)
Ask for loss
loss.backward()
Back propagation for gradient
corresponding D_ weights = [d_ weights[j] + (label[k] – output ) * input[k][j] for j in range(n)]
optimizer.step()
Update all parameters
corresponding weights = [weights [k] + alpha * D_ weights[k] for k in range(n)]
Read More:
- RuntimeError: cudnn RNN backward can only be called in training mode
- How to get the current time in java time string
- After switching the tidb database, an error could not commit JDBC transaction appears from time to time
- The time of the time database displayed by the front end is inconsistent
- [Solved] RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 ‘mat1‘
- When calling time module – time / datetime in wxPython, an error is reported. Valueerror: unknown locale: zh cn
- Start rqt_ Graph, prompt / opt / ROS / melody / share / PR2_ motor_ diagnostic_ Under tool plugin.xml There is something wrong with the file
- raise RuntimeError(RuntimeError: ‘cryptography‘ package is required for sha256_password or caching
- Solve runtimeerror: reduce failed to synchronize: device side assert triggered problem
- Android solution to the conflict of calling ontouch and onclick at the same time
- RuntimeError: Unable to find a valid cuDNN algorithm to run convolution
- Error connecting to master, – retry time: 60 retries: 86400
- It’s time to upgrade your parquet: IOException: totalvaluecount = = 0
- RuntimeError: Found dtype Double but expected Float”
- Set the default time to the current value in MySQL
- How to bypass screen time password with itolab unlockgo for Mac
- RuntimeError:An attempt has been made to start a new process before the……
- Server (for example: HTTP) has a large number of time_ Solutions to wait
- RuntimeError: Expected hidden[0] size (x, x, x), got(x, x, x)
- Go compiles the EXE executable to remove the CMD window at run time