-
Notifications
You must be signed in to change notification settings - Fork 69
Per-episode sensitivity recording (eval + metrics) #781
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: cvolk/feature/sensitivity_analysis_mvp1
Are you sure you want to change the base?
Changes from 1 commit
0d63854
82c25f1
1d4124b
7303e1e
85446cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from isaaclab_arena.metrics.metrics_logger import metrics_to_plain_python_types | ||
|
|
||
| if TYPE_CHECKING: | ||
| from isaaclab_arena.evaluation.job_manager import Job | ||
|
|
||
|
|
||
| def write_episode_summaries(env, job: Job, output_path: str | Path) -> int: | ||
| """Append one JSONL row per recorded episode for the just-completed job. | ||
|
|
||
| Each row has shape:: | ||
|
|
||
| { | ||
| "job_name": "<job.name>", | ||
| "episode_idx": <episode index in the recorded dataset>, | ||
| "arena_env_args": <full job.arena_env_args_dict>, | ||
| "outcomes": <per-episode metric values> | ||
| } | ||
|
|
||
| Per-episode metric values come from the env's ``MetricsManager`` (the same machinery | ||
| that backs ``compute_metrics``), so all HDF5/metric access stays in the metrics layer. | ||
|
|
||
| Args: | ||
| env: The (possibly gym-wrapped) Arena env that just finished its rollout. Its | ||
| ``MetricsManager`` provides the per-episode metric values. | ||
| job: The Job that ran. Its ``arena_env_args_dict`` is logged verbatim under | ||
| ``arena_env_args``. | ||
| output_path: JSONL file to append to. Created (with parent dirs) if absent. | ||
|
|
||
| Returns: | ||
| Number of rows written. | ||
| """ | ||
| unwrapped_env = env.unwrapped | ||
| if not hasattr(unwrapped_env.cfg, "metrics") or unwrapped_env.cfg.metrics is None: | ||
| return 0 | ||
|
|
||
| per_episode_metrics = unwrapped_env.metrics_manager.compute_per_episode() | ||
| arena_env_args_snapshot = dict(job.arena_env_args_dict) | ||
|
|
||
| output_path = Path(output_path) | ||
| output_path.parent.mkdir(parents=True, exist_ok=True) | ||
| with open(output_path, "a", encoding="utf-8") as jsonl_output: | ||
| for episode_idx, episode_metrics in enumerate(per_episode_metrics): | ||
| summary_row = { | ||
| "job_name": job.name, | ||
| "episode_idx": episode_idx, | ||
| "arena_env_args": arena_env_args_snapshot, | ||
| "outcomes": metrics_to_plain_python_types(episode_metrics), | ||
| } | ||
| jsonl_output.write(json.dumps(summary_row) + "\n") | ||
|
|
||
| return len(per_episode_metrics) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| from typing import TYPE_CHECKING | ||
|
|
||
| from isaaclab_arena.cli.isaaclab_arena_cli import get_isaaclab_arena_cli_parser | ||
| from isaaclab_arena.evaluation.episode_writer import write_episode_summaries | ||
| from isaaclab_arena.evaluation.eval_runner_cli import add_eval_runner_arguments | ||
| from isaaclab_arena.evaluation.job_manager import Job, JobManager, Status | ||
| from isaaclab_arena.evaluation.policy_runner import get_policy_cls, rollout_policy | ||
|
|
@@ -200,6 +201,15 @@ def main(): | |
| # Check if any job requires cameras and enable them if needed before starting simulation | ||
| enable_cameras_if_required(eval_jobs_config, args_cli) | ||
|
|
||
| # --episode_summary (opt-in): the writer logs the full arena_env_args per episode; | ||
| # the analyzer's factors.yaml decides which keys are factors (no eval-side knowledge). | ||
| episode_summary_enabled = args_cli.episode_summary is not None | ||
| if episode_summary_enabled: | ||
| print( | ||
| "[INFO] Episode summary recording enabled. Per-episode arena_env_args + outcomes" | ||
| f" → {args_cli.episode_summary}" | ||
| ) | ||
|
|
||
| with SimulationAppContext(args_cli): | ||
| job_manager = JobManager(eval_jobs_config["jobs"]) | ||
| metrics_logger = MetricsLogger() | ||
|
|
@@ -250,6 +260,10 @@ def main(): | |
| language_instruction=job.language_instruction, | ||
| ) | ||
|
|
||
| if episode_summary_enabled: | ||
| rows = write_episode_summaries(env, job, args_cli.episode_summary) | ||
| print(f"[INFO] Wrote {rows} episode summaries for job '{job.name}'") | ||
|
Comment on lines
+260
to
+262
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| job_manager.complete_job(job, metrics=metrics, status=Status.COMPLETED) | ||
|
|
||
| # users may not specify metrics for a task, although it's not recommended | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -62,3 +62,33 @@ def compute(self) -> dict[str, Any]: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metrics_data[term_name] = term_cfg.compute_metric_func(recorded_metric_data, **term_cfg.params) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metrics_data["num_episodes"] = get_num_episodes(dataset_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return metrics_data | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def compute_per_episode(self) -> list[dict[str, Any]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Compute every registered metric separately for each recorded episode. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Where :meth:`compute` reduces across all episodes to one aggregate value per | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metric, this returns one ``{metric_name: value}`` dict per episode — each metric's | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compute func is fed that single episode's recorded array (a one-element list). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A list with one metric dict per episode, in recorded order. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dataset_path = get_metric_recorder_dataset_path(self._env) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_episodes = get_num_episodes(dataset_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Recorded data arrives grouped by metric (each term -> one array per episode). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Read it once here, then transpose into one metric dict per episode below. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| episode_arrays_by_term = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| term_name: get_recorded_metric_data(dataset_path, term_cfg.recorder_term_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for term_name, term_cfg in zip(self._term_names, self._term_cfgs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| per_episode_metrics: list[dict[str, Any]] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for episode_index in range(num_episodes): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+76
to
+87
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| episode_metrics: dict[str, Any] = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for term_name, term_cfg in zip(self._term_names, self._term_cfgs): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # compute_metric_func reduces a list of per-episode arrays; give it just this one. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| episode_array = episode_arrays_by_term[term_name][episode_index] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+87
to
+91
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The fix is to sort the demo keys numerically before building the lists in |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| episode_metrics[term_name] = term_cfg.compute_metric_func([episode_array], **term_cfg.params) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| per_episode_metrics.append(episode_metrics) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return per_episode_metrics | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
arena_env_argswhenJobconstructed withoutarena_env_args_dictjob.arena_env_args_dictdefaults to{}when aJobis constructed directly (i.e., not throughJob.from_dict()). In that case the JSONL row is written as"arena_env_args": {}with no error or warning, producing rows that are silently useless to the sensitivity analyzer. At a minimum, adding a guard that logs a warning (and optionally skips writing) when the dict is empty would surface this misconfiguration before it silently corrupts an analysis run.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!