import torch.nn as nn
from torchvision import models as torchvisionmodels
from domainlab.compos.nn_zoo.nn import LayerId
from domainlab.compos.nn_zoo.nn_torchvision import NetTorchVisionBase
from domainlab.utils.logger import Logger
[docs]
class AlexNetBase(NetTorchVisionBase):
"""
.. code-block:: python
AlexNet(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace=True)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace=True)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
(classifier): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Linear(in_features=9216, out_features=4096, bias=True)
(2): ReLU(inplace=True)
(3): Dropout(p=0.5, inplace=False)
(4): Linear(in_features=4096, out_features=4096, bias=True)
(5): ReLU(inplace=True)
(6): Linear(in_features=4096, out_features=7, bias=True)
)
)
"""
[docs]
def fetch_net(self, flag_pretrain):
self.net_torchvision = torchvisionmodels.alexnet(pretrained=flag_pretrain)
[docs]
class Alex4DeepAll(AlexNetBase):
"""
change the last layer output of AlexNet to the dimension of the
"""
def __init__(self, flag_pretrain, dim_y):
super().__init__(flag_pretrain)
if self.net_torchvision.classifier[6].out_features != dim_y:
logger = Logger.get_logger()
logger.info(
f"original alex net out dim "
f"{self.net_torchvision.classifier[6].out_features}"
)
num_ftrs = self.net_torchvision.classifier[6].in_features
self.net_torchvision.classifier[6] = nn.Linear(num_ftrs, dim_y)
logger.info(f"re-initialized to {dim_y}")
[docs]
class AlexNetNoLastLayer(AlexNetBase):
"""
Change the last layer of AlexNet with identity layer,
the classifier from VAE can then have the same layer depth as erm
model so it is fair for comparison
"""
def __init__(self, flag_pretrain):
super().__init__(flag_pretrain)
self.net_torchvision.classifier[6] = LayerId()