Source code for mipcandy.common.module.conv

from typing import override

import torch
from torch import nn

from mipcandy.layer import LayerT


[docs] class AbstractConvBlock(nn.Module): def __init__(self, in_ch: int, out_ch: int, kernel_size: int, *, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", conv: LayerT = ..., norm: LayerT = ..., act: LayerT = ...) -> None: super().__init__() self.conv: nn.Module = conv.assemble(in_ch, out_ch, kernel_size, stride, padding, dilation, groups, bias, padding_mode) self.norm: nn.Module = norm.assemble(in_ch=out_ch) self.act: nn.Module = act.assemble()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.act(self.norm(self.conv(x)))
[docs] def _conv_block(default_conv: LayerT, default_norm: LayerT, default_act: LayerT) -> type[AbstractConvBlock]: class ConvBlock(AbstractConvBlock): def __init__(self, *args, **kwargs) -> None: if "conv" not in kwargs: kwargs["conv"] = default_conv if "norm" not in kwargs: kwargs["norm"] = default_norm if "act" not in kwargs: kwargs["act"] = default_act super().__init__(*args, **kwargs) return ConvBlock
ConvBlock2d: type[AbstractConvBlock] = _conv_block( LayerT(nn.Conv2d), LayerT(nn.BatchNorm2d, num_features="in_ch"), LayerT(nn.ReLU, inplace=True) ) ConvBlock3d: type[AbstractConvBlock] = _conv_block( LayerT(nn.Conv3d), LayerT(nn.BatchNorm3d, num_features="in_ch"), LayerT(nn.ReLU, inplace=True) )
[docs] class WSConv2d(nn.Conv2d):
[docs] @override def forward(self, x: torch.Tensor) -> torch.Tensor: w = self.weight w = (w - w.mean(dim=(1, 2, 3), keepdim=True)) / (w.std(dim=(1, 2, 3), keepdim=True) + 1e-5) return nn.functional.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs] class WSConv3d(nn.Conv3d):
[docs] @override def forward(self, x: torch.Tensor) -> torch.Tensor: w = self.weight w = (w - w.mean(dim=(1, 2, 3, 4), keepdim=True)) / (w.std(dim=(1, 2, 3, 4), keepdim=True) + 1e-5) return nn.functional.conv3d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)