Source code for mipcandy.data.transform

from typing import Literal

import torch
from torch import nn

from mipcandy.types import Transform

type _Order = Literal["transform", "image_only", "label_only"]


[docs] class JointTransform(nn.Module): def __init__(self, *, transform: Transform | None = None, image_only: Transform | None = None, label_only: Transform | None = None, keys: tuple[str, str] = ("image", "label"), order: tuple[_Order, _Order, _Order] = ("transform", "image_only", "label_only")) -> None: super().__init__() self.transform: Transform | None = transform self.image_only: Transform | None = image_only self.label_only: Transform | None = label_only self._keys: tuple[str, str] = keys self._order: tuple[_Order, _Order, _Order] = order
[docs] def forward(self, image: torch.Tensor, label: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: ik, lk = self._keys data = {ik: image, lk: label} for t in self._order: transform = getattr(self, t) if not transform: continue match t: case "transform": data = transform(data) case "image_only": data[ik] = transform(data[ik]) case "label_only": data[lk] = transform(data[lk]) return data[ik], data[lk]
[docs] class MONAITransform(nn.Module): def __init__(self, transform: Transform, *, keys: tuple[str, str] = ("image", "label")) -> None: super().__init__() self.transform: Transform = transform self._keys: tuple[str, str] = keys
[docs] def forward(self, data: torch.Tensor | dict[str, torch.Tensor]) -> torch.Tensor | dict[str, torch.Tensor]: if isinstance(data, torch.Tensor): return self.transform(data) ik, lk = self._keys image, label = data[ik], data[lk] return {ik: self.transform(image), lk: self.transform(label)}