Source code for mipcandy.data.convertion

import torch

from mipcandy.common import Normalize


[docs] def convert_ids_to_logits(ids: torch.Tensor, num_classes: int, *, channel_dim: int = 1) -> torch.Tensor: """ :param ids: class ids (..., 1, ...) :param num_classes: number of classes :param channel_dim: the index of the channel dimension :return: logits (..., num_classes, ...) """ if torch.is_floating_point(ids) or (ids < 0).any(): raise TypeError("Class ids must be positive integers") shape = list(ids.shape) shape[channel_dim] = num_classes logits = torch.zeros(shape, device=ids.device, dtype=torch.float32) logits.scatter_(channel_dim, ids.long(), 1) return logits
[docs] def convert_logits_to_ids(logits: torch.Tensor, *, channel_dim: int = 1) -> torch.Tensor: """ :param logits: logits (..., num_classes, ...) :param channel_dim: the index of the channel dimension :return: class ids (..., 1, ...) """ return logits.round().int() if logits.shape[channel_dim] < 2 else logits.argmax(channel_dim, keepdim=True)
[docs] def auto_convert(image: torch.Tensor) -> torch.Tensor: return (image * 255 if 0 <= image.min() <= image.max() <= 1 else Normalize(domain=(0, 255))(image)).int()