Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions isaaclab_arena/analysis/sensitivity/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import math
import matplotlib.pyplot as plt
import numpy as np
import re
from pathlib import Path
from scipy.stats import gaussian_kde
from typing import TYPE_CHECKING
Expand All @@ -19,7 +20,7 @@

_CONTINUOUS_COLOR = "steelblue"
_CATEGORICAL_COLOR = "steelblue"
_MEAN_COLOR = "firebrick"
_PRIOR_COLOR = "grey"


def plot_marginals(
Expand All @@ -32,7 +33,8 @@ def plot_marginals(

A pure renderer: it draws already-sampled posterior draws and does not run inference.
One panel per factor — a density curve for continuous factors, a probability bar chart
for categorical ones, wrapped into a grid.
for categorical ones, wrapped into a grid. Panels for components of the same vector
variation share a y-axis, so their densities compare directly.

Args:
samples: ``(num_samples, num_factors)`` posterior draws in the dataset's factor
Expand All @@ -52,17 +54,30 @@ def plot_marginals(
num_rows = math.ceil(len(factors) / num_columns)
figure, axes = plt.subplots(num_rows, num_columns, figsize=(6.0 * num_columns, 4.5 * num_rows), squeeze=False)
flat_axes = axes.flatten()
continuous_axes_by_variation: dict[str, list] = {}
for axis_index, factor in enumerate(factors):
ax = flat_axes[axis_index]
factor_samples = samples[:, dataset.factor_columns[factor.name]].squeeze(-1)
if factor.type == "continuous":
_draw_continuous_marginal(ax, factor, factor_samples)
# Components of one vector variation (name[0], name[1], ...) share a scale.
variation_name = re.sub(r"\[\d+\]$", "", factor.name)
Comment thread
cvolkcvolk marked this conversation as resolved.
Comment thread
cvolkcvolk marked this conversation as resolved.
continuous_axes_by_variation.setdefault(variation_name, []).append(ax)
else:
_draw_categorical_marginal(ax, factor, factor_samples)
ax.set_title(factor.name, fontsize=11)
for unused_index in range(len(factors), len(flat_axes)):
flat_axes[unused_index].axis("off")

# Give the components of a vector variation a common y-axis so their densities compare directly.
# A standalone scalar factor keeps its own scale, since unrelated factors can differ in magnitude.
for grouped_axes in continuous_axes_by_variation.values():
if len(grouped_axes) < 2:
continue
shared_top = max(grouped_ax.get_ylim()[1] for grouped_ax in grouped_axes)
for grouped_ax in grouped_axes:
grouped_ax.set_ylim(0, shared_top)

observation_label = ", ".join(
f"{name}={value:g}" for name, value in zip(dataset.outcome_names, observation.tolist())
)
Expand All @@ -80,21 +95,32 @@ def plot_marginals(


def _draw_continuous_marginal(ax, factor: FactorSpec, factor_samples: np.ndarray) -> None:
"""Smooth posterior density (filled KDE curve) of a continuous factor, with a mean line.
"""Posterior density of a continuous factor over its swept range.

A KDE line over the posterior samples reads the shape of a continuous posterior better
than a binned histogram. Falls back to a single line at the mean when the samples have
no spread (KDE bandwidth is then undefined).
Draws the KDE of the posterior samples, the uniform prior as a flat reference, and shades
the central 5-95% of the posterior. Reading the posterior against the prior shows whether
conditioning on the outcome concentrated the factor, which a mean alone would miss for a
factor swept symmetrically around its nominal value.
"""
range_low, range_high = factor.range
sample_mean = float(np.mean(factor_samples))
span = range_high - range_low

if float(np.std(factor_samples)) >= 1e-9:
grid = np.linspace(range_low, range_high, 200)
density = gaussian_kde(factor_samples)(grid)
ax.plot(grid, density, color=_CONTINUOUS_COLOR, linewidth=2)
ax.plot(grid, density, color=_CONTINUOUS_COLOR, linewidth=2, label="posterior")
Comment thread
cvolkcvolk marked this conversation as resolved.
ax.fill_between(grid, 0, density, color=_CONTINUOUS_COLOR, alpha=0.2)
ax.set_ylim(bottom=0)
ax.axvline(sample_mean, color=_MEAN_COLOR, linestyle="--", linewidth=2, label=f"mean = {sample_mean:.3g}")
low_percentile, high_percentile = np.percentile(factor_samples, [5, 95])
ax.axvspan(low_percentile, high_percentile, color=_CONTINUOUS_COLOR, alpha=0.15, label="5-95%")
Comment thread
cvolkcvolk marked this conversation as resolved.
else:
ax.axvline(float(np.mean(factor_samples)), color=_CONTINUOUS_COLOR, linewidth=2, label="constant")
Comment thread
cvolkcvolk marked this conversation as resolved.
ax.set_ylim(bottom=0)

if span > 0:
Comment thread
cvolkcvolk marked this conversation as resolved.
# The uniform prior is the "no effect" reference the posterior is read against.
ax.axhline(1.0 / span, color=_PRIOR_COLOR, linestyle="--", linewidth=1.5, label="prior (uniform)")
Comment thread
cvolkcvolk marked this conversation as resolved.

ax.set_xlim(range_low, range_high)
ax.set_xlabel(factor.name)
ax.set_ylabel("posterior density")
Expand Down
Loading