Source code for domainlab.compos.utils_conv_get_flat_dim

import torch


[docs] def get_flat_dim(module, i_channel, i_h, i_w, batchsize=5): """flat the convolution layer output and get the flat dimension for fully connected network :param module: :param i_channel: :param i_h: :param i_w: :param batchsize: """ img = torch.randn(i_channel, i_h, i_w) img3 = img.repeat(batchsize, 1, 1, 1) # create batchsize repitition conv_output = module(img3) if len(conv_output.shape) == 2: flat_dim = conv_output.shape[1] else: flat_dim = conv_output.shape[1] * conv_output.shape[2] * conv_output.shape[3] return flat_dim