class ALSQPlus(Function):
@staticmethod
def forward(ctx, weight, alpha, g, Qn, Qp, per_channel, beta):
# assert alpha > 0, "alpha={}".format(alpha)
ctx.save_for_backward(weight, alpha, beta)
ctx.other = g, Qn, Qp, per_channel
if per_channel:
sizes = weight.size()
weight = weight.contiguous().view(weight.size()[0], -1)
weight = torch.transpose(weight, 0, 1)
alpha = torch.broadcast_to(alpha, weight.size())
beta = torch.broadcast_to(beta, weight.size())
w_q = Round.apply(torch.div((weight - beta), alpha)).clamp(Qn, Qp)
w_q = w_q * alpha + beta
w_q = torch.transpose(w_q, 0, 1)
w_q = w_q.contiguous().view(sizes)
else:
w_q = Round.apply(torch.div((weight - beta), alpha)).clamp(Qn, Qp)
w_q = w_q * alpha + beta
return w_q
@staticmethod
def backward(ctx, grad_weight):
weight, alpha, beta = ctx.saved_tensors
g, Qn, Qp, per_channel = ctx.other
if per_channel:
sizes = weight.size()
weight = weight.contiguous().view(weight.size()[0], -1)
weight = torch.transpose(weight, 0, 1)
alpha = torch.broadcast_to(alpha, weight.size())
q_w = (weight - beta)/alpha
q_w = torch.transpose(q_w, 0, 1)
q_w = q_w.contiguous().view(sizes)
else:
q_w = (weight - beta)/alpha
smaller = (q_w < Qn).float() #bool value to floating point value, 1.0 or 0.0
bigger = (q_w > Qp).float() #bool value to floating point value, 1.0 or 0.0
between = 1.0-smaller -bigger #Get the index in the quantization interval
if per_channel:
grad_alpha = ((smaller * Qn + bigger * Qp +
between * Round.apply(q_w) - between * q_w)*grad_weight * g)
grad_alpha = grad_alpha.contiguous().view(grad_alpha.size()[0], -1).sum(dim=1)
grad_beta = ((smaller + bigger) * grad_weight * g).sum().unsqueeze(dim=0)
grad_beta = grad_beta.contiguous().view(grad_beta.size()[0], -1).sum(dim=1)
else:
grad_alpha = ((smaller * Qn + bigger * Qp +
between * Round.apply(q_w) - between * q_w)*grad_weight * g).sum().unsqueeze(dim=0)
grad_beta = ((smaller + bigger) * grad_weight * g).sum().unsqueeze(dim=0)
grad_weight = between * grad_weight
#The returned gradient should correspond to the forward parameter
return grad_weight, grad_alpha, grad_beta, None, None, None, None
RuntimeError: function ALSQPlusBackward returned a gradient different than None at position 3, but the corresponding forward input was not a Variable
The gradient return value of the backward function of Function should be consistent with the order of the parameters of forward
Modify the last line to return grad_weight, grad_alpha, None, None, None, None, grad_beta