Source code for colibri.models.custom_layers

""" Utilities for building layers. """

# import tensorflow as tf
# import tensorflow.keras.layers as layers

import torch
import torch.nn as nn


[docs] class Activation(nn.Module): """Activation Layer""" def __init__(self, activation="relu"): super(Activation, self).__init__() """ Activation Layer Args: activation (str or nn.functional, optional): Activation function, such as tf.nn.relu, or string name of built-in activation function, such as "relu". Returns: nn.Module: Activation layer """ if isinstance(activation, str): self.act_fn = self.get_activation(activation) else: self.act_fn = activation
[docs] def get_activation(self, name): """ Get activation function by name. Args: name (str): Name of the activation function. Returns: nn.Module: Activation function """ activations = { "relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "softmax": nn.Softmax(dim=1), "tanh": nn.Tanh(), "identity": nn.Identity(), } if name in activations.keys(): return activations[name] else: raise ValueError(f"Unknown activation function: {name}")
[docs] def forward(self, x): """ Computes the activation function. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Output tensor. """ return self.act_fn(x)
[docs] class convBlock(nn.Module): """Convolutional Block default configuration: (Conv2D => Batchnorm => ReLU) * 2 """ def __init__( self, in_channels=1, out_channels=1, kernel_size=3, bias=False, mode="CBR", factor=2, ): """Convolutional Block Args: out_channels (int, optional): number of output channels. Defaults to 1. kernel_size (int, optional): size of the kernel. Defaults to 3. bias (bool, optional): whether to use bias or not. Defaults to False. mode (str, optional): mode of the convBlock, posible values are: ['C', 'B', 'R', 'U', 'M', 'A']. Defaults to 'CBR'. factor (int, optional): factor for upsampling/downsampling. Defaults to 2. """ super(convBlock, self).__init__() self.layers = nn.ModuleList() pad_size = kernel_size // 2 conv_kwargs = dict( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=pad_size, bias=bias, ) first_conv = True for c in mode: layer = self.build_layer(c, conv_kwargs, factor) self.layers.append(layer) if c == "C" and first_conv: first_conv = False conv_kwargs["in_channels"] = conv_kwargs["out_channels"]
[docs] def forward(self, x): """ Forward pass of the convBlock. Args: x (torch.Tensor): Input tensor Returns: torch.Tensor: Output tensor """ for layer in self.layers: x = layer(x) return x
[docs] def build_layer(self, c, params, factor): """ Build layer based on the mode. Args: c (str): mode of the layer params (dict): parameters for the layer factor (int): factor for upsampling/downsampling Returns: nn.Module: Layer """ num_features = params["out_channels"] batch_norm_params = dict(num_features=num_features) params_mapping = { "C": (nn.Conv2d, params), "B": (nn.BatchNorm2d, batch_norm_params), "R": (nn.ReLU, None), "U": (nn.Upsample, dict(size=(factor, factor))), "M": (nn.MaxPool2d, dict(kernel_size=(factor, factor))), "A": (nn.AvgPool2d, dict(kernel_size=(factor, factor))), } if c in params_mapping.keys(): layer, params = params_mapping[c] return layer(**params) if params else layer() else: raise ValueError(f"Unknown layer type: {c}")
[docs] class downBlock(nn.Module): """Spatial downsampling and then convBlock""" def __init__(self, in_channels, out_channels): """ Args: in_channels (int): number of input channels out_channels (int): number of output channels Returns: nn.Module: DownBlock model """ super(downBlock, self).__init__() self.pool_conv = convBlock(in_channels, out_channels, mode="MCBRCBR") def forward(self, x): return self.pool_conv(x)
[docs] class upBlock(nn.Module): """Spatial upsampling and then convBlock""" def __init__(self, in_channels): """ Args: in_channels (int): number of input channels """ super(upBlock, self).__init__() self.up = nn.Sequential( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), nn.Conv2d(in_channels * 2, in_channels, kernel_size=1, bias=False), ) self.conv_block = nn.Sequential( convBlock(in_channels * 2, in_channels), convBlock(in_channels, in_channels) )
[docs] def forward(self, x1, x2): """ Forward pass of the upBlock. Args: x1 (torch.Tensor): Input tensor x2 (torch.Tensor): Input tensor Returns: torch.Tensor: Output tensor """ x1 = self.up(x1) # input is CHW diffY = x2.shape[-2] - x1.shape[-2] diffX = x2.shape[-1] - x1.shape[-1] x1 = nn.functional.pad( x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2) ) return self.conv_block(torch.cat([x2, x1], dim=1))
[docs] class upBlockNoSkip(nn.Module): """Spatial upsampling and then convBlock""" def __init__(self, in_channels,out_channels): """ Args: in_channels (int): number of input channels out_channels (int): number of output channels """ super(upBlockNoSkip, self).__init__() self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.conv_block = nn.Sequential( convBlock(in_channels,out_channels ), convBlock(out_channels, out_channels) )
[docs] def forward(self, x1): """ Forward pass of the upBlock. Args: x1 (torch.Tensor): Input tensor Returns: torch.Tensor: Output tensor """ x1 = self.up(x1) # input is CHW return self.conv_block(x1)
[docs] class outBlock(nn.Module): """Convolutional Block with 1x1 kernel and without activation""" def __init__(self, in_channels, out_channels, activation=None): """ Args: in_channels (int): number of input channels out_channels (int): number of output channels activation (str, optional): activation function. Defaults to None. """ super(outBlock, self).__init__() self.conv = convBlock(in_channels, out_channels, kernel_size=1, mode="C") self.act = Activation(activation) if activation else nn.Identity()
[docs] def forward(self, x): """ Forward pass of the outBlock. Args: x (torch.Tensor): Input tensor Returns: torch.Tensor: Output tensor """ x = self.conv(x) return self.act(x)