Source code for domainlab.compos.nn_zoo.net_classif

"Classifier"
import torch
import torch.nn as nn
from torch.nn import functional as F


[docs] class ClassifDropoutReluLinear(nn.Module): """first apply dropout, then relu, then linearly fully connected, without activation""" def __init__(self, z_dim, target_dim): """ :param z_dim: :param target_dim: """ super().__init__() self.op_drop = nn.Dropout() self.op_linear = nn.Linear(z_dim, target_dim) torch.nn.init.xavier_uniform_(self.op_linear.weight) self.op_linear.bias.data.zero_()
[docs] def forward(self, z_vec): """ :param z_vec: """ hidden = F.relu(self.op_drop(z_vec)) logit = self.op_linear(hidden) return logit