Source code for robometric_frame.efficiency.base

"""Base class for efficiency metrics with start/stop interface.

This module provides an abstract base class for efficiency metrics that track
resource usage over time intervals using a start/stop pattern.
"""

from abc import abstractmethod
from typing import Any, Optional

import torch
from torch import Tensor
from torchmetrics import Metric


[docs] class EfficiencyMetric(Metric): """Abstract base class for efficiency metrics with start/stop interface. This base class provides common functionality for metrics that track resource usage (time, memory, etc.) over intervals. It implements: - start()/stop() interface for interval-based measurements - Percentile computation support - Common state management Subclasses must implement: - _on_start(): Called when measurement starts - _on_stop(): Called when measurement stops, should return measured value - _get_measurement_unit(): Returns the unit suffix for computed statistics Args: percentiles: List of percentile values to compute (e.g., [0.5, 0.95, 0.99]). Default: [0.5, 0.95, 0.99] for median, 95th, and 99th percentiles. **kwargs: Additional keyword arguments passed to the base Metric class. """ # Metric attributes full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = False # State attributes (dynamically added by add_state) total_value: Tensor min_value: Tensor max_value: Tensor count: Tensor values: list[Tensor]
[docs] def __init__( self, percentiles: Optional[list[float]] = None, **kwargs: Any, ) -> None: """Initialize the efficiency metric.""" super().__init__(**kwargs) # Store percentiles to compute if percentiles is None: percentiles = [0.5, 0.95, 0.99] # median, 95th, 99th percentiles self.percentiles = percentiles # Validate percentiles for p in self.percentiles: if not 0 <= p <= 1: raise ValueError(f"Percentiles must be between 0 and 1, got {p}") # Add metric states for distributed computation self.add_state("total_value", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("min_value", default=torch.tensor(float("inf")), dist_reduce_fx="min") self.add_state("max_value", default=torch.tensor(0.0), dist_reduce_fx="max") self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") # Store all values for percentile calculation self.add_state("values", default=[], dist_reduce_fx="cat") # Internal tracking state self._is_tracking = False
@abstractmethod def _on_start(self) -> None: """Called when measurement starts. Subclasses should implement any setup needed when tracking begins (e.g., recording start time, initializing memory tracking). """ @abstractmethod def _on_stop(self) -> float: """Called when measurement stops. Subclasses should implement the actual measurement logic and return the measured value. Returns: The measured value (e.g., elapsed time, memory usage). """ @abstractmethod def _get_measurement_unit(self) -> str: """Return the unit suffix for computed statistics. Returns: Unit suffix string (e.g., '' for seconds, '_mb' for megabytes). """
[docs] def start(self) -> None: """Start a measurement interval. This method begins tracking. Call stop() to end the interval and record the measurement. Raises: RuntimeError: If start() is called while already tracking. Example: >>> metric.start() >>> # ... perform operations ... >>> metric.stop() """ if self._is_tracking: raise RuntimeError("Already tracking. Call stop() before starting a new measurement.") self._is_tracking = True self._on_start()
[docs] def stop(self) -> None: """Stop tracking and record the measurement. This method ends the tracking interval and records the measured value. Raises: RuntimeError: If stop() is called without a preceding start(). Example: >>> metric.start() >>> # ... perform operations ... >>> metric.stop() """ if not self._is_tracking: raise RuntimeError("Not currently tracking. Call start() before stop().") self._is_tracking = False value = self._on_stop() # Update metric with measured value self.update(torch.tensor(value, device=self.device))
[docs] def update(self, value: Tensor) -> None: """Update metric state with measurements. Args: value: Measurement values. Can be: - Scalar tensor: Single measurement - 1D tensor: Batch of measurements All values must be non-negative. Raises: ValueError: If value contains negative values. Example: >>> metric.update(torch.tensor(0.1)) # Single measurement >>> metric.update(torch.tensor([0.11, 0.12])) # Batch """ # Flatten to 1D if needed value = value.flatten().float() # Validate inputs if (value < 0).any(): raise ValueError("Values must be non-negative") # Update states self.total_value += value.sum() self.min_value = torch.min(self.min_value, value.min()) self.max_value = torch.max(self.max_value, value.max()) self.count += value.numel() # Store values for percentile calculation self.values.append(value)
[docs] def compute(self) -> dict[str, Tensor]: """Compute statistics including percentiles. Returns: Dictionary containing: - 'mean{unit}': Mean value - 'min{unit}': Minimum value - 'max{unit}': Maximum value - 'total{unit}': Total accumulated value - 'count': Number of measurements - 'p{X}{unit}': Xth percentile (e.g., 'p50' for median) Raises: RuntimeError: If no measurements have been recorded. """ if self.count == 0: raise RuntimeError( "Cannot compute: no measurements have been recorded. " "Call update() or start()/stop() before compute()." ) unit = self._get_measurement_unit() # Base statistics stats = { f"mean{unit}": self.total_value / self.count, f"min{unit}": self.min_value, f"max{unit}": self.max_value, f"total{unit}": self.total_value, "count": self.count.float(), } # Compute percentiles if values are stored if self.values: all_values = torch.cat(self.values) for p in self.percentiles: # Format percentile key (e.g., 0.95 -> 'p95', 0.5 -> 'p50') key = f"p{int(p * 100)}{unit}" stats[key] = torch.quantile(all_values, p) return stats
[docs] def reset(self) -> None: """Reset the metric state. This method resets all metric states to their default values and clears any internal tracking state. """ super().reset() self._is_tracking = False