"""Action Accuracy metrics for robotics policy evaluation.
Action Accuracy measures the precision of predicted actions against ground truth
trajectories using Mean Squared Error (MSE) and its variations. This provides
direct assessment of model performance in offline evaluation scenarios.
References:
[1] M. Dobiš et al., "Evaluation criteria for trajectories of robotic arms,"
Robotics, vol. 11, p. 29, 2022.
[2] K. K. A. Farag et al., "Mobile robot obstacle avoidance based on neural
network with a standardization technique," J. Robot., vol. 2021, 2021.
"""
from typing import Any, Optional
import torch
from torch import Tensor
from torchmetrics import Metric
[docs]
class ActionAccuracy(Metric):
r"""Compute Action Accuracy metrics (MSE, AMSE, NAMSE) for robotics policy evaluation.
This metric computes three related measures of action prediction accuracy:
- MSE: Mean Squared Error per trajectory
- AMSE: Average MSE across multiple trajectories
- NAMSE: Normalized AMSE (scaled by action variance)
Formulas:
MSE = (1/T) * sum_{t=1}^{T} \|a_t - â_t\|_2^2
AMSE = (1/K) * sum_{k=1}^{K} MSE_k
NAMSE = AMSE / σ²_action
where:
- a_t is the ground truth action at timestep t
- â_t is the predicted action at timestep t
- T is the number of timesteps in a trajectory
- K is the number of trajectories
- σ²_action is the variance of ground truth actions
Args:
normalize: Whether to compute NAMSE. If True, action variance is computed
from the data. If False, only MSE and AMSE are computed. Default: False.
action_variance: Pre-computed action variance for normalization. If provided,
this value is used instead of computing from data. Default: None.
**kwargs: Additional keyword arguments passed to the base Metric class.
Example:
>>> from robometric_frame import ActionAccuracy
>>> import torch
>>> metric = ActionAccuracy()
>>>
>>> # Single trajectory
>>> predictions = torch.randn(10, 4) # 10 timesteps, 4-dim actions
>>> targets = torch.randn(10, 4)
>>> metric.update(predictions, targets)
>>> results = metric.compute()
>>> print(f"MSE: {results['mse']:.4f}, AMSE: {results['amse']:.4f}")
>>>
>>> # With normalization
>>> metric = ActionAccuracy(normalize=True)
>>> metric.update(predictions, targets)
>>> results = metric.compute()
>>> print(f"NAMSE: {results['namse']:.4f}")
Example (multiple trajectories):
>>> metric = ActionAccuracy()
>>> # Trajectory 1
>>> metric.update(torch.randn(10, 4), torch.randn(10, 4))
>>> # Trajectory 2
>>> metric.update(torch.randn(15, 4), torch.randn(15, 4))
>>> results = metric.compute()
>>> # AMSE is averaged across both trajectories
"""
full_state_update: bool = False
# Dynamically added by add_state() in __init__
total_mse: Tensor
total_trajectories: Tensor
total_squared_actions: Tensor
total_actions: Tensor
total_action_count: Tensor
[docs]
def __init__(
self,
normalize: bool = False,
action_variance: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Initialize the ActionAccuracy metric."""
super().__init__(**kwargs)
self.normalize = normalize
self.action_variance = action_variance
# States for MSE and AMSE computation
self.add_state("total_mse", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total_trajectories", default=torch.tensor(0.0), dist_reduce_fx="sum")
# States for action variance computation (needed for NAMSE)
if normalize and action_variance is None:
self.add_state("total_squared_actions", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total_actions", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total_action_count", default=torch.tensor(0.0), dist_reduce_fx="sum")
[docs]
def update(self, predictions: Tensor, targets: Tensor) -> None: # pylint: disable=arguments-differ
"""Update metric state with predicted and target actions.
Args:
predictions: Predicted actions of shape (T, D) where T is the number of
timesteps and D is the action dimension.
targets: Ground truth actions of shape (T, D).
Raises:
ValueError: If predictions and targets have different shapes or are empty.
"""
if predictions.shape != targets.shape:
raise ValueError(
f"Shape mismatch: predictions {predictions.shape} vs targets {targets.shape}"
)
if predictions.numel() == 0:
raise ValueError("Input tensors are empty")
# Compute MSE for this trajectory: mean of squared L2 norms
squared_errors = torch.sum((predictions - targets) ** 2, dim=-1) # (T,)
mse = squared_errors.mean()
# Update MSE accumulator
self.total_mse += mse # pylint: disable=no-member
self.total_trajectories += 1.0 # pylint: disable=no-member
# Update action statistics for variance computation (if needed for NAMSE)
if self.normalize and self.action_variance is None:
# Flatten actions to compute overall statistics
targets_flat = targets.reshape(-1)
self.total_squared_actions += (targets_flat**2).sum() # pylint: disable=no-member
self.total_actions += targets_flat.sum() # pylint: disable=no-member
self.total_action_count += targets_flat.numel() # pylint: disable=no-member
[docs]
def compute(self) -> dict[str, Tensor]:
"""Compute the final Action Accuracy metrics.
Returns:
Dictionary containing:
- 'mse': Mean Squared Error of the last trajectory
- 'amse': Average MSE across all trajectories
- 'namse': Normalized AMSE (only if normalize=True)
Raises:
RuntimeError: If no trajectories have been recorded.
"""
if self.total_trajectories == 0: # pylint: disable=no-member
raise RuntimeError(
"Cannot compute action accuracy: no trajectories have been recorded. "
"Call update() with predictions and targets before compute()."
)
# Compute AMSE (average of MSEs across trajectories)
amse = self.total_mse / self.total_trajectories # pylint: disable=no-member
results = {
"mse": self.total_mse / self.total_trajectories, # pylint: disable=no-member
"amse": amse,
}
# Compute NAMSE if normalization is enabled
if self.normalize:
if self.action_variance is not None:
# Use provided variance
variance = torch.tensor(self.action_variance, dtype=amse.dtype, device=amse.device)
else:
# Compute variance from accumulated statistics
# Var(X) = E[X²] - E[X]²
mean_squared = (
self.total_squared_actions / self.total_action_count # pylint: disable=no-member
)
mean = self.total_actions / self.total_action_count # pylint: disable=no-member
variance = mean_squared - mean**2
if variance <= 0:
raise RuntimeError(
"Action variance is zero or negative. Cannot compute NAMSE. "
"Ensure target actions have non-zero variance."
)
results["namse"] = amse / variance
return results