import torch
from mipcandy.types import Device, Reduction
[docs]
def _args_check(outputs: torch.Tensor, labels: torch.Tensor, *, dtype: torch.dtype | None = None,
device: Device | None = None) -> tuple[torch.dtype, Device]:
if outputs.shape != labels.shape:
raise ValueError(f"Outputs ({outputs.shape}) and labels ({labels.shape}) must have the same shape")
if (outputs_dtype := outputs.dtype) != labels.dtype or dtype and outputs_dtype != dtype:
raise TypeError(f"Outputs({outputs_dtype}) and labels ({labels.dtype}) must both be {dtype}")
if (outputs_device := outputs.device) != labels.device:
raise RuntimeError(f"Outputs ({outputs.device}) and labels ({labels.device}) must be on the same device")
if device and outputs_device != device:
raise RuntimeError(f"Tensors are expected to be on {device}, but instead they are on {outputs.device}")
return outputs_dtype, outputs_device
[docs]
def do_reduction(x: torch.Tensor, method: Reduction) -> torch.Tensor:
match method:
case "mean":
return x.mean()
case "median":
return x.median()
case "sum":
return x.sum()
case "none":
return x
[docs]
def binary_dice(outputs: torch.Tensor, labels: torch.Tensor, *, if_empty: float = 1,
reduction: Reduction = "mean") -> torch.Tensor:
"""
:param outputs: boolean class ids (B, 1, ...)
:param labels: boolean class ids (B, 1, ...)
:param if_empty: the value to return if both outputs and labels are empty
:param reduction: the reduction method to apply to the dice score
"""
_args_check(outputs, labels, dtype=torch.bool)
axes = tuple(range(2, outputs.ndim))
volume_sum = outputs.sum(axes) + labels.sum(axes)
if volume_sum == 0:
return torch.tensor(if_empty, dtype=torch.float)
return do_reduction(2 * (outputs & labels).sum(axes) / volume_sum, reduction)
[docs]
def dice_similarity_coefficient(outputs: torch.Tensor, labels: torch.Tensor, *, if_empty: float = 1,
reduction: Reduction = "mean") -> torch.Tensor:
"""
:param outputs: one-hot (B, N, ...)
:param labels: one-hot (B, N, ...)
:param if_empty: the value to return if both outputs and labels are empty
:param reduction: the reduction method to apply to the dice score
"""
_args_check(outputs, labels, dtype=torch.float)
axes = tuple(range(2, outputs.ndim))
tp = (outputs * labels).sum(axes)
fp = (outputs * (1 - labels)).sum(axes)
fn = ((1 - outputs) * labels).sum(axes)
volume_sum = 2 * tp + fp + fn
if (volume_sum == 0).any():
return torch.tensor(if_empty, dtype=torch.float)
return do_reduction(2 * tp / volume_sum, reduction)
[docs]
def soft_dice(outputs: torch.Tensor, labels: torch.Tensor, *, smooth: float = 1, batch_dice: bool = True,
reduction: Reduction = "mean") -> torch.Tensor:
"""
:param outputs: logits (B, C, ...)
:param labels: logits (B, C, ...)
:param smooth: the smoothness term to avoid division by zero
:param batch_dice: whether to compute dice score for each batch separately
:param reduction: the reduction method to apply to the dice score
"""
_args_check(outputs, labels, dtype=torch.float)
axes = tuple(range(2, outputs.ndim))
if batch_dice:
axes = (0,) + axes
label_sum = labels.sum(axes)
intersection = (outputs * labels).sum(axes)
output_sum = outputs.sum(axes)
return do_reduction((2 * intersection + smooth) / (label_sum + output_sum + smooth), reduction)