Source code for colibri.models.autoencoder

""" Autoencoder Architecture """

from . import custom_layers
import torch.nn as nn


[docs] class Autoencoder(nn.Module): r""" Autoencoder Model The autoencoder model is a neural network that is trained to learn a latent representation of the input data. The model is composed of an encoder and a decoder. The encoder compresses the input data into a latent space representation, while the decoder reconstructs the input data from the latent space representation. Usually, the autoencoder model is trained to minimize the reconstruction error between the input data and the reconstructed data as follows: .. math:: \mathcal{L} = \left\| \mathbf{x}- D(E(\mathbf{x})) \right\|_2^2 where :math:`\mathbf{x}` is the input data and :math:`\hat{\mathbf{x}} = D(E(\mathbf{x}))` is the reconstructed data with :math:`E(\cdot)` and :math:`D(\cdot)` the encoder and decoder networks, respectively. Implementation based on the formulation of authors in https://dl.acm.org/doi/book/10.5555/3086952 """ def __init__( self, in_channels=1, out_channels=1, features=[32, 64, 128, 256], last_activation="sigmoid", reduce_spatial=False, **kwargs, ): r""" Args: in_channels (int): number of input channels out_channels (int): number of output channels features (list, optional): number of features in each level of the Unet. Defaults to [32, 64, 128, 256]. last_activation (str, optional): activation function for the last layer. Defaults to 'sigmoid'. reduce_spatial (bool): select if the autoencder reduce spatial dimension Returns: torch.nn.Module: Autoencoder model """ super(Autoencoder, self).__init__() levels = len(features) self.inc = custom_layers.convBlock(in_channels, features[0], mode="CBRCBR") if reduce_spatial: self.downs = nn.ModuleList( [ custom_layers.downBlock(features[i], features[i + 1]) for i in range(len(features) - 1) ] ) self.ups = nn.ModuleList( [ custom_layers.upBlockNoSkip(features[i + 1], features[i]) for i in range(len(features) - 2, 0, -1) ] + [custom_layers.upBlockNoSkip(features[1], features[0])] ) # self.ups.append(custom_layers.upBlockNoSkip(features[0])) self.bottle = custom_layers.convBlock(features[-1], features[-1]) else: self.downs = nn.ModuleList( [ custom_layers.convBlock(features[i], features[i + 1], mode="CBRCBR") for i in range(levels - 1) ] ) self.bottle = custom_layers.convBlock(features[-1], features[-1]) self.ups = nn.ModuleList( [ custom_layers.convBlock(features[i + 1], features[i]) for i in range(len(features) - 2, 0, -1) ] + [custom_layers.convBlock(features[1], features[0])] ) self.outc = custom_layers.outBlock(features[0], out_channels, last_activation)
[docs] def forward(self, inputs, get_latent=False, **kwargs): r""" Forward pass of the autoencoder Args: inputs (torch.Tensor): input tensor get_latent (bool): if True, return the latent space Returns: torch.Tensor: output tensor """ x = self.inc(inputs) for down in self.downs: x = down(x) xl = self.bottle(x) x = xl for up in self.ups: x = up(x) if get_latent: return self.outc(x), xl else: return self.outc(x)