Source code for domainlab.compos.nn_zoo.net_adversarial
from torch import nn
from torch.autograd import Function
[docs]
class Flatten(nn.Module):
[docs]
def forward(self, x):
x = x.view(x.size()[0], -1)
return x
[docs]
class AutoGradFunReverseMultiply(Function):
"""
https://pytorch.org/docs/stable/autograd.html
https://pytorch.org/docs/stable/notes/extending.html#extending-autograd
"""
[docs]
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
[docs]
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
[docs]
class AutoGradFunMultiply(Function):
[docs]
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
[docs]
@staticmethod
def backward(ctx, grad_output):
output = grad_output * ctx.alpha
return output, None