Source code for domainlab.compos.nn_zoo.net_gated

import torch.nn as nn


[docs] class GatedDense(nn.Module): def __init__(self, input_size, output_size, activation=None): super(GatedDense, self).__init__() self.activation = activation self.sigmoid = nn.Sigmoid() self.h = nn.Linear(input_size, output_size) self.g = nn.Linear(input_size, output_size)
[docs] def forward(self, x): h = self.h(x) if self.activation is not None: h = self.activation(self.h(x)) g = self.sigmoid(self.g(x)) return h * g
# ==========================================================================
[docs] class GatedConv2d(nn.Module): def __init__( self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None, ): super(GatedConv2d, self).__init__() self.activation = activation self.sigmoid = nn.Sigmoid() self.h = nn.Conv2d( input_channels, output_channels, kernel_size, stride, padding, dilation ) self.g = nn.Conv2d( input_channels, output_channels, kernel_size, stride, padding, dilation )
[docs] def forward(self, x): if self.activation is None: h = self.h(x) else: h = self.activation(self.h(x)) g = self.sigmoid(self.g(x)) return h * g
# ==============================================================================
[docs] class Conv2d(nn.Module): def __init__( self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None, bias=True, ): super(Conv2d, self).__init__() self.activation = activation self.conv = nn.Conv2d( input_channels, output_channels, kernel_size, stride, padding, dilation, bias=bias, )
[docs] def forward(self, x): h = self.conv(x) if self.activation is None: out = h else: out = self.activation(h) return out