[Solved] RuntimeError: function ALSQPlusBackward returned a gradient different than None at position 3, but t

class ALSQPlus(Function):
    def forward(ctx, weight, alpha, g, Qn, Qp, per_channel, beta):
        # assert alpha > 0, "alpha={}".format(alpha)
        ctx.save_for_backward(weight, alpha, beta)
        ctx.other = g, Qn, Qp, per_channel
        if per_channel:
            sizes = weight.size()
            weight = weight.contiguous().view(weight.size()[0], -1)
            weight = torch.transpose(weight, 0, 1)
            alpha = torch.broadcast_to(alpha, weight.size())
            beta = torch.broadcast_to(beta, weight.size())
            w_q = Round.apply(torch.div((weight - beta), alpha)).clamp(Qn, Qp)
            w_q = w_q * alpha + beta
            w_q = torch.transpose(w_q, 0, 1)
            w_q = w_q.contiguous().view(sizes)
        return w_q

    def backward(ctx, grad_weight):
        weight, alpha, beta = ctx.saved_tensors
        g, Qn, Qp, per_channel = ctx.other
        if per_channel:
            sizes = weight.size()
            weight = weight.contiguous().view(weight.size()[0], -1)
            weight = torch.transpose(weight, 0, 1)
            alpha = torch.broadcast_to(alpha, weight.size())
            q_w = (weight - beta)/alpha
            q_w = torch.transpose(q_w, 0, 1)
            q_w = q_w.contiguous().view(sizes)
        smaller = (q_w < Qn).float() #bool value to floating point value, 1.0 or 0.0
         bigger = (q_w > Qp).float() #bool value to floating point value, 1.0 or 0.0
         between = 1.0-smaller -bigger #Get the index in the quantization interval
        if per_channel:
            grad_alpha = ((smaller * Qn + bigger * Qp + 
                between * Round.apply(q_w) - between * q_w)*grad_weight * g)
            grad_alpha = grad_alpha.contiguous().view(grad_alpha.size()[0], -1).sum(dim=1)
            grad_beta = ((smaller + bigger) * grad_weight * g).sum().unsqueeze(dim=0)
            grad_beta = grad_beta.contiguous().view(grad_beta.size()[0], -1).sum(dim=1)
        grad_weight = between * grad_weight
        #The returned gradient should correspond to the forward parameter
        return grad_weight, grad_alpha, grad_beta, None, None, None, None

RuntimeError: function ALSQPlusBackward returned a gradient different than None at position 3, but the corresponding forward input was not a Variable

The gradient return value of the backward function of Function should be consistent with the order of the parameters of forward

Modify the last line to return grad_weight, grad_alpha, None, None, None, None, grad_beta

