"""
Metric Extraction: Registry-based lookup for cycle evaluation metrics.
This module provides unified metric extraction from CycleEvaluation objects,
replacing repetitive if/elif chains with a registry-based approach.
Usage:
.. code-block:: python
from quantammsim.runners.metric_extraction import extract_cycle_metric
# Extract aggregated metrics from cycle evaluations
value = extract_cycle_metric(cycle_evals, "mean_oos_sharpe")
value = extract_cycle_metric(cycle_evals, "worst_wfe")
value = extract_cycle_metric(cycle_evals, "neg_is_oos_gap")
"""
from typing import List, Dict, Callable, Any
import numpy as np
#: Registry mapping short metric names to ``CycleEvaluation`` attribute names.
#:
#: Keys are the tokens recognised in metric spec strings (e.g.
#: ``"mean_oos_sharpe"`` → aggregator ``"mean"`` + metric ``"oos_sharpe"``).
#: Values are the corresponding attribute on :class:`CycleEvaluation`.
CYCLE_METRICS: Dict[str, str] = {
"oos_sharpe": "oos_sharpe",
"is_sharpe": "is_sharpe",
"wfe": "walk_forward_efficiency",
"is_oos_gap": "is_oos_gap",
"adjusted_oos_sharpe": "adjusted_oos_sharpe",
# Risk metrics
"oos_calmar": "oos_calmar",
"is_calmar": "is_calmar",
"oos_sterling": "oos_sterling",
"is_sterling": "is_sterling",
"oos_ulcer": "oos_ulcer",
"is_ulcer": "is_ulcer",
"oos_returns": "oos_returns",
"is_returns": "is_returns",
"oos_returns_over_hodl": "oos_returns_over_hodl",
"is_returns_over_hodl": "is_returns_over_hodl",
"oos_daily_log_sharpe": "oos_daily_log_sharpe",
"is_daily_log_sharpe": "is_daily_log_sharpe",
}
# Aggregation functions
def _mean_agg(v: List[float]) -> float:
"""Arithmetic mean with inf/nan filtering.
Non-finite and ``None`` entries are silently dropped before averaging.
If *all* entries are invalid, returns ``-inf`` so that Optuna treats
the trial as the worst possible outcome.
Parameters
----------
v : List[float]
Per-cycle metric values (may contain ``nan``, ``inf``, or ``None``).
Returns
-------
float
Filtered mean, or ``-inf`` if no valid values remain.
"""
filtered = [x for x in v if x is not None and np.isfinite(x)]
if not filtered:
return float("-inf") # No valid values = worst possible result
return float(np.mean(filtered))
def _worst_agg(v: List[float]) -> float:
"""Worst-case (minimum) aggregator with inf/nan filtering.
Same filtering semantics as :func:`_mean_agg` but returns the minimum
of the valid values. Used for ``"worst_"`` metric prefixes (e.g.
``"worst_oos_sharpe"``).
Parameters
----------
v : List[float]
Per-cycle metric values.
Returns
-------
float
Minimum of valid values, or ``-inf`` if none remain.
"""
filtered = [x for x in v if x is not None and np.isfinite(x)]
if not filtered:
return float("-inf") # No valid values = worst possible result
return float(np.min(filtered))
#: Aggregation functions keyed by the prefix used in metric spec strings.
#: E.g. ``"mean_oos_sharpe"`` dispatches to ``AGGREGATORS["mean"]``.
AGGREGATORS: Dict[str, Callable[[List[float]], float]] = {
"mean": _mean_agg,
"worst": _worst_agg,
}
[docs]
def get_metric_from_result(result: Any, metric_name: str) -> float:
"""
Extract a metric from an EvaluationResult object.
Parameters
----------
result : EvaluationResult
The evaluation result object
metric_name : str
Name of the metric to extract
Returns
-------
float
The metric value
"""
metric_map = {
"mean_oos_sharpe": "mean_oos_sharpe",
"mean_wfe": "mean_wfe",
"worst_oos_sharpe": "worst_oos_sharpe",
"adjusted_mean_oos_sharpe": "adjusted_mean_oos_sharpe",
"neg_is_oos_gap": "mean_is_oos_gap", # Will be negated
}
attr = metric_map.get(metric_name, metric_name)
value = getattr(result, attr, None)
if value is None:
return float("-inf")
# Handle negation for gap metric
if metric_name == "neg_is_oos_gap":
return -value
# Handle fallback for adjusted Sharpe
if metric_name == "adjusted_mean_oos_sharpe" and value is None:
return getattr(result, "mean_oos_sharpe", float("-inf"))
return value