Source code for mipcandy.frontend.notion_fe

from datetime import datetime
from typing import override, Literal

from requests import get, post, patch, Response

from mipcandy.frontend.prototype import Frontend
from mipcandy.types import Settings


[docs] class NotionFrontend(Frontend): def __init__(self, secrets: Settings) -> None: super().__init__(secrets) self._api_key: str = self.require_nonempty_secret("notion_api_key", required_type=str) self._database_id: str = self.require_nonempty_secret("notion_database_id", required_type=str) self._headers: dict[str, str] = { "Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json", "Notion-Version": "2022-06-28" } self._num_epochs: int = 1 self._early_stop_tolerance: int = -1 self._start_time: str = "" self._page_id: str = ""
[docs] def retrieve_database(self) -> Response: return get(f"https://api.notion.com/v1/databases/{self._database_id}", headers=self._headers)
[docs] def query_database(self, *, experiment_id: str | None = None) -> Response: json = {"filter": {"property": "Experiment ID", "title": {"equals": experiment_id}}} if experiment_id else None return post(f"https://api.notion.com/v1/databases/{self._database_id}/query", json=json, headers=self._headers)
[docs] def select_experiment(self, experiment_id: str) -> str: experiments = self.query_database(experiment_id=experiment_id) if experiments.status_code != 200: raise RuntimeError(f"Failed to query database: {experiments.json()}") experiments = experiments.json()["results"] if len(experiments) == 1: return experiments[0]["id"] if len(experiments) > 1: raise RuntimeError(f"Found multiple experiments with the same ID {experiment_id}") return ""
[docs] def new_experiment(self, experiment_id: str, trainer: str, model: str, note: str, num_macs: float, num_params: float) -> Response: self._start_time = datetime.now().astimezone().strftime("%Y-%m-%dT%H:%M:%S.000%z") properties = { "Experiment ID": {"title": [{"text": {"content": experiment_id}}]}, "Status": {"status": {"name": "In Progress"}}, "Progress": {"number": 0}, "Early Stop": {"number": 1}, "Trainer": {"select": {"name": trainer}}, "Model": {"select": {"name": model}}, "Time": {"date": {"start": self._start_time}}, "Note": {"rich_text": [{"text": {"content": note}}]}, "MACs (G)": {"number": round(num_macs, 1)}, "Params (M)": {"number": round(num_params, 1)}, "Epoch": {"number": 0}, "Score": {"number": 0}, } page_id = self.select_experiment(experiment_id) if page_id: self._page_id = page_id return patch(f"https://api.notion.com/v1/pages/{page_id}", json={"properties": properties}, headers=self._headers) res = post("https://api.notion.com/v1/pages", json={ "parent": {"database_id": self._database_id}, "icon": {"external": {"url": "https://www.notion.so/icons/science_gray.svg"}}, "properties": properties }, headers=self._headers) self._page_id = res.json()["id"] return res
[docs] def update_experiment(self, experiment_id: str, status: Literal["In Progress", "Completed", "Interrupted"], *, epoch: int | None = None, score: float | None = None, early_stop_tolerance: int | None = None, observation: str | None = None) -> Response: if not self._page_id: raise RuntimeError(f"Experiment {experiment_id} has not been created") properties = {"Status": {"status": {"name": status}}} if epoch is not None: properties["Progress"] = {"number": epoch / self._num_epochs} properties["Epoch"] = {"number": epoch} if early_stop_tolerance is not None: properties["Early Stop"] = {"number": max(early_stop_tolerance, 0) / self._early_stop_tolerance} if score is not None: properties["Score"] = {"number": round(score, 4)} if observation is not None: properties["Observation"] = {"rich_text": [{"text": {"content": observation}}]} if status == "Completed": properties["Progress"] = {"number": 1} properties["Time"] = {"date": {"start": self._start_time, "end": datetime.now().astimezone().strftime("%Y-%m-%dT%H:%M:%S.000%z")}} return patch(f"https://api.notion.com/v1/pages/{self._page_id}", json={"properties": properties}, headers=self._headers)
[docs] @override def on_experiment_created(self, experiment_id: str, trainer: str, model: str, note: str, num_macs: float, num_params: float, num_epochs: int, early_stop_tolerance: int) -> None: self._num_epochs = num_epochs self._early_stop_tolerance = early_stop_tolerance res = self.new_experiment(experiment_id, trainer, model, note, num_macs * 1e-9, num_params * 1e-6) if res.status_code != 200: raise RuntimeError(f"Failed to create experiment: {res.json()}")
[docs] @override def on_experiment_updated(self, experiment_id: str, epoch: int, metrics: dict[str, list[float]], early_stop_tolerance: int) -> None: try: self.update_experiment(experiment_id, "In Progress", epoch=epoch, score=max(metrics["val score"]), early_stop_tolerance=early_stop_tolerance) except RuntimeError: pass
[docs] @override def on_experiment_completed(self, experiment_id: str) -> None: res = self.update_experiment(experiment_id, "Completed") if res.status_code != 200: raise RuntimeError(f"Failed to update experiment: {res.json()}")
[docs] @override def on_experiment_interrupted(self, experiment_id: str, error: Exception) -> None: res = self.update_experiment(experiment_id, "Interrupted", observation=repr(error)) if res.status_code != 200: raise RuntimeError(f"Failed to update experiment: {res.json()}")