Source code for domainlab.compos.nn_zoo.net_gated
import torch.nn as nn
[docs]
class GatedDense(nn.Module):
def __init__(self, input_size, output_size, activation=None):
super(GatedDense, self).__init__()
self.activation = activation
self.sigmoid = nn.Sigmoid()
self.h = nn.Linear(input_size, output_size)
self.g = nn.Linear(input_size, output_size)
[docs]
def forward(self, x):
h = self.h(x)
if self.activation is not None:
h = self.activation(self.h(x))
g = self.sigmoid(self.g(x))
return h * g
# ==========================================================================
[docs]
class GatedConv2d(nn.Module):
def __init__(
self,
input_channels,
output_channels,
kernel_size,
stride,
padding,
dilation=1,
activation=None,
):
super(GatedConv2d, self).__init__()
self.activation = activation
self.sigmoid = nn.Sigmoid()
self.h = nn.Conv2d(
input_channels, output_channels, kernel_size, stride, padding, dilation
)
self.g = nn.Conv2d(
input_channels, output_channels, kernel_size, stride, padding, dilation
)
[docs]
def forward(self, x):
if self.activation is None:
h = self.h(x)
else:
h = self.activation(self.h(x))
g = self.sigmoid(self.g(x))
return h * g
# ==============================================================================
[docs]
class Conv2d(nn.Module):
def __init__(
self,
input_channels,
output_channels,
kernel_size,
stride,
padding,
dilation=1,
activation=None,
bias=True,
):
super(Conv2d, self).__init__()
self.activation = activation
self.conv = nn.Conv2d(
input_channels,
output_channels,
kernel_size,
stride,
padding,
dilation,
bias=bias,
)
[docs]
def forward(self, x):
h = self.conv(x)
if self.activation is None:
out = h
else:
out = self.activation(h)
return out