from typing import Literal
import torch
from torch import nn
from mipcandy.data import convert_ids_to_logits, convert_logits_to_ids
from mipcandy.metrics import do_reduction, binary_dice, dice_similarity_coefficient, soft_dice
[docs]
class FocalBCEWithLogits(nn.Module):
def __init__(self, alpha: float, gamma: float, *, reduction: Literal["mean", "sum", "none"] = "mean") -> None:
super().__init__()
self.alpha: float = alpha
self.gamma: float = gamma
self.reduction: Literal["mean", "sum", "none"] = reduction
[docs]
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
bce = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction="none")
p = torch.sigmoid(logits)
p_t = torch.where(targets.bool(), p, 1 - p)
alpha_t = torch.where(targets.bool(), torch.as_tensor(self.alpha, device=logits.device), torch.as_tensor(
1 - self.alpha, device=logits.device))
loss = alpha_t * (1 - p_t).pow(self.gamma) * bce
return do_reduction(loss, self.reduction)
[docs]
class Loss(nn.Module):
def __init__(self) -> None:
super().__init__()
self._validation_mode: bool = False
@property
def validation_mode(self) -> bool:
return self._validation_mode
@validation_mode.setter
def validation_mode(self, value: bool) -> None:
self._validation_mode = value
for child in self.children():
if isinstance(child, Loss):
child.validation_mode = value
[docs]
class SegmentationLoss(Loss):
def __init__(self, num_classes: int, include_background: bool) -> None:
super().__init__()
self.num_classes: int = num_classes
self.include_background: bool = include_background
[docs]
def logitfy_no_grad(self, ids: torch.Tensor) -> torch.Tensor:
if self.num_classes != 1 and ids.shape[1] == 1:
with torch.no_grad():
return convert_ids_to_logits(ids.int(), self.num_classes)
return ids.float()
[docs]
class DiceCELossWithLogits(SegmentationLoss):
def __init__(self, num_classes: int, *, lambda_ce: float = 1, lambda_soft_dice: float = 1,
smooth: float = 1e-5, include_background: bool = True) -> None:
super().__init__(num_classes, include_background)
self.lambda_ce: float = lambda_ce
self.lambda_soft_dice: float = lambda_soft_dice
self.smooth: float = smooth
[docs]
def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
ce = nn.functional.cross_entropy(outputs, labels[:, 0].long())
outputs = outputs.softmax(1)
labels = self.logitfy_no_grad(labels)
if not self.include_background:
outputs = outputs[:, 1:]
labels = labels[:, 1:]
dice = soft_dice(outputs, labels, smooth=self.smooth)
metrics = {"soft dice": dice.item(), "ce loss": ce.item()}
c = self.lambda_ce * ce + self.lambda_soft_dice * (1 - dice)
return c, metrics
[docs]
def forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
if not self.validation_mode:
return self._forward(outputs, labels)
with torch.no_grad():
c, metrics = self._forward(outputs, labels)
outputs = convert_logits_to_ids(outputs)
for i in range(0 if self.include_background else 1, self.num_classes):
class_dice = binary_dice(outputs == i, labels == i).item()
metrics[f"dice {i}"] = class_dice
metrics["dice"] = dice_similarity_coefficient(
self.logitfy_no_grad(outputs), self.logitfy_no_grad(labels)
).item()
return c, metrics
[docs]
class DiceBCELossWithLogits(SegmentationLoss):
def __init__(self, *, lambda_bce: float = 1, lambda_soft_dice: float = 1,
smooth: float = 1e-5, min_percentage_per_class: float | None = None) -> None:
super().__init__(1, True)
self.lambda_bce: float = lambda_bce
self.lambda_soft_dice: float = lambda_soft_dice
self.smooth: float = smooth
self.min_percentage_per_class: float | None = min_percentage_per_class
[docs]
def _forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
outputs = outputs.sigmoid()
labels = labels.float()
bce = nn.functional.binary_cross_entropy(outputs, labels)
dice = soft_dice(outputs, labels, smooth=self.smooth)
metrics = {"soft dice": dice.item(), "bce loss": bce.item()}
c = self.lambda_bce * bce + self.lambda_soft_dice * (1 - dice)
return c, metrics
[docs]
def forward(self, outputs: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
if not self.validation_mode:
return self._forward(outputs, labels)
with torch.no_grad():
c, metrics = self._forward(outputs, labels)
outputs = convert_logits_to_ids(outputs).bool()
metrics["dice"] = binary_dice(outputs, labels.bool()).item()
return c, metrics