Source code for domainlab.compos.nn_zoo.net_adversarial

from torch import nn
from torch.autograd import Function


[docs] class Flatten(nn.Module):
[docs] def forward(self, x): x = x.view(x.size()[0], -1) return x
[docs] class AutoGradFunReverseMultiply(Function): """ https://pytorch.org/docs/stable/autograd.html https://pytorch.org/docs/stable/notes/extending.html#extending-autograd """
[docs] @staticmethod def forward(ctx, x, alpha): ctx.alpha = alpha return x.view_as(x)
[docs] @staticmethod def backward(ctx, grad_output): output = grad_output.neg() * ctx.alpha return output, None
[docs] class AutoGradFunMultiply(Function):
[docs] @staticmethod def forward(ctx, x, alpha): ctx.alpha = alpha return x.view_as(x)
[docs] @staticmethod def backward(ctx, grad_output): output = grad_output * ctx.alpha return output, None