Source code for mipcandy.sanity_check

from dataclasses import dataclass
from io import StringIO
from typing import Sequence, override

import torch
from ptflops import get_model_complexity_info
from torch import nn

from mipcandy.layer import auto_device
from mipcandy.types import Device


[docs] def num_trainable_params(model: nn.Module) -> int: return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs] def model_complexity_info(model: nn.Module, example_shape: Sequence[int]) -> tuple[float | None, float | None, str]: layer_stats = StringIO() macs, params = get_model_complexity_info(model, tuple(example_shape), ost=layer_stats, as_strings=False) return macs, params, layer_stats.getvalue()
[docs] @dataclass class SanityCheckResult(object): num_macs: float num_params: float layer_stats: str output: torch.Tensor
[docs] @override def __str__(self) -> str: return f"MACs: {self.num_macs * 1e-9:.1f} G / Params: {self.num_params * 1e-6:.1f} M"
[docs] def sanity_check(model: nn.Module, input_shape: Sequence[int], *, device: Device | None = None) -> SanityCheckResult: if device is None: device = auto_device() with torch.no_grad(): num_macs, num_params, layer_stats = model_complexity_info(model, input_shape) if num_macs is None or num_params is None: raise RuntimeError("Failed to validate model") outputs = model.to(device).eval()(torch.randn(1, *input_shape, device=device)) return SanityCheckResult(num_macs, num_params, layer_stats, ( outputs[0] if isinstance(outputs, tuple) else outputs).squeeze(0))