Source code for robometric_frame.efficiency.inference_latency

"""Inference Latency metric for robotics policy evaluation.

Inference Latency measures the time required to generate actions from visual
observations and language instructions. This metric is crucial for real-time
applications where responsive behavior is essential for effective human-robot
interaction.

Reference:
    A. Brohan et al., "RT-1: Robotics transformer for real-world control at
    scale," arXiv:2212.06817, 2022.
"""

import time
from typing import Any, Optional

import torch

from robometric_frame.efficiency.base import EfficiencyMetric


[docs] class InferenceLatency(EfficiencyMetric): r"""Compute Inference Latency for robotics policy evaluation. Inference Latency is calculated as: IL = t_infer,end - t_infer,start This metric tracks the time elapsed during model inference operations, which is critical for real-time robotics applications. It accumulates timing measurements across multiple inference calls and provides statistics including mean, minimum, maximum, total latency, and configurable percentiles. The metric is designed to be used in two ways: 1. Manual timing: Call start() before inference and stop() after 2. Direct update: Call update() with pre-measured latency values Args: percentiles: List of percentile values to compute (e.g., [0.5, 0.95, 0.99]). Default: [0.5, 0.95, 0.99] for median, 95th, and 99th percentiles. **kwargs: Additional keyword arguments passed to the base Metric class. Example: >>> from robometric_frame.efficiency import InferenceLatency >>> import torch >>> import time >>> metric = InferenceLatency() >>> # Manual timing >>> metric.start() >>> # ... model inference ... >>> time.sleep(0.1) # Simulate inference >>> metric.stop() >>> result = metric.compute() >>> result['mean'] > 0 tensor(True) Example (direct update): >>> # Direct update with measured latency >>> metric = InferenceLatency() >>> latencies = torch.tensor([0.1, 0.15, 0.12, 0.11]) # seconds >>> metric.update(latencies) >>> result = metric.compute() >>> result['mean'].item() 0.12 Example (batched): >>> # Multiple inference measurements >>> metric = InferenceLatency() >>> for _ in range(10): ... metric.start() ... time.sleep(0.01) # Simulate inference ... metric.stop() >>> result = metric.compute() >>> result['count'] tensor(10) Example (distributed): >>> # In distributed training, metrics are automatically synced >>> metric = InferenceLatency() >>> # On GPU 0 >>> metric.update(torch.tensor([0.1, 0.12])) >>> # On GPU 1 >>> metric.update(torch.tensor([0.11, 0.13])) >>> # Final result aggregates across all GPUs >>> result = metric.compute() >>> result['mean'].item() 0.115 Example (custom percentiles): >>> # Track specific percentiles for robustness analysis >>> metric = InferenceLatency(percentiles=[0.5, 0.9, 0.95, 0.99]) >>> latencies = torch.tensor([0.1, 0.12, 0.15, 0.11, 0.13, 0.2, 0.25, 0.3]) >>> metric.update(latencies) >>> result = metric.compute() >>> result['p50'] # median tensor(0.1350) >>> result['p95'] # 95th percentile tensor(0.2875) """ _start_time: Optional[float]
[docs] def __init__( self, percentiles: Optional[list[float]] = None, **kwargs: Any, ) -> None: """Initialize the InferenceLatency metric.""" super().__init__(percentiles=percentiles, **kwargs) self._start_time = None
def _on_start(self) -> None: """Record the start time for latency measurement.""" # Use CUDA events for GPU timing if available if torch.cuda.is_available() and self.device.type == "cuda": torch.cuda.synchronize(self.device) self._start_time = time.perf_counter() def _on_stop(self) -> float: """Calculate and return the elapsed time since start. Returns: Elapsed time in seconds. """ # Synchronize GPU if needed if torch.cuda.is_available() and self.device.type == "cuda": torch.cuda.synchronize(self.device) end_time = time.perf_counter() latency = end_time - self._start_time # type: ignore[operator] self._start_time = None return latency def _get_measurement_unit(self) -> str: """Return empty string as latency is measured in seconds (base unit).""" return ""
[docs] def reset(self) -> None: """Reset the metric state.""" super().reset() self._start_time = None