from abc import ABCMeta, abstractmethod
from os import PathLike
from typing import Any, Generator, Self, override
import torch
from safetensors.torch import save_model, load_model
from torch import nn
from mipcandy.types import Device, AmbiguousShape
[docs]
def batch_int_multiply(f: float, *n: int) -> Generator[int, None, None]:
for i in n:
r = i * f
if not r.is_integer():
raise ValueError(f"Inequivalent conversion")
yield int(r)
[docs]
def batch_int_divide(f: float, *n: int) -> Generator[int, None, None]:
return batch_int_multiply(1 / f, *n)
[docs]
class LayerT(object):
def __init__(self, m: type[nn.Module], **kwargs) -> None:
self.m: type[nn.Module] = m
self.kwargs: dict[str, Any] = kwargs
[docs]
def update(self, *, must_exist: bool = True, inplace: bool = False, **kwargs) -> Self:
if not inplace:
return self.copy().update(must_exist=must_exist, inplace=True, **kwargs)
for k, v in kwargs.items():
if not must_exist or k in self.kwargs:
self.kwargs[k] = v
return self
[docs]
def assemble(self, *args, **kwargs) -> nn.Module:
self_kwargs = self.kwargs.copy()
for k, v in self_kwargs.items():
if isinstance(v, str) and v in kwargs:
self_kwargs[k] = kwargs.pop(v)
return self.m(*args, **self_kwargs, **kwargs)
[docs]
def copy(self) -> Self:
return self.__class__(self.m, **self.kwargs)
[docs]
class HasDevice(object):
def __init__(self, device: Device) -> None:
self._device: Device = device
[docs]
def device(self, *, device: Device | None = None) -> None | Device:
if device is None:
return self._device
self._device = device
[docs]
def auto_device() -> Device:
if torch.cuda.is_available():
return f"cuda:{max(range(torch.cuda.device_count()),
key=lambda i: torch.cuda.memory_reserved(i) - torch.cuda.memory_allocated(i))}"
if torch.mps.is_available():
return "mps"
return "cpu"
[docs]
class WithPaddingModule(HasDevice):
def __init__(self, device: Device) -> None:
super().__init__(device)
self._padding_module: nn.Module | None = None
self._restoring_module: nn.Module | None = None
self._padding_module_built: bool = False
[docs]
def build_padding_module(self) -> nn.Module | None:
return None
[docs]
def build_restoring_module(self, padding_module: nn.Module | None) -> nn.Module | None:
return None
[docs]
def _lazy_load_padding_module(self) -> None:
if self._padding_module_built:
return
self._padding_module = self.build_padding_module()
if self._padding_module:
self._padding_module = self._padding_module.to(self._device)
self._restoring_module = self.build_restoring_module(self._padding_module)
if self._restoring_module:
self._restoring_module = self._restoring_module.to(self._device)
self._padding_module_built = True
[docs]
def get_padding_module(self) -> nn.Module | None:
self._lazy_load_padding_module()
return self._padding_module
[docs]
def get_restoring_module(self) -> nn.Module | None:
self._lazy_load_padding_module()
return self._restoring_module
[docs]
class WithCheckpoint(object, metaclass=ABCMeta):
[docs]
@abstractmethod
def load_checkpoint(self, model: nn.Module, path: str | PathLike[str]) -> nn.Module:
raise NotImplementedError
[docs]
@abstractmethod
def save_checkpoint(self, model: nn.Module, path: str | PathLike[str]) -> None:
raise NotImplementedError
[docs]
class WithNetwork(WithCheckpoint, HasDevice, metaclass=ABCMeta):
def __init__(self, device: Device) -> None:
super().__init__(device)
[docs]
@override
def load_checkpoint(self, model: nn.Module, path: str | PathLike[str]) -> nn.Module:
load_model(model, path)
return model
[docs]
@override
def save_checkpoint(self, model: nn.Module, path: str | PathLike[str]) -> None:
save_model(getattr(model, "_orig_mod") if hasattr(model, "_orig_mod") else model, path)
[docs]
@abstractmethod
def build_network(self, example_shape: AmbiguousShape) -> nn.Module:
raise NotImplementedError
[docs]
@staticmethod
def compile_model(model: nn.Module) -> nn.Module:
return torch.compile(model)
[docs]
def build_network_from_checkpoint(self, example_shape: AmbiguousShape, path: str | PathLike[str]) -> nn.Module:
"""
Internally exposed interface for overriding. Use `load_model()` instead.
"""
model = self.build_network(example_shape)
return self.load_checkpoint(model, path)
[docs]
def load_model(self, example_shape: AmbiguousShape, compile_model: bool, *,
path: str | PathLike[str] | None = None) -> nn.Module:
m = (self.build_network_from_checkpoint(example_shape, path) if path else self.build_network(example_shape)).to(
self._device)
return self.compile_model(m) if compile_model else m
[docs]
def save_model(self, model: nn.Module, path: str | PathLike[str]) -> None:
self.save_checkpoint(model, path)