diff --git a/isaaclab_arena/analysis/sensitivity/plotting.py b/isaaclab_arena/analysis/sensitivity/plotting.py index 5dd0ef2cb..e00f4d062 100644 --- a/isaaclab_arena/analysis/sensitivity/plotting.py +++ b/isaaclab_arena/analysis/sensitivity/plotting.py @@ -8,6 +8,7 @@ import math import matplotlib.pyplot as plt import numpy as np +import re from pathlib import Path from scipy.stats import gaussian_kde from typing import TYPE_CHECKING @@ -19,7 +20,7 @@ _CONTINUOUS_COLOR = "steelblue" _CATEGORICAL_COLOR = "steelblue" -_MEAN_COLOR = "firebrick" +_PRIOR_COLOR = "grey" def plot_marginals( @@ -32,7 +33,8 @@ def plot_marginals( A pure renderer: it draws already-sampled posterior draws and does not run inference. One panel per factor — a density curve for continuous factors, a probability bar chart - for categorical ones, wrapped into a grid. + for categorical ones, wrapped into a grid. Panels for components of the same vector + variation share a y-axis, so their densities compare directly. Args: samples: ``(num_samples, num_factors)`` posterior draws in the dataset's factor @@ -52,17 +54,30 @@ def plot_marginals( num_rows = math.ceil(len(factors) / num_columns) figure, axes = plt.subplots(num_rows, num_columns, figsize=(6.0 * num_columns, 4.5 * num_rows), squeeze=False) flat_axes = axes.flatten() + continuous_axes_by_variation: dict[str, list] = {} for axis_index, factor in enumerate(factors): ax = flat_axes[axis_index] factor_samples = samples[:, dataset.factor_columns[factor.name]].squeeze(-1) if factor.type == "continuous": _draw_continuous_marginal(ax, factor, factor_samples) + # Components of one vector variation (name[0], name[1], ...) share a scale. + variation_name = re.sub(r"\[\d+\]$", "", factor.name) + continuous_axes_by_variation.setdefault(variation_name, []).append(ax) else: _draw_categorical_marginal(ax, factor, factor_samples) ax.set_title(factor.name, fontsize=11) for unused_index in range(len(factors), len(flat_axes)): flat_axes[unused_index].axis("off") + # Give the components of a vector variation a common y-axis so their densities compare directly. + # A standalone scalar factor keeps its own scale, since unrelated factors can differ in magnitude. + for grouped_axes in continuous_axes_by_variation.values(): + if len(grouped_axes) < 2: + continue + shared_top = max(grouped_ax.get_ylim()[1] for grouped_ax in grouped_axes) + for grouped_ax in grouped_axes: + grouped_ax.set_ylim(0, shared_top) + observation_label = ", ".join( f"{name}={value:g}" for name, value in zip(dataset.outcome_names, observation.tolist()) ) @@ -80,21 +95,32 @@ def plot_marginals( def _draw_continuous_marginal(ax, factor: FactorSpec, factor_samples: np.ndarray) -> None: - """Smooth posterior density (filled KDE curve) of a continuous factor, with a mean line. + """Posterior density of a continuous factor over its swept range. - A KDE line over the posterior samples reads the shape of a continuous posterior better - than a binned histogram. Falls back to a single line at the mean when the samples have - no spread (KDE bandwidth is then undefined). + Draws the KDE of the posterior samples, the uniform prior as a flat reference, and shades + the central 5-95% of the posterior. Reading the posterior against the prior shows whether + conditioning on the outcome concentrated the factor, which a mean alone would miss for a + factor swept symmetrically around its nominal value. """ range_low, range_high = factor.range - sample_mean = float(np.mean(factor_samples)) + span = range_high - range_low + if float(np.std(factor_samples)) >= 1e-9: grid = np.linspace(range_low, range_high, 200) density = gaussian_kde(factor_samples)(grid) - ax.plot(grid, density, color=_CONTINUOUS_COLOR, linewidth=2) + ax.plot(grid, density, color=_CONTINUOUS_COLOR, linewidth=2, label="posterior") ax.fill_between(grid, 0, density, color=_CONTINUOUS_COLOR, alpha=0.2) ax.set_ylim(bottom=0) - ax.axvline(sample_mean, color=_MEAN_COLOR, linestyle="--", linewidth=2, label=f"mean = {sample_mean:.3g}") + low_percentile, high_percentile = np.percentile(factor_samples, [5, 95]) + ax.axvspan(low_percentile, high_percentile, color=_CONTINUOUS_COLOR, alpha=0.15, label="5-95%") + else: + ax.axvline(float(np.mean(factor_samples)), color=_CONTINUOUS_COLOR, linewidth=2, label="constant") + ax.set_ylim(bottom=0) + + if span > 0: + # The uniform prior is the "no effect" reference the posterior is read against. + ax.axhline(1.0 / span, color=_PRIOR_COLOR, linestyle="--", linewidth=1.5, label="prior (uniform)") + ax.set_xlim(range_low, range_high) ax.set_xlabel(factor.name) ax.set_ylabel("posterior density")