由于pytorch版本更新,在1.3及以上的版本中要求forward方法必须为静态方法。而在许多较早版本的代码实现中,梯度反转通常是这样写的:
from torch.autograd import Function
class GradReverse(Function):
def __init__(self, lambd):
self.lambd = lambd
def forward(self, x):
return x.view_as(x)
def backward(self, grad_output):
return (grad_output