Source code for mipcandy.data.io

from gc import collect, get_objects
from math import floor
from os import PathLike

import SimpleITK as SpITK
import torch
from safetensors.torch import save_file, load_file

from mipcandy.data.convertion import auto_convert
from mipcandy.data.geometric import ensure_num_dimensions
from mipcandy.types import Device, AmbiguousShape


[docs] def fast_save(x: torch.Tensor, path: str | PathLike[str]) -> None: save_file({"payload": x if x.is_contiguous() else x.contiguous()}, path)
[docs] def fast_load(path: str | PathLike[str], *, device: Device = "cpu") -> torch.Tensor: return load_file(path, device)["payload"]
[docs] def resample_to_isotropic(image: SpITK.Image, *, target_iso: float | None = None, interpolator: int = SpITK.sitkBSpline) -> SpITK.Image: dim = image.GetDimension() old_spacing = image.GetSpacing() old_size = image.GetSize() origin = image.GetOrigin() direction = image.GetDirection() if target_iso is None: target_iso = min(old_spacing) new_spacing = (target_iso,) * dim new_size = tuple(max(1, floor(old_spacing[i] * (old_size[i] - 1) / new_spacing[i] + 1)) for i in range(dim)) return SpITK.Resample( image, new_size, SpITK.Transform(), interpolator, origin, new_spacing, direction, 0, image.GetPixelID() )
[docs] def load_image(path: str | PathLike[str], *, is_label: bool = False, align_spacing: bool = False, target_iso: float | None = None, device: Device = "cpu") -> torch.Tensor: file = SpITK.ReadImage(path) if align_spacing: file = resample_to_isotropic(file, target_iso=target_iso, interpolator=SpITK.sitkNearestNeighbor if is_label else SpITK.sitkBSpline) img = torch.tensor(SpITK.GetArrayFromImage(file), dtype=torch.long if is_label else torch.float, device=device) if path.endswith(".nii.gz") or path.endswith(".nii") or path.endswith(".mha"): img = ensure_num_dimensions(img, 4, append_before=False).permute(3, 0, 1, 2) return img.squeeze(1) if img.shape[1] == 1 else img if path.endswith(".png") or path.endswith(".jpg") or path.endswith(".jpeg"): return ensure_num_dimensions(img, 3, append_before=False).permute(2, 0, 1) raise NotImplementedError(f"Unsupported file type: {path}")
[docs] def save_image(image: torch.Tensor, path: str | PathLike[str]) -> None: if path.endswith(".nii.gz") or path.endswith(".nii") or path.endswith(".mha"): image = ensure_num_dimensions(image, 4).permute(1, 2, 3, 0) return SpITK.WriteImage(SpITK.GetImageFromArray(image.detach().cpu().numpy(), isVector=False), path) if path.endswith(".png") or path.endswith(".jpg") or path.endswith(".jpeg"): image = auto_convert(ensure_num_dimensions(image, 3)).to(torch.uint8).permute(1, 2, 0) return SpITK.WriteImage(SpITK.GetImageFromArray(image.detach().cpu().numpy(), isVector=True), path) raise NotImplementedError(f"Unsupported file type: {path}")
[docs] def empty_cache(device: Device) -> None: match torch.device(device).type: case "cpu": collect() case "cuda": torch.cuda.empty_cache() case "mps": torch.mps.empty_cache()
[docs] def dump_allocated_tensors() -> tuple[float, list[tuple[ float, AmbiguousShape, torch.dtype, torch.device, bool, str]]]: """ :return: (total size in MB, [(size in MB, shape, dtype, device, requires_grad, grad_fn)]) """ tensors = [ (obj, obj.numel() * obj.element_size() / 1048576) for obj in get_objects() if isinstance(obj, torch.Tensor) ] tensors.sort(key=lambda t: t[1], reverse=True) return sum(t[1] for t in tensors), [ (sz, tuple(t.shape), t.dtype, t.device, t.requires_grad, str(t.grad_fn)) for t, sz in tensors ]