Source code for mipcandy.data.geometric

from typing import Literal, Sequence

import torch


[docs] def ensure_num_dimensions(x: torch.Tensor, num_dimensions: int, *, append_before: bool = True) -> torch.Tensor: d = num_dimensions - x.ndim if d == 0: return x return (x.reshape(*((1,) * d + x.shape)) if d > 0 else x.reshape(x.shape[-num_dimensions:])) if append_before else ( x.reshape(*(x.shape + (1,) * d)) if d > 0 else x.reshape(x.shape[:-num_dimensions]) )
[docs] def orthographic_views(x: torch.Tensor, reduction: Literal["mean", "sum"] = "mean") -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor]: match reduction: case "mean": return x.mean(dim=-3), x.mean(dim=-2), x.mean(dim=-1) case "sum": return x.sum(dim=-3), x.sum(dim=-2), x.sum(dim=-1)
[docs] def aggregate_orthographic_views(d: torch.Tensor, h: torch.Tensor, w: torch.Tensor) -> torch.Tensor: d, h, w = d.unsqueeze(-3), h.unsqueeze(-2), w.unsqueeze(-1) return d * h * w
[docs] def crop(t: torch.Tensor, bbox: Sequence[int]) -> torch.Tensor: return t[:, :, bbox[0]:bbox[1], bbox[2]:bbox[3]] if len(bbox) == 4 else t[:, :, bbox[0]:bbox[1], bbox[2]:bbox[3], bbox[4]:bbox[5]]