Source code for mipcandy.data.geometric
from typing import Literal, Sequence
import torch
[docs]
def ensure_num_dimensions(x: torch.Tensor, num_dimensions: int, *, append_before: bool = True) -> torch.Tensor:
d = num_dimensions - x.ndim
if d == 0:
return x
return (x.reshape(*((1,) * d + x.shape)) if d > 0 else x.reshape(x.shape[-num_dimensions:])) if append_before else (
x.reshape(*(x.shape + (1,) * d)) if d > 0 else x.reshape(x.shape[:-num_dimensions])
)
[docs]
def orthographic_views(x: torch.Tensor, reduction: Literal["mean", "sum"] = "mean") -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor]:
match reduction:
case "mean":
return x.mean(dim=-3), x.mean(dim=-2), x.mean(dim=-1)
case "sum":
return x.sum(dim=-3), x.sum(dim=-2), x.sum(dim=-1)
[docs]
def aggregate_orthographic_views(d: torch.Tensor, h: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
d, h, w = d.unsqueeze(-3), h.unsqueeze(-2), w.unsqueeze(-1)
return d * h * w
[docs]
def crop(t: torch.Tensor, bbox: Sequence[int]) -> torch.Tensor:
return t[:, :, bbox[0]:bbox[1], bbox[2]:bbox[3]] if len(bbox) == 4 else t[:, :, bbox[0]:bbox[1], bbox[2]:bbox[3],
bbox[4]:bbox[5]]