Source code for mipcandy.sanity_check
from dataclasses import dataclass
from io import StringIO
from typing import Sequence, override
import torch
from ptflops import get_model_complexity_info
from torch import nn
from mipcandy.layer import auto_device
from mipcandy.types import Device
[docs]
def num_trainable_params(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs]
def model_complexity_info(model: nn.Module, example_shape: Sequence[int]) -> tuple[float | None, float | None, str]:
layer_stats = StringIO()
macs, params = get_model_complexity_info(model, tuple(example_shape), ost=layer_stats, as_strings=False)
return macs, params, layer_stats.getvalue()
[docs]
@dataclass
class SanityCheckResult(object):
num_macs: float
num_params: float
layer_stats: str
output: torch.Tensor
[docs]
@override
def __str__(self) -> str:
return f"MACs: {self.num_macs * 1e-9:.1f} G / Params: {self.num_params * 1e-6:.1f} M"
[docs]
def sanity_check(model: nn.Module, input_shape: Sequence[int], *, device: Device | None = None) -> SanityCheckResult:
if device is None:
device = auto_device()
with torch.no_grad():
num_macs, num_params, layer_stats = model_complexity_info(model, input_shape)
if num_macs is None or num_params is None:
raise RuntimeError("Failed to validate model")
outputs = model.to(device).eval()(torch.randn(1, *input_shape, device=device))
return SanityCheckResult(num_macs, num_params, layer_stats, (
outputs[0] if isinstance(outputs, tuple) else outputs).squeeze(0))