Source code for domainlab.compos.nn_zoo.nn_torchvision
import torch.nn as nn
from domainlab.utils.logger import Logger
[docs]
class NetTorchVisionBase(nn.Module):
"""
fetch model from torchvision
"""
def __init__(self, flag_pretrain):
super().__init__()
self.net_torchvision = None
self.fetch_net(flag_pretrain)
[docs]
def fetch_net(self, flag_pretrain):
raise NotImplementedError
[docs]
def forward(self, tensor):
"""
delegate forward operation to self.net_torchvision
"""
out = self.net_torchvision(tensor)
return out
[docs]
def show(self):
"""
print out which layer will be optimized
"""
for name, param in self.net_torchvision.named_parameters():
if param.requires_grad:
logger = Logger.get_logger()
logger.info(f"layers that will be optimized: \t{name}")