class ALSQPlus(Function):
@staticmethod
def forward(ctx, weight, alpha, g, Qn, Qp, per_channel, beta):
# assert alpha > 0, "alpha={}".format(alpha)
ctx.save_for_backward(weight, alpha, beta)
ctx.other = g, Qn, Qp, per_channel
if per_channel:
sizes = weight.size()
weight = weight.contiguous().view(weight.size()[0], -1)
weight = torch.transpose(weight, 0, 1)
alpha = torch.broadcast_to(alpha, weight.size())
beta = torch.broadcast_to(beta, weight.size())
w_q = Round.apply(torch.div((weight - beta), alpha)).clamp(Qn, Qp)
w_q = w_q * alpha + beta
w_q = torch.transpose(w_q, 0, 1)
w_q = w_q.contiguous().view(sizes)
else:
w_q = Round.apply(torch.div((weight - beta), alpha)).clamp(Qn, Qp)
w_q = w_q * alpha + beta
return w_q
@staticmethod
def backward(ctx, grad_weight):
weight, alpha, beta = ctx.saved_tensors
g, Qn, Qp, per_channel = ctx.other
if per_channel:
sizes = weight.size()
weight = weight.contiguous().view(weight.size()[0], -1)
weight = torch.transpose(weight, 0, 1)
alpha = torch.broadcast_to(alpha, weight.size())
q_w = (weight - beta)/alpha
q_w = torch.transpose(q_w, 0, 1)
q_w = q_w.contiguous().view(sizes)
else:
q_w = (weight - beta)/alpha
smaller = (q_w < Qn).float() #bool value to floating point value, 1.0 or 0.0
bigger = (q_w > Qp).float() #bool value to floating point value, 1.0 or 0.0
between = 1.0-smaller -bigger #Get the index in the quantization interval
if per_channel:
grad_alpha = ((smaller * Qn + bigger * Qp +
between * Round.apply(q_w) - between * q_w)*grad_weight * g)
grad_alpha = grad_alpha.contiguous().view(grad_alpha.size()[0], -1).sum(dim=1)
grad_beta = ((smaller + bigger) * grad_weight * g).sum().unsqueeze(dim=0)
grad_beta = grad_beta.contiguous().view(grad_beta.size()[0], -1).sum(dim=1)
else:
grad_alpha = ((smaller * Qn + bigger * Qp +
between * Round.apply(q_w) - between * q_w)*grad_weight * g).sum().unsqueeze(dim=0)
grad_beta = ((smaller + bigger) * grad_weight * g).sum().unsqueeze(dim=0)
grad_weight = between * grad_weight
#The returned gradient should correspond to the forward parameter
return grad_weight, grad_alpha, grad_beta, None, None, None, None
RuntimeError: function ALSQPlusBackward returned a gradient different than None at position 3, but the corresponding forward input was not a Variable
The gradient return value of the backward function of Function should be consistent with the order of the parameters of forward
Modify the last line to return grad_weight, grad_alpha, None, None, None, None, grad_beta
Read More:
- [Solved] pytorch loss.backward() Error: RuntimeError: Function AddBackward0 returned an invalid gradient at index 1…
- RuntimeError: stack expects each tensor to be equal size, but got [x] at entry 0 and [x] at entry 1
- Pytorch torch.cuda.FloatTensor Error: RuntimeError: one of the variables needed for gradient computation has…
- [Solved] Pytorch error: RuntimeError: one of the variables needed for gradient computation
- [Solved] RuntimeError: one of the variables needed for gradient computation has been modified by an inplace
- [Solved] pytorc Error: RuntimeError: one of the variables needed for gradient computation has been modified by an
- [Solved] RuntimeError (note: full exception trace is shown but execution is paused at: <module>)
- How to Solve RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu
- Using postman Test Django Interface error: RuntimeError:You called this URL via POST,but the URL doesn‘t end in a slash
- [Solved] RuntimeError: cublas runtime error : resource allocation failed at
- [Solved] RuntimeError: expected scalar type Long but found Float
- [Solved] Pytorch Error: RuntimeError: expected scalar type Double but found Float
- [Solved] RuntimeError: cuda runtime error (801) : operation not supported at
- [Solved] SyntaxError: (unicode error) ‘unicodeescape‘ codec can‘t decode bytes in position 10-11: malformed
- [Solved] SyntaxError: (unicode error) ‘unicodeescape‘ codec can‘t decode bytes in position 6-7: malformed
- [Solved] RuntimeError: cuda runtime error (100) : no CUDA-capable device is detected at
- SyntaxError: (unicode error) ‘unicodeescape‘ codec can‘t decode bytes in position 2-3: truncated \UX
- SyntaxError: (unicode error) ‘unicodeescape‘ codec can‘t decode bytes in position 2-3: truncated \UX
- [Solved] bushi RuntimeError: version_ <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /pytorch/caffe2/s
- Python Error: SyntaxError: (unicode error) ‘unicodeescape‘ codec can‘t decode bytes in position 2-3: