From 5d214dd6b0b0e4f22c182ae398bc502cefc3a06e Mon Sep 17 00:00:00 2001 From: Clemens Volk Date: Thu, 25 Jun 2026 15:43:42 +0200 Subject: [PATCH] feat(sensitivity): interactive Streamlit explorer with posterior conditioning - app.py: fit-once, re-sample-on-change viewer; sidebar outcome toggle and a what-if panel that pins or averages factors and reads the conditioned marginal, with a live slice count - marginals.condition_mask: slice posterior draws to a pinned region (pure, testable) - plotting.plot_marginal / plot_joint: single-panel renderers the app draws through - Streamlit is already a dev dependency; run via streamlit run ... -- --episode_results Signed-off-by: Clemens Volk --- isaaclab_arena/analysis/sensitivity/app.py | 168 ++++++++++++++++++ .../analysis/sensitivity/marginals.py | 36 ++++ .../analysis/sensitivity/plotting.py | 93 ++++++++++ 3 files changed, 297 insertions(+) create mode 100644 isaaclab_arena/analysis/sensitivity/app.py diff --git a/isaaclab_arena/analysis/sensitivity/app.py b/isaaclab_arena/analysis/sensitivity/app.py new file mode 100644 index 000000000..3c0abc2dd --- /dev/null +++ b/isaaclab_arena/analysis/sensitivity/app.py @@ -0,0 +1,168 @@ +# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Interactive sensitivity explorer: a Streamlit shell over the amortized posterior. + +The estimator is amortized, so re-conditioning it on a new outcome is a cheap re-sample with no +retraining. This app exposes that: the posterior is fit once (cached), and every control change +re-samples it and redraws the importance ranking, the per-factor marginals, and a chosen pairwise +joint. It is a development/exploration tool — Streamlit is a dev dependency, not a runtime one. + +Run (after ``pip install -e .[dev]`` for streamlit): + + streamlit run isaaclab_arena/analysis/sensitivity/app.py -- --episode_results path/to/episode_results.jsonl + +Everything after the ``--`` is passed to this script; the path can also be set in the sidebar. +""" + +from __future__ import annotations + +import argparse +import matplotlib.pyplot as plt +import torch + +import streamlit as st + +from isaaclab_arena.analysis.sensitivity.analyzer import SensitivityAnalyzer +from isaaclab_arena.analysis.sensitivity.dataset import SensitivityDataset +from isaaclab_arena.analysis.sensitivity.episode_results_reader import dataset_from_episode_results +from isaaclab_arena.analysis.sensitivity.marginals import condition_mask +from isaaclab_arena.analysis.sensitivity.plotting import plot_marginal + +# Posterior sampling is amortized (cheap), so we draw a large fixed pool: enough to keep conditioned +# slices (the what-if panel) populated without a user-facing knob. It only sets MC resolution. +_NUM_SAMPLES = 50000 +_THIN_SLICE_WARNING = 200 +"""Below this many draws in a conditioned slice, warn that the curve is unreliable.""" + + +def _parse_args() -> argparse.Namespace: + """Parse the CLI args Streamlit forwards after ``--`` (path is optional; sidebar can set it).""" + parser = argparse.ArgumentParser(description="Interactive sensitivity explorer.") + parser.add_argument("--episode_results", type=str, default="", help="Path to episode_results.jsonl.") + parser.add_argument("--outcome", type=str, nargs="+", default=["success"], help="Outcome field(s).") + # Streamlit injects its own argv; parse_known_args ignores anything that isn't ours. + args, _ = parser.parse_known_args() + return args + + +@st.cache_resource(show_spinner="Fitting posterior…") +def _load_and_fit( + episode_results_path: str, outcome_names: tuple[str, ...], seed: int +) -> tuple[SensitivityDataset, SensitivityAnalyzer]: + """Build the dataset and fit the analyzer once per (path, outcomes, seed). + + Cached as a resource: the fitted analyzer holds a torch model, so it is reused across reruns + and only refit when one of these inputs changes. + """ + torch.manual_seed(seed) + dataset = dataset_from_episode_results(episode_results_path, outcome_names) + analyzer = SensitivityAnalyzer(dataset) + analyzer.fit() + return dataset, analyzer + + +def _outcome_controls(dataset: SensitivityDataset) -> torch.Tensor: + """Render one sidebar control per outcome and return the observation vector to condition on. + + A binary outcome (values all 0/1) gets a success/failure toggle; any other outcome gets a + slider over its observed range — the continuous-conditioning the amortized posterior allows. + """ + st.sidebar.subheader("Condition on outcome") + values: list[float] = [] + for index, name in enumerate(dataset.outcome_names): + column = dataset.x[:, index] + is_binary = set(column.tolist()).issubset({0.0, 1.0}) + if is_binary: + choice = st.sidebar.radio( + name, options=[1.0, 0.0], format_func=lambda v: "success (1)" if v == 1.0 else "failure (0)" + ) + values.append(float(choice)) + else: + low, high = float(column.min()), float(column.max()) + values.append(st.sidebar.slider(name, min_value=low, max_value=high, value=high)) + return torch.tensor(values, dtype=torch.float32) + + +def _conditioning_panel(samples: torch.Tensor, dataset: SensitivityDataset, observation: torch.Tensor) -> None: + """What-if panel: pin some factors and view another's conditional posterior marginal. + + Picks a factor to view, lets every other factor be pinned (continuous → a range band, categorical + → a choice), slices the draws to that pinned region, and redraws the view factor's marginal from + the survivors. Pinned factors are conditioned on; unpinned ones are averaged over. A live count + surfaces when the slice is too thin to trust. + """ + factor_names = [factor.name for factor in dataset.factors] + if not factor_names: + return + + st.subheader("Conditioning (what-if)") + st.caption("Pin other factors to slice the posterior; unpinned factors are averaged over.") + view = st.selectbox("View factor", factor_names, index=0, key="condition_view") + + continuous_windows: dict[str, tuple[float, float]] = {} + categorical_choices: dict[str, int] = {} + for factor in dataset.factors: + if factor.name == view: + continue + if not st.checkbox(f"pin {factor.name}", value=False, key=f"pin_{factor.name}"): + continue + if factor.type == "continuous": + low, high = float(factor.range[0]), float(factor.range[1]) + span = high - low + # Default to a central band so pinning has a visible effect; the user widens/moves it. + window = st.slider( + factor.name, + min_value=low, + max_value=high, + value=(low + 0.4 * span, low + 0.6 * span), + key=f"window_{factor.name}", + ) + continuous_windows[factor.name] = window + else: + choice = st.selectbox(f"{factor.name} =", factor.choices, key=f"choice_{factor.name}") + categorical_choices[factor.name] = factor.choices.index(choice) + + mask = condition_mask(samples, dataset, continuous_windows, categorical_choices) + num_in_slice = int(mask.sum()) + st.caption(f"{num_in_slice} / {len(mask)} samples in slice") + if num_in_slice == 0: + st.warning("No samples in this slice — widen a window or unpin a factor.") + return + if num_in_slice < _THIN_SLICE_WARNING: + st.warning( + f"Thin slice ({num_in_slice} samples): the curve reflects the fitted model more than the " + "data here. Widen a window or unpin a factor." + ) + st.pyplot(plot_marginal(samples[torch.as_tensor(mask)], dataset, view, observation), use_container_width=False) + + +def main() -> None: + """Run the interactive explorer: fit once, then re-sample and redraw on every control change.""" + st.set_page_config(page_title="Sensitivity Explorer", layout="wide") + st.title("Sensitivity Explorer") + + args = _parse_args() + episode_results_path = st.sidebar.text_input("episode_results.jsonl", value=args.episode_results) + if not episode_results_path: + st.info("Set the path to an episode_results.jsonl in the sidebar to begin.") + return + + seed = st.sidebar.number_input("seed", value=0, step=1) + dataset, analyzer = _load_and_fit(episode_results_path, tuple(args.outcome), int(seed)) + + observation = _outcome_controls(dataset) + + # Re-seed before sampling so identical controls reproduce the same draws across reruns. + torch.manual_seed(int(seed)) + samples = analyzer.sample_posterior(observation, num_samples=_NUM_SAMPLES) + + _conditioning_panel(samples, dataset, observation) + + plt.close("all") + + +if __name__ == "__main__": + main() diff --git a/isaaclab_arena/analysis/sensitivity/marginals.py b/isaaclab_arena/analysis/sensitivity/marginals.py index 98d641f8d..2734b9d72 100644 --- a/isaaclab_arena/analysis/sensitivity/marginals.py +++ b/isaaclab_arena/analysis/sensitivity/marginals.py @@ -123,3 +123,39 @@ def factor_importances(samples: torch.Tensor, dataset: SensitivityDataset) -> li for factor in dataset.factors ] return sorted(scored, key=lambda name_score: name_score[1], reverse=True) + + +def condition_mask( + samples: torch.Tensor, + dataset: SensitivityDataset, + continuous_windows: dict[str, tuple[float, float]], + categorical_choices: dict[str, int], +) -> np.ndarray: + """Boolean mask over posterior draws that fall inside every pinned factor's constraint. + + Conditioning by slicing: keep the draws whose pinned continuous factors lie in their window and + whose pinned categorical factors equal their chosen code. With a uniform sampling prior the + surviving draws approximate the conditional posterior p(unpinned | outcome, pinned), which is + proportional to the conditional success surface. Factors named in neither dict are left free + (averaged over). Accuracy here is bounded by the number of draws *and* by how much real data + backs the slice — a thin slice over sparse data is unreliable however many draws land in it. + + Args: + samples: ``(num_samples, num_factors)`` posterior draws in the dataset's factor layout. + dataset: The dataset, for the column layout and factor types. + continuous_windows: factor name → (low, high) band to keep, for continuous factors. + categorical_choices: factor name → integer choice code to keep, for categorical factors. + + Returns: + A length-``num_samples`` boolean array, True where a draw satisfies all constraints. + """ + sample_array = samples.cpu().numpy() + columns = dataset.factor_columns + mask = np.ones(sample_array.shape[0], dtype=bool) + for name, (low, high) in continuous_windows.items(): + values = sample_array[:, columns[name]].squeeze(-1) + mask &= (values >= low) & (values <= high) + for name, code in categorical_choices.items(): + codes = np.round(sample_array[:, columns[name]].squeeze(-1)).astype(int) + mask &= codes == code + return mask diff --git a/isaaclab_arena/analysis/sensitivity/plotting.py b/isaaclab_arena/analysis/sensitivity/plotting.py index e24707ce7..fc95849ba 100644 --- a/isaaclab_arena/analysis/sensitivity/plotting.py +++ b/isaaclab_arena/analysis/sensitivity/plotting.py @@ -169,6 +169,99 @@ def _draw_categorical_marginal(ax, factor: FactorSpec, factor_samples: np.ndarra ax.set_ylabel("posterior probability") +def plot_marginal( + samples: torch.Tensor, + dataset: SensitivityDataset, + factor_name: str, + observation: torch.Tensor, + output_path: str | None = None, +): + """Posterior marginal of a single named factor, on its own figure. + + The one-panel counterpart to plot_marginals, for drawing one factor from an arbitrary set of + draws — e.g. a conditioned subset, where samples is already sliced to a pinned region. + + Args: + samples: ``(num_samples, num_factors)`` posterior draws in the dataset's factor layout + (may be a conditioned subset). + dataset: The dataset, for the factor schema and column layout. + factor_name: Name of the factor to draw. + observation: The outcome vector the samples were conditioned on (shown in the title). + output_path: If given, save the figure here; the format follows the path's extension. + + Returns: + The matplotlib Figure. + """ + sample_array = samples.cpu().numpy() + factor = {factor.name: factor for factor in dataset.factors}[factor_name] + factor_samples = sample_array[:, dataset.factor_columns[factor_name]].squeeze(-1) + + figure, ax = plt.subplots(figsize=(5.0, 3.5)) + if factor.type == "continuous": + _draw_continuous_marginal(ax, factor, factor_samples) + else: + _draw_categorical_marginal(ax, factor, factor_samples) + + observation_label = ", ".join( + f"{name}={value:g}" for name, value in zip(dataset.outcome_names, observation.tolist()) + ) + ax.set_title(f"{factor_name} (observed: {observation_label}; n={len(factor_samples)})", fontsize=11) + figure.tight_layout() + + if output_path is not None: + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + figure.savefig(output_path, dpi=150, bbox_inches="tight") + return figure + + +def plot_joint( + samples: torch.Tensor, + dataset: SensitivityDataset, + factor_x_name: str, + factor_y_name: str, + observation: torch.Tensor, + output_path: str | None = None, +): + """Single-pair joint posterior of two named factors, on its own figure. + + The one-cell counterpart to plot_corner, for picking an interaction to look at (e.g. the + interactive app's factor-pair selector) without rendering the whole grid. + + Args: + samples: ``(num_samples, num_factors)`` posterior draws in the dataset's factor layout. + dataset: The dataset, for the factor schema and column layout. + factor_x_name: Name of the factor on the horizontal axis. + factor_y_name: Name of the factor on the vertical axis. + observation: The outcome vector the samples were conditioned on (shown in the title). + output_path: If given, save the figure here; the format follows the path's extension. + + Returns: + The matplotlib Figure. + """ + sample_array = samples.cpu().numpy() + factors_by_name = {factor.name: factor for factor in dataset.factors} + factor_x, factor_y = factors_by_name[factor_x_name], factors_by_name[factor_y_name] + columns = dataset.factor_columns + samples_x = sample_array[:, columns[factor_x_name]].squeeze(-1) + samples_y = sample_array[:, columns[factor_y_name]].squeeze(-1) + + figure, ax = plt.subplots(figsize=(6.0, 5.0)) + _draw_joint(ax, factor_x, factor_y, samples_x, samples_y) + ax.set_xlabel(factor_x_name) + ax.set_ylabel(factor_y_name) + + observation_label = ", ".join( + f"{name}={value:g}" for name, value in zip(dataset.outcome_names, observation.tolist()) + ) + ax.set_title(f"Joint posterior (observed: {observation_label})", fontsize=12, fontweight="bold") + figure.tight_layout() + + if output_path is not None: + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + figure.savefig(output_path, dpi=150, bbox_inches="tight") + return figure + + def plot_corner( samples: torch.Tensor, dataset: SensitivityDataset,