From d7b6182bc608ee6ccfa7f810e487e0c4787292fe Mon Sep 17 00:00:00 2001 From: zhx06 Date: Thu, 25 Jun 2026 16:36:35 -0700 Subject: [PATCH 1/7] mesh-based non-collision constraints Signed-off-by: zhx06 --- isaaclab_arena/assets/dummy_object.py | 22 +- isaaclab_arena/assets/object.py | 14 +- isaaclab_arena/assets/object_base.py | 14 + isaaclab_arena/cli/isaaclab_arena_cli.py | 7 + .../environments/arena_env_builder.py | 24 +- .../isaaclab_arena_environment.py | 5 + .../environments/relation_solver_interface.py | 36 +- isaaclab_arena/relations/collision_mode.py | 16 + isaaclab_arena/relations/mesh_pair_cache.py | 88 +++ isaaclab_arena/relations/object_placer.py | 290 ++++++-- .../relations/object_placer_params.py | 4 +- isaaclab_arena/relations/placement_result.py | 3 +- .../relations/relation_loss_strategies.py | 48 +- isaaclab_arena/relations/relation_solver.py | 452 +++++++++++- .../relations/relation_solver_params.py | 7 + isaaclab_arena/relations/warp_mesh_manager.py | 212 ++++++ isaaclab_arena/relations/warp_sdf_kernels.py | 232 ++++++ isaaclab_arena/tests/test_mesh_collision.py | 696 ++++++++++++++++++ .../test_object_placer_reproducibility.py | 19 +- .../tests/test_relation_solver_interface.py | 10 +- .../tests/test_usd_scale_helpers.py | 279 +++++++ .../tests/test_validate_placement.py | 10 +- isaaclab_arena/utils/pose.py | 11 + isaaclab_arena/utils/usd_helpers.py | 69 ++ 24 files changed, 2402 insertions(+), 166 deletions(-) create mode 100644 isaaclab_arena/relations/collision_mode.py create mode 100644 isaaclab_arena/relations/mesh_pair_cache.py create mode 100644 isaaclab_arena/relations/warp_mesh_manager.py create mode 100644 isaaclab_arena/relations/warp_sdf_kernels.py create mode 100644 isaaclab_arena/tests/test_mesh_collision.py create mode 100644 isaaclab_arena/tests/test_usd_scale_helpers.py diff --git a/isaaclab_arena/assets/dummy_object.py b/isaaclab_arena/assets/dummy_object.py index aabe95bfe4..8c1349dfd5 100644 --- a/isaaclab_arena/assets/dummy_object.py +++ b/isaaclab_arena/assets/dummy_object.py @@ -2,17 +2,21 @@ # All rights reserved. # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import torch +from typing import TYPE_CHECKING -from isaaclab_arena.relations.relations import Relation, RelationBase, UnaryRelation +from isaaclab_arena.relations.relations import IsAnchor, Relation, RelationBase, UnaryRelation from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, quaternion_to_90_deg_z_quarters from isaaclab_arena.utils.pose import Pose +if TYPE_CHECKING: + import trimesh + class DummyObject: - """ - Dummy object for testing purposes without Isaac Sim dependencies. - """ + """Dummy object for testing purposes without Isaac Sim dependencies.""" def __init__( self, @@ -20,6 +24,7 @@ def __init__( bounding_box: AxisAlignedBoundingBox, initial_pose: Pose | None = None, relations: list[RelationBase] = [], + collision_mesh: trimesh.Trimesh | None = None, **kwargs, ): self.name = name @@ -27,6 +32,7 @@ def __init__( self.bounding_box = bounding_box assert self.bounding_box is not None self.relations = list(relations) + self._collision_mesh = collision_mesh def add_relation(self, relation: RelationBase) -> None: self.relations.append(relation) @@ -63,3 +69,11 @@ def get_initial_pose(self) -> Pose | None: def is_initial_pose_set(self) -> bool: return self.initial_pose is not None + + @property + def is_anchor(self) -> bool: + return any(isinstance(r, IsAnchor) for r in self.relations) + + def get_collision_mesh(self) -> trimesh.Trimesh | None: + """Return the collision mesh, or None to fall back to AABB.""" + return self._collision_mesh diff --git a/isaaclab_arena/assets/object.py b/isaaclab_arena/assets/object.py index e2a92af368..600bfbbe7d 100644 --- a/isaaclab_arena/assets/object.py +++ b/isaaclab_arena/assets/object.py @@ -2,8 +2,13 @@ # All rights reserved. # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import torch -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + import trimesh from isaaclab.assets import ArticulationCfg, AssetBaseCfg, RigidObjectCfg from isaaclab.sensors.contact_sensor.contact_sensor_cfg import ContactSensorCfg @@ -20,9 +25,7 @@ class Object(ObjectBase): - """ - Encapsulates the pick-up object config for a pick-and-place environment. - """ + """Pick-up object config for a pick-and-place environment.""" def __init__( self, @@ -74,6 +77,9 @@ def get_bounding_box(self) -> AxisAlignedBoundingBox: self.bounding_box = compute_local_bounding_box_from_usd(self.usd_path, self.scale) return self.bounding_box + def get_collision_mesh(self) -> trimesh.Trimesh | None: + """Return None; collision mesh is not available at the object level.""" + def get_world_bounding_box(self) -> AxisAlignedBoundingBox: """Get bounding box in world coordinates (local bbox rotated and translated). diff --git a/isaaclab_arena/assets/object_base.py b/isaaclab_arena/assets/object_base.py index 075d8f1c91..2a157475fd 100644 --- a/isaaclab_arena/assets/object_base.py +++ b/isaaclab_arena/assets/object_base.py @@ -7,9 +7,13 @@ import torch from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import warp as wp from isaaclab.assets import ArticulationCfg, AssetBaseCfg, RigidObjectCfg + +if TYPE_CHECKING: + import trimesh from isaaclab.envs import ManagerBasedEnv from isaaclab.managers import EventTermCfg, SceneEntityCfg from isaaclab.sensors.contact_sensor.contact_sensor_cfg import ContactSensorCfg @@ -73,6 +77,9 @@ def get_world_bounding_box(self) -> AxisAlignedBoundingBox: """Get bounding box in world coordinates (local bbox rotated and translated).""" ... + def get_collision_mesh(self) -> trimesh.Trimesh | None: + """Return collision mesh, or None to fall back to AABB overlap.""" + def _get_initial_pose_as_pose(self) -> Pose | None: """Return a single ``Pose`` suitable for *init_state* and bounding-box calculations. @@ -170,6 +177,13 @@ def get_relations(self) -> list[RelationBase]: """Get all relations for this object.""" return self.relations + @property + def is_anchor(self) -> bool: + """True if this object has an IsAnchor relation.""" + from isaaclab_arena.relations.relations import IsAnchor + + return any(isinstance(r, IsAnchor) for r in self.relations) + def get_spatial_relations(self) -> list[RelationBase]: """Get only spatial relations (On, NextTo, AtPosition, etc.), excluding markers like IsAnchor.""" return [r for r in self.relations if isinstance(r, (Relation, UnaryRelation))] diff --git a/isaaclab_arena/cli/isaaclab_arena_cli.py b/isaaclab_arena/cli/isaaclab_arena_cli.py index 50d0e2ec05..cb2846d9e6 100644 --- a/isaaclab_arena/cli/isaaclab_arena_cli.py +++ b/isaaclab_arena/cli/isaaclab_arena_cli.py @@ -92,6 +92,13 @@ def add_isaaclab_arena_cli_args(parser: argparse.ArgumentParser) -> None: default=False, help="Print Hydra-configurable variations for the selected environment and exit.", ) + arena_group.add_argument( + "--collision_mode", + type=str, + choices=["bbox", "mesh"], + default="bbox", + help="Collision detection mode: 'bbox' (AABB, default) or 'mesh' (sphere-to-SDF, requires Warp).", + ) def add_env_graph_spec_cli_args(parser: argparse.ArgumentParser) -> None: diff --git a/isaaclab_arena/environments/arena_env_builder.py b/isaaclab_arena/environments/arena_env_builder.py index 6aa286fe85..4536031afa 100644 --- a/isaaclab_arena/environments/arena_env_builder.py +++ b/isaaclab_arena/environments/arena_env_builder.py @@ -36,7 +36,10 @@ ) from isaaclab_arena.recording.common_terms import CoreEpisodeRecorderTermCfg, VariationEpisodeRecorderTermCfg from isaaclab_arena.recording.episode_recorder_manager import EpisodeRecorderTermCfg +from isaaclab_arena.relations.collision_mode import CollisionMode +from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams from isaaclab_arena.relations.placement_events import PLACEMENT_RESET_EVENT_NAME +from isaaclab_arena.relations.relation_solver_params import RelationSolverParams from isaaclab_arena.tasks.no_task import NoTask from isaaclab_arena.utils.configclass import combine_configclass_instances, make_configclass from isaaclab_arena.utils.isaaclab_utils.simulation_app import reapply_viewer_cfg @@ -82,12 +85,27 @@ def _solve_relations(self) -> None: events restore the same layout every time. """ objects_with_relations = self.arena_env.scene.get_objects_with_relations() + + # Prefer env-level placer_params; fall back to CLI-constructed defaults. + placer_params = self.arena_env.placer_params + if placer_params is None: + collision_mode_str = getattr(self.args, "collision_mode", "bbox") + mode = CollisionMode.MESH if collision_mode_str == "mesh" else CollisionMode.BBOX + placer_params = ObjectPlacerParams( + placement_seed=self.args.placement_seed, + random_yaw_init=self.args.random_yaw_init, + solver_params=RelationSolverParams( + collision_mode=mode, + save_position_history=False, + verbose=False, + ), + ) + if self.args.resolve_on_reset is not None: + placer_params.resolve_on_reset = self.args.resolve_on_reset self._placement_event_cfg = solve_and_apply_relation_placement( objects_with_relations, num_envs=self.args.num_envs, - placement_seed=self.args.placement_seed, - resolve_on_reset=self.args.resolve_on_reset, - random_yaw_init=self.args.random_yaw_init, + placer_params=placer_params, ) def get_all_variations(self) -> dict[str, list[VariationBase]]: diff --git a/isaaclab_arena/environments/isaaclab_arena_environment.py b/isaaclab_arena/environments/isaaclab_arena_environment.py index 7ef8be9fa6..708992eba0 100644 --- a/isaaclab_arena/environments/isaaclab_arena_environment.py +++ b/isaaclab_arena/environments/isaaclab_arena_environment.py @@ -13,6 +13,7 @@ from isaaclab_arena.embodiments.embodiment_base import EmbodimentBase from isaaclab_arena.environments.isaaclab_arena_manager_based_env_cfg import IsaacLabArenaManagerBasedRLEnvCfg from isaaclab_arena.recording.episode_recorder_manager import EpisodeRecorderTermCfg + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams from isaaclab_arena.scene.scene import Scene from isaaclab_arena.tasks.task_base import TaskBase @@ -31,6 +32,7 @@ def __init__( rl_framework_entry_point: str | None = None, rl_policy_cfg: str | None = None, episode_recorder_terms: dict[str, EpisodeRecorderTermCfg] | None = None, + placer_params: ObjectPlacerParams | None = None, ): """ Args: @@ -50,6 +52,8 @@ def __init__( ``"my_module:RLPolicyCfg"``. episode_recorder_terms: Additional per-episode recorder terms to record alongside the built-in ones, keyed by name. + placer_params: Object placement configuration. When set, used as-is + (CLI flags are ignored). When None, params are built from CLI flags. """ self.name = name self.scene = scene @@ -62,3 +66,4 @@ def __init__( self.rl_framework_entry_point = rl_framework_entry_point self.rl_policy_cfg = rl_policy_cfg self.episode_recorder_terms = episode_recorder_terms or {} + self.placer_params = placer_params diff --git a/isaaclab_arena/environments/relation_solver_interface.py b/isaaclab_arena/environments/relation_solver_interface.py index aad1019d36..3d903100aa 100644 --- a/isaaclab_arena/environments/relation_solver_interface.py +++ b/isaaclab_arena/environments/relation_solver_interface.py @@ -5,12 +5,12 @@ from __future__ import annotations +import copy from typing import TYPE_CHECKING from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams from isaaclab_arena.relations.placement_events import get_rotation_xyzw, solve_and_place_objects from isaaclab_arena.relations.pooled_object_placer import PooledObjectPlacer -from isaaclab_arena.relations.relation_solver_params import RelationSolverParams from isaaclab_arena.relations.relations import get_anchor_objects from isaaclab_arena.utils.pose import Pose, PosePerEnv, rotate_quat_by_yaw @@ -23,39 +23,19 @@ def solve_and_apply_relation_placement( objects: list[ObjectBase], num_envs: int, - placement_seed: int | None = None, - resolve_on_reset: bool | None = None, - random_yaw_init: bool = False, + placer_params: ObjectPlacerParams | None = None, ) -> EventTermCfg | None: - """Solve relation placement and apply the result to object reset/static state. - - Args: - objects: Objects with spatial predicates that should be relation-solved. - num_envs: Number of environments to prepare placements for. - placement_seed: Optional random seed for reproducible object placement. - resolve_on_reset: Optional override for whether to draw fresh layouts from - the placement pool on reset. When ``False``, fixed per-environment - initial poses are applied immediately. - random_yaw_init: If True, randomly rotates non-anchor objects around the vertical (Z) - axis at startup to add visual variety to the scene. - - Returns: - Reset event config to attach to the environment when placement should be - resolved on reset. Returns ``None`` when no reset event is needed. - """ + """Solve relation placement and return a reset EventTermCfg (or None if no objects).""" objects = list(objects) if not objects: print("No objects with relations found in scene. Skipping relation solving.") return None - placer_params = ObjectPlacerParams( - placement_seed=placement_seed, - apply_positions_to_objects=False, - solver_params=RelationSolverParams(save_position_history=False, verbose=False), - random_yaw_init=random_yaw_init, - ) - if resolve_on_reset is not None: - placer_params.resolve_on_reset = resolve_on_reset + if placer_params is None: + placer_params = ObjectPlacerParams() + else: + placer_params = copy.copy(placer_params) + placer_params.apply_positions_to_objects = False # TODO(xinjieyao, 2026-05-22): Add joint object/embodiment placement once task-dependent # reachability constraints are available. For now this always uses the object-only placer. diff --git a/isaaclab_arena/relations/collision_mode.py b/isaaclab_arena/relations/collision_mode.py new file mode 100644 index 0000000000..90c13a726b --- /dev/null +++ b/isaaclab_arena/relations/collision_mode.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from enum import Enum + + +class CollisionMode(Enum): + """Selects which collision detection method the solver uses for no-overlap constraints.""" + + BBOX = "bbox" + """Axis-aligned bounding box overlap volume (fast, conservative).""" + + MESH = "mesh" + """Sphere-to-SDF queries against actual mesh geometry (accurate, slower).""" diff --git a/isaaclab_arena/relations/mesh_pair_cache.py b/isaaclab_arena/relations/mesh_pair_cache.py new file mode 100644 index 0000000000..1ad82e1dff --- /dev/null +++ b/isaaclab_arena/relations/mesh_pair_cache.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Typed container for precomputed mesh-collision pair data.""" + +from __future__ import annotations + +import torch +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import warp as wp + + from isaaclab_arena.assets.object_base import ObjectBase + + +@dataclass(slots=True) +class MeshPairCache: + """Precomputed per-pair collision data for the vectorized multi-mesh kernel.""" + + all_centers_local: torch.Tensor + """(S, 3) sphere centers in each child's local frame, concatenated across pairs.""" + + all_radii: torch.Tensor + """(S,) sphere radii, concatenated across pairs.""" + + pair_child_objs: list[ObjectBase] + """Per-pair child object reference.""" + + pair_parent_objs: list[ObjectBase] + """Per-pair parent/target object reference.""" + + pair_is_anchor: list[bool] + """Per-pair flag: True if parent is a static anchor.""" + + pair_anchor_pos: list[torch.Tensor | None] + """Per-pair world position of anchors (None for non-anchor parents).""" + + pair_anchor_yaw: list[float] + """Per-pair anchor yaw in radians (0.0 for non-anchor parents).""" + + pair_c_bbox_min: torch.Tensor + """(P, B, 3) child bbox min corners for broadphase.""" + + pair_c_bbox_max: torch.Tensor + """(P, B, 3) child bbox max corners for broadphase.""" + + pair_p_bbox_min: torch.Tensor + """(P, B, 3) parent bbox min corners for broadphase.""" + + pair_p_bbox_max: torch.Tensor + """(P, B, 3) parent bbox max corners for broadphase.""" + + pair_max_r: torch.Tensor + """(P,) max sphere radius per pair (broadphase margin).""" + + sphere_pair_id: torch.Tensor + """(S,) maps each sphere to its pair index for segment reduction.""" + + sphere_mesh_idx: torch.Tensor + """(S,) per-sphere index into mesh_id_array.""" + + pair_sphere_count: torch.Tensor + """(P,) number of spheres per pair (for mean reduction).""" + + mesh_id_array: wp.array + """Warp uint64 array of mesh IDs for the multi-mesh kernel.""" + + num_pairs: int + """Total number of active object pairs.""" + + total_spheres: int + """Total number of sphere queries across all pairs.""" + + def __post_init__(self) -> None: + assert len(self.pair_child_objs) == self.num_pairs, "pair_child_objs length mismatch" + assert len(self.pair_parent_objs) == self.num_pairs, "pair_parent_objs length mismatch" + assert len(self.pair_is_anchor) == self.num_pairs, "pair_is_anchor length mismatch" + assert self.all_centers_local.shape[0] == self.total_spheres, "all_centers_local size mismatch" + assert self.all_radii.shape[0] == self.total_spheres, "all_radii size mismatch" + assert self.sphere_pair_id.shape[0] == self.total_spheres, "sphere_pair_id size mismatch" + assert self.sphere_mesh_idx.shape[0] == self.total_spheres, "sphere_mesh_idx size mismatch" + assert int(self.pair_sphere_count.sum().item()) == self.total_spheres, "pair_sphere_count sum mismatch" + for i, (is_anchor, pos) in enumerate(zip(self.pair_is_anchor, self.pair_anchor_pos)): + assert not is_anchor or pos is not None, f"pair {i}: is_anchor=True but anchor_pos is None" diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 5355a8f51d..3c834c3dfa 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -5,11 +5,13 @@ from __future__ import annotations +import math import torch from dataclasses import dataclass, field from typing import TYPE_CHECKING from isaaclab_arena.relations.bounding_box_helpers import assign_variants_for_envs, build_per_env_bounding_boxes +from isaaclab_arena.relations.collision_mode import CollisionMode from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams from isaaclab_arena.relations.placement_result import PlacementResult from isaaclab_arena.relations.placement_validation import PlacementCheck, PlacementValidationResults @@ -24,8 +26,10 @@ RotateAroundSolution, get_anchor_objects, ) +from isaaclab_arena.relations.warp_mesh_manager import WarpMeshAndSphereCache +from isaaclab_arena.relations.warp_sdf_kernels import has_sdf_sentinel, mesh_sdf from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox -from isaaclab_arena.utils.pose import Pose, PosePerEnv, rotate_quat_by_yaw, wrap_angle_to_pi +from isaaclab_arena.utils.pose import Pose, PosePerEnv, rotate_quat_by_yaw, wrap_angle_to_pi, yaw_from_quat_xyzw from isaaclab_arena.utils.random import get_random_rotation if TYPE_CHECKING: @@ -46,7 +50,7 @@ class PlacementCandidate: """Per-check validation results for this candidate's layout.""" orientations: dict[ObjectBase, float] = field(default_factory=dict) - """Per-object yaw (radians about Z) sampled for this candidate. Empty when unrotated.""" + """Total applied Z-yaw (marker + sampled) per object. Empty when unrotated.""" @property def is_valid(self) -> bool: @@ -76,6 +80,7 @@ class ObjectPlacer: def __init__(self, params: ObjectPlacerParams | None = None): self.params = params or ObjectPlacerParams() self._solver = RelationSolver(params=self.params.solver_params) + self._cpu_mesh_manager = None def place( self, @@ -222,16 +227,19 @@ def _place_ranked( self._generate_initial_orientations(objects, anchor_objects_set, generator) ) - # Bake each candidate's yaw into a conservative enclosing bbox; no-op when yaw is disabled. - # The solver and validation then treat the rotated object as an axis-aligned box. + # Bake each candidate's total yaw into a conservative enclosing bbox (AABB broadphase). candidate_bboxes = self._rotate_candidate_bboxes(objects, candidate_bboxes, orientations_per_candidate) - all_positions = self._solver.solve(objects, initial_positions, env_bboxes=candidate_bboxes) + all_positions = self._solver.solve( + objects, initial_positions, env_bboxes=candidate_bboxes, orientations=orientations_per_candidate + ) assert self._solver.last_loss_per_env is not None all_losses: list[float] = self._solver.last_loss_per_env.cpu().tolist() all_validations = [ self._validate_placement( - positions, self._get_bounding_boxes_for_candidate_index(candidate_bboxes, candidate_idx) + positions, + self._get_bounding_boxes_for_candidate_index(candidate_bboxes, candidate_idx), + orientations_per_candidate[candidate_idx], ) for candidate_idx, positions in enumerate(all_positions) ] @@ -358,17 +366,23 @@ def _generate_initial_orientations( anchor_objects: set[ObjectBase], generator: torch.Generator | None = None, ) -> dict[ObjectBase, float]: - """Sample a fixed yaw (radians about Z) per non-anchor object. + """Compute the total applied Z-yaw (marker + sampled) for each non-anchor object. - Empty dict (no RNG consumed) when random_yaw_init is off; anchors are never rotated. + Empty dict when random_yaw_init is off. """ if not self.params.random_yaw_init: return {} orientations: dict[ObjectBase, float] = {} for obj in objects: + marker_yaw = self._get_yaw_from_rotate_around_solution(obj) if obj in anchor_objects: - continue - orientations[obj] = get_random_rotation(generator) + assert marker_yaw == 0.0, ( + f"Anchor '{obj.name}' has a RotateAroundSolution (yaw={marker_yaw:.3f}). " + "Anchors are not repositioned by the placer, so any marker rotation must " + "already be baked into the anchor's initial_pose before calling place()." + ) + else: + orientations[obj] = wrap_angle_to_pi(get_random_rotation(generator) + marker_yaw) return orientations @staticmethod @@ -379,8 +393,8 @@ def _rotate_candidate_bboxes( ) -> dict[ObjectBase, AxisAlignedBoundingBox]: """Replace each candidate's bbox with the enclosing box of its yaw-rotated object. - candidate_bboxes hold one row per candidate (num_candidates, 3); each row is rotated by - its own yaw. Returns the input unchanged when no yaw is set, keeping the no-yaw path exact. + orientations_per_candidate carries total applied yaw (marker + sampled). + Returns the input unchanged when no yaw is set, keeping the no-yaw path exact. """ if not any(orientations for orientations in orientations_per_candidate): return candidate_bboxes @@ -388,14 +402,8 @@ def _rotate_candidate_bboxes( rotated: dict[ObjectBase, AxisAlignedBoundingBox] = {} for obj in objects: bbox = candidate_bboxes[obj] - # Only objects that receive a sampled yaw are rotated; anchors never appear here. if any(obj in orientations for orientations in orientations_per_candidate): - # Enclose marker_yaw + sampled yaw (the applied pose); both are pure-Z. - marker_yaw = ObjectPlacer._get_yaw_from_rotate_around_solution(obj) - yaws = [ - wrap_angle_to_pi(orientations_per_candidate[c].get(obj, 0.0) + marker_yaw) - for c in range(num_candidates) - ] + yaws = [orientations_per_candidate[c].get(obj, 0.0) for c in range(num_candidates)] if any(yaw != 0.0 for yaw in yaws): yaw_tensor = torch.tensor(yaws, dtype=torch.float32, device=bbox.min_point.device) bbox = bbox.rotated_around_z(yaw_tensor) @@ -590,38 +598,30 @@ def _validate_on_relations( return False return True - def _validate_no_overlap( - self, + @staticmethod + def _collect_skip_pairs( positions: dict[ObjectBase, tuple[float, float, float]], - env_bboxes: dict[ObjectBase, AxisAlignedBoundingBox], - ) -> bool: - """Validate that no two objects overlap in 3D (axis-aligned bbox with margin). - - Pairs linked by an On relation and anchor-anchor pairs are skipped. - The margin is derived from the solver's clearance_m parameter (with a - small float tolerance subtracted to avoid rejecting solutions that are - within solver residual). - - Args: - positions: Solved positions for each object. - env_bboxes: Per-object bboxes for the current env, each with shape (1, 3). - """ + ) -> tuple[set[tuple], set[int]]: + """Build On-pair skip set and anchor ID set from positioned objects.""" on_pairs: set[tuple] = set() anchor_ids: set[int] = set() for obj in positions: for rel in obj.get_relations(): if isinstance(rel, On) and rel.parent in positions: - # The lookup below sees pairs in object-list order, so store - # both directions for symmetric On-pair skipping. on_pairs.add((id(obj), id(rel.parent))) on_pairs.add((id(rel.parent), id(obj))) - if any(isinstance(r, IsAnchor) for r in obj.get_relations()): + if obj.is_anchor: anchor_ids.add(id(obj)) + return on_pairs, anchor_ids - clearance_m = self.params.solver_params.clearance_m - # Allow tiny residuals from the differentiable solver around the clearance boundary. - margin = max(0.0, clearance_m - 1e-6) - + def _non_skip_pairs( + self, + positions: dict[ObjectBase, tuple[float, float, float]], + skip_mesh_pairs: bool = False, + ): + """Yield (a, b) pairs for overlap checks, skipping On/anchor-anchor/mesh pairs as configured.""" + on_pairs, anchor_ids = self._collect_skip_pairs(positions) + mesh_manager = self._get_cpu_mesh_manager() if skip_mesh_pairs else None objects = list(positions.keys()) for i in range(len(objects)): for j in range(i + 1, len(objects)): @@ -630,16 +630,34 @@ def _validate_no_overlap( continue if (id(a), id(b)) in on_pairs: continue + if ( + mesh_manager is not None + and mesh_manager.get_collision_mesh(a) is not None + and mesh_manager.get_collision_mesh(b) is not None + ): + continue + yield a, b - a_bbox = env_bboxes[a] - b_bbox = env_bboxes[b] - a_world = a_bbox.translated(positions[a]) - b_world = b_bbox.translated(positions[b]) + def _validate_no_overlap( + self, + positions: dict[ObjectBase, tuple[float, float, float]], + env_bboxes: dict[ObjectBase, AxisAlignedBoundingBox], + skip_mesh_pairs: bool = False, + ) -> bool: + """AABB overlap check. Skips On-pairs, anchor-anchor, and (optionally) mesh-validated pairs.""" + clearance_m = self.params.solver_params.clearance_m + margin = max(0.0, clearance_m - 1e-6) - if a_world.overlaps(b_world, margin=margin).item(): - if self.params.verbose: - print(f" Overlap between '{a.name}' and '{b.name}'") - return False + for a, b in self._non_skip_pairs(positions, skip_mesh_pairs=skip_mesh_pairs): + a_bbox = env_bboxes[a] + b_bbox = env_bboxes[b] + a_world = a_bbox.translated(positions[a]) + b_world = b_bbox.translated(positions[b]) + + if a_world.overlaps(b_world, margin=margin).item(): + if self.params.verbose: + print(f" Overlap between '{a.name}' and '{b.name}'") + return False return True def _validate_next_to_relations( @@ -723,21 +741,176 @@ def _not_next_to_margin(self, relation: NotNextTo) -> float: strategy = self._solver.params.strategies[type(relation)] return strategy.margin_m + def _get_cpu_mesh_manager(self): + """Lazy-init CPU WarpMeshAndSphereCache.""" + if self._cpu_mesh_manager is None: + self._cpu_mesh_manager = WarpMeshAndSphereCache( + num_spheres=self.params.solver_params.num_spheres, + device="cpu", + ) + return self._cpu_mesh_manager + + def _validate_no_overlap_mesh( + self, + positions: dict[ObjectBase, tuple[float, float, float]], + orientations: dict[ObjectBase, float] | None = None, + ) -> bool: + """Sphere-to-SDF overlap check. Mesh-less pairs fall back to AABB.""" + clearance_m = self.params.solver_params.clearance_m + tolerance = max(0.0, clearance_m - 1e-6) + mesh_manager = self._get_cpu_mesh_manager() + mesh_manager.reset_sentinel_warning() + warned_no_mesh: set[str] = set() + + for a, b in self._non_skip_pairs(positions): + a_mesh = mesh_manager.get_collision_mesh(a) + b_mesh = mesh_manager.get_collision_mesh(b) + if a_mesh is None or b_mesh is None: + for obj, mesh in [(a, a_mesh), (b, b_mesh)]: + if mesh is None and obj.name not in warned_no_mesh: + warned_no_mesh.add(obj.name) + print( + f" [NoCollision] MESH mode: '{obj.name}' has no collision mesh," + " falling back to AABB validation for this pair" + ) + a_pos = torch.tensor(positions[a], dtype=torch.float32) + b_pos = torch.tensor(positions[b], dtype=torch.float32) + if self._pair_aabb_overlaps(a, b, a_pos, b_pos, orientations, tolerance): + if self.params.verbose: + print(f" AABB overlap between '{a.name}' and '{b.name}' (mesh unavailable)") + return False + continue + + a_pos = torch.tensor(positions[a], dtype=torch.float32) + b_pos = torch.tensor(positions[b], dtype=torch.float32) + + if self._spheres_penetrate_mesh(a, a_mesh, a_pos, b, b_mesh, b_pos, mesh_manager, tolerance, orientations): + return False + if self._spheres_penetrate_mesh(b, b_mesh, b_pos, a, a_mesh, a_pos, mesh_manager, tolerance, orientations): + return False + + return True + + def _spheres_penetrate_mesh( + self, + source, + source_mesh, + source_pos, + target, + target_mesh, + target_pos, + mesh_manager, + tolerance, + orientations, + ) -> bool: + """True if source's spheres penetrate target's mesh.""" + spheres = mesh_manager.get_query_spheres(source_mesh, obj=source) + warp_mesh = mesh_manager.get_warp_mesh(target_mesh, obj=target) + centers = self._centers_in_target_frame(spheres[:, :3], source, target, source_pos, target_pos, orientations) + sdf = mesh_sdf(centers, warp_mesh) + mesh_manager.warn_sdf_sentinel(sdf) + if has_sdf_sentinel(sdf): + return True + if (sdf < spheres[:, 3] + tolerance).any(): + if self.params.verbose: + print(f" Mesh overlap between '{source.name}' and '{target.name}'") + return True + return False + + @staticmethod + def _pair_aabb_overlaps( + a: ObjectBase, + b: ObjectBase, + a_pos: torch.Tensor, + b_pos: torch.Tensor, + orientations: dict[ObjectBase, float] | None, + margin: float, + ) -> bool: + """Return True if the yaw-rotated AABBs of a and b overlap.""" + for obj in (a, b): + if obj.is_anchor: + pose = obj.get_initial_pose() + if isinstance(pose, Pose): + qx, qy = pose.rotation_xyzw[0], pose.rotation_xyzw[1] + assert abs(qx) < 1e-6 and abs(qy) < 1e-6, ( + f"AABB fallback requires anchor '{obj.name}' to have pure-Z rotation, " + f"got rotation_xyzw={pose.rotation_xyzw}" + ) + a_bbox = a.get_bounding_box() + b_bbox = b.get_bounding_box() + a_yaw = ObjectPlacer._effective_yaw(a, orientations) + b_yaw = ObjectPlacer._effective_yaw(b, orientations) + if a_yaw != 0.0: + a_bbox = a_bbox.rotated_around_z(a_yaw) + if b_yaw != 0.0: + b_bbox = b_bbox.rotated_around_z(b_yaw) + a_world = a_bbox.translated(a_pos) + b_world = b_bbox.translated(b_pos) + return a_world.overlaps(b_world, margin=margin).item() + + @staticmethod + def _effective_yaw(obj: ObjectBase, orientations: dict[ObjectBase, float] | None) -> float: + """Resolve effective Z-yaw: orientations dict, else initial_pose for anchors.""" + if orientations is not None and obj in orientations: + return orientations[obj] + if not obj.is_anchor: + return 0.0 + pose = obj.get_initial_pose() + if not isinstance(pose, Pose): + return 0.0 + return yaw_from_quat_xyzw(pose.rotation_xyzw) + + @staticmethod + def _centers_in_target_frame( + centers_local: torch.Tensor, + source_obj: ObjectBase, + target_obj: ObjectBase, + source_pos: torch.Tensor, + target_pos: torch.Tensor, + orientations: dict[ObjectBase, float] | None, + ) -> torch.Tensor: + """Transform source sphere centers into the target's local frame (Z-yaw only).""" + src_yaw = ObjectPlacer._effective_yaw(source_obj, orientations) + tgt_yaw = ObjectPlacer._effective_yaw(target_obj, orientations) + + if src_yaw == 0.0 and tgt_yaw == 0.0: + return centers_local + source_pos - target_pos + + net_yaw = src_yaw - tgt_yaw + cos_n = math.cos(net_yaw) + sin_n = math.sin(net_yaw) + rx = centers_local[:, 0] * cos_n - centers_local[:, 1] * sin_n + ry = centers_local[:, 0] * sin_n + centers_local[:, 1] * cos_n + rotated_centers = torch.stack([rx, ry, centers_local[:, 2]], dim=-1) + + offset = source_pos - target_pos + cos_t = math.cos(-tgt_yaw) + sin_t = math.sin(-tgt_yaw) + ox = offset[0] * cos_t - offset[1] * sin_t + oy = offset[0] * sin_t + offset[1] * cos_t + rotated_offset = torch.tensor([ox, oy, offset[2].item()], dtype=centers_local.dtype) + return rotated_centers + rotated_offset + def _validate_placement( self, positions: dict[ObjectBase, tuple[float, float, float]], env_bboxes: dict[ObjectBase, AxisAlignedBoundingBox], + orientations: dict[ObjectBase, float] | None = None, ) -> PlacementValidationResults: """Validate that no two objects overlap in 3D and On / NextTo / NotNextTo relations are satisfied. Args: positions: Dictionary mapping objects to their solved (x, y, z) positions. env_bboxes: Per-object bboxes for the current env, each with shape (1, 3). + orientations: Optional per-object yaw (radians about Z). Returns: PlacementValidationResults with the overlap and relation checks. """ - no_overlap = self._validate_no_overlap(positions, env_bboxes) + use_mesh = self.params.solver_params.collision_mode == CollisionMode.MESH + no_overlap = self._validate_no_overlap(positions, env_bboxes, skip_mesh_pairs=use_mesh) + if no_overlap and use_mesh: + no_overlap = self._validate_no_overlap_mesh(positions, orientations) on_relation = self._validate_on_relations(positions, env_bboxes) next_to = self._validate_next_to_relations(positions, env_bboxes) not_next_to = self._validate_not_next_to_relations(positions, env_bboxes) @@ -763,13 +936,9 @@ def _apply_poses( anchor_objects: set[ObjectBase], orientations_per_env: list[dict[ObjectBase, float]], ) -> None: - """Apply solved positions and sampled yaw to objects (skipping anchors). + """Apply solved positions and orientations to non-anchor objects. - Handles both single-env and multi-env placement: - - Single-env: sets a fixed Pose or PoseRange (with RandomAroundSolution). - - Multi-env: sets a PosePerEnv with one Pose per environment. - - Rotation is the RotateAroundSolution marker (or identity) with the sampled yaw composed on top. + orientations_per_env carries total Z-yaw; marker yaw is subtracted to get the sampled delta. """ num_envs = len(positions_per_env) objects = list(positions_per_env[0]) @@ -779,10 +948,15 @@ def _apply_poses( rotate_marker = self._get_rotate_around_solution(obj) base_rotation = rotate_marker.get_rotation_xyzw() if rotate_marker else (0.0, 0.0, 0.0, 1.0) + marker_yaw = self._get_yaw_from_rotate_around_solution(obj) + + def _sampled_yaw_delta(env_idx: int) -> float: + """Total yaw minus marker yaw = solver-sampled delta.""" + return orientations_per_env[env_idx].get(obj, marker_yaw) - marker_yaw if num_envs == 1: pos = positions_per_env[0][obj] - rotation_xyzw = rotate_quat_by_yaw(base_rotation, orientations_per_env[0].get(obj, 0.0)) + rotation_xyzw = rotate_quat_by_yaw(base_rotation, _sampled_yaw_delta(0)) random_marker = self._get_random_around_solution(obj) if random_marker is not None: obj.set_initial_pose(random_marker.to_pose_range_centered_at(pos, rotation_xyzw=rotation_xyzw)) @@ -792,7 +966,7 @@ def _apply_poses( poses = [ Pose( position_xyz=positions_per_env[env_idx][obj], - rotation_xyzw=rotate_quat_by_yaw(base_rotation, orientations_per_env[env_idx].get(obj, 0.0)), + rotation_xyzw=rotate_quat_by_yaw(base_rotation, _sampled_yaw_delta(env_idx)), ) for env_idx in range(num_envs) ] diff --git a/isaaclab_arena/relations/object_placer_params.py b/isaaclab_arena/relations/object_placer_params.py index 353e4844ca..8185968502 100644 --- a/isaaclab_arena/relations/object_placer_params.py +++ b/isaaclab_arena/relations/object_placer_params.py @@ -5,7 +5,9 @@ from dataclasses import dataclass, field -from isaaclab_arena.relations.relation_solver_params import RelationSolverParams +from isaaclab_arena.relations.relation_solver_params import CollisionMode, RelationSolverParams + +__all__ = ["CollisionMode", "ObjectPlacerParams"] @dataclass diff --git a/isaaclab_arena/relations/placement_result.py b/isaaclab_arena/relations/placement_result.py index 13cd9b252f..1b30aa2d1d 100644 --- a/isaaclab_arena/relations/placement_result.py +++ b/isaaclab_arena/relations/placement_result.py @@ -30,8 +30,7 @@ class PlacementResult: """Number of attempts made.""" orientations: dict[ObjectBase, float] = field(default_factory=dict) - """Per-object yaw (radians) about the world up (Z) axis, composed on top of each object's - base rotation. Keyed by object, like positions. Empty when unrotated.""" + """Total applied Z-yaw (radians) per object (marker + sampled). Empty when unrotated.""" @property def success(self) -> bool: diff --git a/isaaclab_arena/relations/relation_loss_strategies.py b/isaaclab_arena/relations/relation_loss_strategies.py index 508500a870..e622d11db0 100644 --- a/isaaclab_arena/relations/relation_loss_strategies.py +++ b/isaaclab_arena/relations/relation_loss_strategies.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import torch from abc import ABC, abstractmethod from dataclasses import dataclass @@ -15,13 +17,12 @@ single_boundary_linear_loss, single_point_linear_loss, ) +from isaaclab_arena.relations.relations import Side from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox if TYPE_CHECKING: from isaaclab_arena.relations.relations import AtPosition, NextTo, NotNextTo, On, PositionLimits, Relation -from isaaclab_arena.relations.relations import Side - class Axis(IntEnum): """Spatial axis indices for tensor indexing.""" @@ -158,7 +159,7 @@ class UnaryRelationLossStrategy(ABC): @abstractmethod def compute_loss( self, - relation: "Relation", + relation: Relation, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, ) -> torch.Tensor: @@ -182,7 +183,7 @@ class RelationLossStrategy(ABC): @abstractmethod def compute_loss( self, - relation: "Relation", + relation: Relation, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, parent_world_bbox: AxisAlignedBoundingBox, @@ -223,7 +224,7 @@ def __init__(self, slope: float = 10.0, debug: bool = False): def compute_loss( self, - relation: "NextTo", + relation: NextTo, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, parent_world_bbox: AxisAlignedBoundingBox, @@ -304,7 +305,7 @@ def __init__(self, slope: float = 10.0, debug: bool = False): def compute_loss( self, - relation: "On", + relation: On, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, parent_world_bbox: AxisAlignedBoundingBox, @@ -413,7 +414,7 @@ def __init__(self, slope: float = 10.0, margin_m: float = 0.1, debug: bool = Fal def compute_loss( self, - relation: "NotNextTo", + relation: NotNextTo, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, parent_world_bbox: AxisAlignedBoundingBox, @@ -448,25 +449,20 @@ def compute_loss( class NoCollisionLossStrategy: - """Loss strategy for no-overlap constraints between objects. - - Computes loss based on: - 1. X overlap: zero when child and parent are separated along X; else overlap length - 2. Y overlap: zero when separated along Y; else overlap length - 3. Z overlap: zero when separated along Z; else overlap length - 4. Volume loss: slope * (overlap_x * overlap_y * overlap_z) + """AABB no-overlap loss between object pairs (built-in solver behavior, not a user relation).""" - This is a standalone strategy (not a RelationLossStrategy) because no-overlap - is a built-in solver behavior, not a user-specified relation. - """ - - def __init__(self, slope: float = 10.0): + def __init__( + self, + slope: float = 10.0, + debug: bool = False, + ): """ Args: - slope: Gradient magnitude for overlap volume loss (default: 10.0). - Loss scales with slope times overlap volume. + slope: Gradient magnitude for overlap loss. + debug: If True, print detailed AABB loss component breakdown. """ self.slope = slope + self.debug = debug def compute_loss_batched( self, @@ -478,8 +474,6 @@ def compute_loss_batched( ) -> torch.Tensor: """Overlap-volume no-overlap loss for boxes already reduced to world-space extents. - The subject box carries gradient; it is pushed off the obstacle box (expanded by clearance). - Args: clearance_m: Minimum clearance between boxes in meters. subject_min: World-space min extent of the subject box, shape (num_pairs, batch_size, 3). @@ -522,7 +516,7 @@ def __init__(self, slope: float = 10.0): def compute_loss( self, - relation: "AtPosition", + relation: AtPosition, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, ) -> torch.Tensor: @@ -542,17 +536,14 @@ def compute_loss( total_loss = torch.zeros(child_pos.shape[0], dtype=child_pos.dtype, device=child_pos.device) - # X position constraint if relation.x is not None: x_loss = single_point_linear_loss(child_pos[:, 0], relation.x, slope=self.slope) total_loss = total_loss + x_loss - # Y position constraint if relation.y is not None: y_loss = single_point_linear_loss(child_pos[:, 1], relation.y, slope=self.slope) total_loss = total_loss + y_loss - # Z position constraint if relation.z is not None: z_loss = single_point_linear_loss(child_pos[:, 2], relation.z, slope=self.slope) total_loss = total_loss + z_loss @@ -578,7 +569,7 @@ def __init__(self, slope: float = 100.0): def compute_loss( self, - relation: "PositionLimits", + relation: PositionLimits, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, ) -> torch.Tensor: @@ -598,7 +589,6 @@ def compute_loss( total_loss = torch.zeros(child_pos.shape[0], dtype=child_pos.dtype, device=child_pos.device) - # Iterate over X (0), Y (1), Z (2) with their optional bounds axis_bounds = [ (relation.x_min, relation.x_max), (relation.y_min, relation.y_max), diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index b393bfd317..4f50fd6f21 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -5,11 +5,17 @@ from __future__ import annotations +import numpy as np import time import torch from dataclasses import dataclass from typing import TYPE_CHECKING, cast +import warp as wp +from isaaclab.utils.math import quat_apply, quat_apply_inverse + +from isaaclab_arena.relations.collision_mode import CollisionMode +from isaaclab_arena.relations.mesh_pair_cache import MeshPairCache from isaaclab_arena.relations.relation_loss_strategies import ( NoCollisionLossStrategy, RelationLossStrategy, @@ -18,7 +24,10 @@ from isaaclab_arena.relations.relation_solver_params import RelationSolverParams from isaaclab_arena.relations.relation_solver_state import RelationSolverState from isaaclab_arena.relations.relations import On, Relation, RelationBase, UnaryRelation +from isaaclab_arena.relations.warp_mesh_manager import WarpMeshAndSphereCache +from isaaclab_arena.relations.warp_sdf_kernels import clamp_sdf_sentinel, multi_mesh_sdf from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox +from isaaclab_arena.utils.pose import Pose, yaw_from_quat_xyzw if TYPE_CHECKING: from isaaclab_arena.assets.object_base import ObjectBase @@ -38,7 +47,7 @@ class NoOverlapPair: class RelationSolver: - """Differentiable solver for 3D spatial relations of IsaacLab Arena Objects + """Differentiable solver for 3D spatial relations of IsaacLab Arena Objects. Uses the Strategy pattern for loss computation: each Relation type has a corresponding RelationLossStrategy that handles the actual loss calculation. @@ -62,18 +71,17 @@ def __init__( self._last_position_history: list = [] self._last_loss_per_env: torch.Tensor | None = None self._last_no_overlap_pair_count: int = 0 + self._mesh_orientations: list[dict[ObjectBase, float]] | None = None + self._warned_no_mesh: set[str] = set() + self._mesh_manager: WarpMeshAndSphereCache | None = None + self._mesh_cache_fwd: MeshPairCache | None = None + self._mesh_cache_rev: MeshPairCache | None = None def _get_strategy(self, relation: RelationBase) -> RelationLossStrategy | UnaryRelationLossStrategy: - """Look up the appropriate strategy for a relation type. + """Look up the loss strategy for a relation type; raises ValueError if none registered. Args: relation: The relation to find a strategy for. - - Returns: - The RelationLossStrategy or UnaryRelationLossStrategy for this relation type. - - Raises: - ValueError: If no strategy is registered for this relation type. """ strategy = self.params.strategies.get(type(relation)) if strategy is None: @@ -102,14 +110,12 @@ def _compute_total_loss( device = state.device total_loss = torch.zeros(batch_size, device=device, dtype=torch.float32) - # Compute loss from all spatial relations using strategies for obj in state.optimizable_objects: for relation in obj.get_spatial_relations(): child_pos = state.get_position(obj) strategy = self._get_strategy(relation) child_bbox = state.get_bbox(obj) - # Handle unary relations (no parent) if isinstance(relation, UnaryRelation): unary_strategy = cast(UnaryRelationLossStrategy, strategy) loss = unary_strategy.compute_loss( @@ -119,7 +125,7 @@ def _compute_total_loss( ) if debug: _print_unary_relation_debug(obj, relation, child_pos[0], loss.mean()) - # Handle binary relations (with parent) like On, NextTo + # Binary relation (On, NextTo, etc.) elif isinstance(relation, Relation): relation_strategy = cast(RelationLossStrategy, strategy) parent = relation.parent @@ -154,18 +160,27 @@ def _compute_no_overlap_loss( state: RelationSolverState, debug: bool = False, ) -> torch.Tensor: - """Compute pairwise no-overlap loss, skipping On-linked pairs. + """Compute pairwise no-overlap loss, skipping On-linked pairs.""" + if self.params.collision_mode == CollisionMode.MESH: + mesh_loss = self._compute_no_overlap_loss_mesh(state, debug) + aabb_loss = self._compute_no_overlap_loss_aabb(state, debug, skip_mesh_pairs=True) + return mesh_loss + aabb_loss + else: + return self._compute_no_overlap_loss_aabb(state, debug) + + def _compute_no_overlap_loss_aabb( + self, + state: RelationSolverState, + debug: bool, + skip_mesh_pairs: bool = False, + ) -> torch.Tensor: + """Per-pair AABB collision loss. - Non-anchor vs anchor: gradient flows to the non-anchor only. - Non-anchor vs non-anchor: both objects accumulate gradient (two directed passes). - Args: - state: Current optimization state with object positions and - optional per-env bounding boxes. - debug: If True, print detailed loss breakdown. - - Returns: - Per-environment loss tensor of shape (batch_size,). + When skip_mesh_pairs=True (used as AABB fallback in MESH mode), only + processes pairs where at least one object lacks a collision mesh. """ device = state.device batch_size = state.batch_size @@ -174,9 +189,6 @@ def _compute_no_overlap_loss( non_anchor_objects = state.optimizable_objects anchor_objects = list(state.anchor_objects) - # Skip no-overlap for On pairs: the On loss already pushes the child - # onto the parent surface, so penalizing bbox overlap between them - # would fight that constraint and cause oscillation. on_pairs: set[tuple[int, int]] = set() for obj in [*non_anchor_objects, *anchor_objects]: for rel in obj.get_relations(): @@ -207,6 +219,13 @@ def _compute_no_overlap_loss( for anchor in anchor_objects: if (id(child), id(anchor)) in on_pairs: continue + if ( + skip_mesh_pairs + and self._mesh_manager is not None + and self._mesh_manager.get_collision_mesh(child) is not None + and self._mesh_manager.get_collision_mesh(anchor) is not None + ): + continue anchor_min, anchor_max = extents[anchor] pairs.append(NoOverlapPair(child_min, child_max, anchor_min, anchor_max)) pair_names.append((child.name, anchor.name)) @@ -218,6 +237,13 @@ def _compute_no_overlap_loss( other = non_anchor_objects[j] if (id(child), id(other)) in on_pairs: continue + if ( + skip_mesh_pairs + and self._mesh_manager is not None + and self._mesh_manager.get_collision_mesh(child) is not None + and self._mesh_manager.get_collision_mesh(other) is not None + ): + continue other_min, other_max = extents[other] pairs.append(NoOverlapPair(child_min, child_max, other_min.detach(), other_max.detach())) pair_names.append((child.name, other.name)) @@ -244,11 +270,366 @@ def _compute_no_overlap_loss( return pair_loss.sum(dim=0) + def _prepare_mesh_collision_cache( + self, + state: RelationSolverState, + on_pairs: set[tuple[int, int]], + ) -> None: + """Precompute static per-pair mesh collision data (called once per solve).""" + device = state.device + device_str = str(device) + if self._mesh_manager is None or self._mesh_manager.device != device_str: + self._mesh_manager = WarpMeshAndSphereCache(num_spheres=self.params.num_spheres, device=device_str) + manager = self._mesh_manager + + non_anchor_objects = state.optimizable_objects + anchor_objects = list(state.anchor_objects) + + self._mesh_cache_fwd = self._build_vectorized_cache( + state, manager, non_anchor_objects, anchor_objects, on_pairs, device, direction="fwd" + ) + self._mesh_cache_rev = self._build_vectorized_cache( + state, manager, non_anchor_objects, anchor_objects, on_pairs, device, direction="rev" + ) + + def _build_vectorized_cache( + self, state, manager, non_anchor_objects, anchor_objects, on_pairs, device, direction: str + ) -> MeshPairCache | None: + """Build vectorized pair cache for one direction. + + Returns None if no valid pairs exist for this direction. + """ + centers_list: list[torch.Tensor] = [] + radii_list: list[torch.Tensor] = [] + pair_child_objs: list = [] + pair_parent_objs: list = [] + pair_is_anchor: list[bool] = [] + pair_anchor_pos: list[torch.Tensor | None] = [] + pair_anchor_yaw: list[float] = [] + pair_c_bbox_min: list[torch.Tensor] = [] + pair_c_bbox_max: list[torch.Tensor] = [] + pair_p_bbox_min: list[torch.Tensor] = [] + pair_p_bbox_max: list[torch.Tensor] = [] + pair_max_r: list[float] = [] + mesh_id_map: dict[int, int] = {} + mesh_id_values: list[int] = [] + mesh_idx_per_sphere: list[int] = [] + pair_slices: list[tuple[int, int]] = [] + offset = 0 + + for i, child in enumerate(non_anchor_objects): + child_mesh = manager.get_collision_mesh(child) + if child_mesh is None: + if child.name not in self._warned_no_mesh: + self._warned_no_mesh.add(child.name) + print(f"[NoCollision] '{child.name}' has no collision mesh; pair will use AABB fallback.") + continue + child_spheres = manager.get_query_spheres(child_mesh, obj=child).to(device) + child_centers_local = child_spheres[:, :3] + child_radii = child_spheres[:, 3] + child_bbox = state.get_bbox(child) + c_bbox_min = child_bbox.min_point.to(device) + c_bbox_max = child_bbox.max_point.to(device) + + if direction == "fwd": + for anchor in anchor_objects: + if (id(child), id(anchor)) in on_pairs: + continue + parent_mesh = manager.get_collision_mesh(anchor) + if parent_mesh is None: + if anchor.name not in self._warned_no_mesh: + self._warned_no_mesh.add(anchor.name) + print(f"[NoCollision] '{anchor.name}' has no collision mesh; pair will use AABB fallback.") + continue + warp_mesh = manager.get_warp_mesh(parent_mesh, obj=anchor) + parent_bbox = state.get_bbox(anchor) + p_bbox_min = parent_bbox.min_point.to(device) + p_bbox_max = parent_bbox.max_point.to(device) + pose = anchor.get_initial_pose() + assert pose is not None and isinstance( + pose, Pose + ), f"MESH collision requires anchor '{anchor.name}' to have a fixed Pose initial_pose" + assert abs(pose.rotation_xyzw[0]) < 1e-6 and abs(pose.rotation_xyzw[1]) < 1e-6, ( + f"MESH collision requires anchor '{anchor.name}' to have identity or " + f"pure-Z rotation, got rotation_xyzw={pose.rotation_xyzw}. " + "Roll/pitch anchors are not supported in MESH mode." + ) + anchor_pos = torch.tensor(pose.position_xyz, dtype=torch.float32, device=device) + anchor_yaw = yaw_from_quat_xyzw(pose.rotation_xyzw) + + n_spheres = child_centers_local.shape[0] + mesh_key = id(warp_mesh) + if mesh_key not in mesh_id_map: + mesh_id_map[mesh_key] = len(mesh_id_values) + mesh_id_values.append(warp_mesh.id) + mesh_idx = mesh_id_map[mesh_key] + + centers_list.append(child_centers_local) + radii_list.append(child_radii) + pair_child_objs.append(child) + pair_parent_objs.append(anchor) + pair_is_anchor.append(True) + pair_anchor_pos.append(anchor_pos) + pair_anchor_yaw.append(anchor_yaw) + pair_c_bbox_min.append(c_bbox_min) + pair_c_bbox_max.append(c_bbox_max) + pair_p_bbox_min.append(p_bbox_min) + pair_p_bbox_max.append(p_bbox_max) + pair_max_r.append(child_radii.max().item()) + mesh_idx_per_sphere.extend([mesh_idx] * n_spheres) + pair_slices.append((offset, offset + n_spheres)) + offset += n_spheres + + for j in range(i + 1, len(non_anchor_objects)): + other = non_anchor_objects[j] + if (id(child), id(other)) in on_pairs: + continue + other_mesh = manager.get_collision_mesh(other) + if other_mesh is None: + if other.name not in self._warned_no_mesh: + self._warned_no_mesh.add(other.name) + print(f"[NoCollision] '{other.name}' has no collision mesh; pair will use AABB fallback.") + continue + warp_mesh = manager.get_warp_mesh(other_mesh, obj=other) + other_bbox = state.get_bbox(other) + p_bbox_min = other_bbox.min_point.to(device) + p_bbox_max = other_bbox.max_point.to(device) + + n_spheres = child_centers_local.shape[0] + mesh_key = id(warp_mesh) + if mesh_key not in mesh_id_map: + mesh_id_map[mesh_key] = len(mesh_id_values) + mesh_id_values.append(warp_mesh.id) + mesh_idx = mesh_id_map[mesh_key] + + centers_list.append(child_centers_local) + radii_list.append(child_radii) + pair_child_objs.append(child) + pair_parent_objs.append(other) + pair_is_anchor.append(False) + pair_anchor_pos.append(None) + pair_anchor_yaw.append(0.0) + pair_c_bbox_min.append(c_bbox_min) + pair_c_bbox_max.append(c_bbox_max) + pair_p_bbox_min.append(p_bbox_min) + pair_p_bbox_max.append(p_bbox_max) + pair_max_r.append(child_radii.max().item()) + mesh_idx_per_sphere.extend([mesh_idx] * n_spheres) + pair_slices.append((offset, offset + n_spheres)) + offset += n_spheres + + else: # direction == "rev" + for j in range(i + 1, len(non_anchor_objects)): + other = non_anchor_objects[j] + if (id(child), id(other)) in on_pairs: + continue + other_mesh = manager.get_collision_mesh(other) + if other_mesh is None: + if other.name not in self._warned_no_mesh: + self._warned_no_mesh.add(other.name) + print(f"[NoCollision] '{other.name}' has no collision mesh; pair will use AABB fallback.") + continue + other_spheres = manager.get_query_spheres(other_mesh, obj=other).to(device) + other_centers_local = other_spheres[:, :3] + other_radii = other_spheres[:, 3] + warp_mesh = manager.get_warp_mesh(child_mesh, obj=child) + other_bbox = state.get_bbox(other) + o_bbox_min = other_bbox.min_point.to(device) + o_bbox_max = other_bbox.max_point.to(device) + + n_spheres = other_centers_local.shape[0] + mesh_key = id(warp_mesh) + if mesh_key not in mesh_id_map: + mesh_id_map[mesh_key] = len(mesh_id_values) + mesh_id_values.append(warp_mesh.id) + mesh_idx = mesh_id_map[mesh_key] + + centers_list.append(other_centers_local) + radii_list.append(other_radii) + pair_child_objs.append(other) + pair_parent_objs.append(child) + pair_is_anchor.append(False) + pair_anchor_pos.append(None) + pair_anchor_yaw.append(0.0) + pair_c_bbox_min.append(o_bbox_min) + pair_c_bbox_max.append(o_bbox_max) + pair_p_bbox_min.append(c_bbox_min) + pair_p_bbox_max.append(c_bbox_max) + pair_max_r.append(other_radii.max().item()) + mesh_idx_per_sphere.extend([mesh_idx] * n_spheres) + pair_slices.append((offset, offset + n_spheres)) + offset += n_spheres + + if not centers_list: + return None + + wp_device = str(device) + pair_sphere_count = torch.tensor([e - s for s, e in pair_slices], dtype=torch.float32, device=device) + sphere_pair_id = torch.repeat_interleave( + torch.arange(len(pair_slices), device=device), pair_sphere_count.long() + ) + + return MeshPairCache( + all_centers_local=torch.cat(centers_list, dim=0), + all_radii=torch.cat(radii_list, dim=0), + pair_child_objs=pair_child_objs, + pair_parent_objs=pair_parent_objs, + pair_is_anchor=pair_is_anchor, + pair_anchor_pos=pair_anchor_pos, + pair_anchor_yaw=pair_anchor_yaw, + pair_c_bbox_min=torch.stack(pair_c_bbox_min), + pair_c_bbox_max=torch.stack(pair_c_bbox_max), + pair_p_bbox_min=torch.stack(pair_p_bbox_min), + pair_p_bbox_max=torch.stack(pair_p_bbox_max), + pair_max_r=torch.tensor(pair_max_r, device=device), + sphere_pair_id=sphere_pair_id, + sphere_mesh_idx=torch.tensor(mesh_idx_per_sphere, dtype=torch.int32, device=device), + pair_sphere_count=pair_sphere_count, + mesh_id_array=wp.array(np.array(mesh_id_values, dtype=np.uint64), dtype=wp.uint64, device=wp_device), + num_pairs=len(pair_slices), + total_spheres=offset, + ) + + def _compute_no_overlap_loss_mesh( + self, + state: RelationSolverState, + debug: bool, + ) -> torch.Tensor: + """Sphere-to-SDF penetration loss using the vectorized multi-mesh kernel. + + Uses precomputed pair cache (centers, radii, mesh indices) to batch all + sphere queries into a single Warp kernel call per iteration. + """ + device = state.device + total_loss = torch.zeros(state.batch_size, device=device, dtype=torch.float32) + clearance_m = self.params.clearance_m + slope = self._no_collision_strategy.slope + + for b in range(state.batch_size): + for cache in (self._mesh_cache_fwd, self._mesh_cache_rev): + if cache is None: + continue + + num_pairs = cache.num_pairs + + child_positions = torch.stack( + [state.get_position(cache.pair_child_objs[p])[b] for p in range(num_pairs)] + ) + parent_positions = torch.stack([ + ( + cache.pair_anchor_pos[p] + if cache.pair_is_anchor[p] + else state.get_position(cache.pair_parent_objs[p])[b].detach() + ) + for p in range(num_pairs) + ]) + + anchor_yaws = cache.pair_anchor_yaw + has_any_yaw = self._mesh_orientations is not None or any(y != 0.0 for y in anchor_yaws) + if has_any_yaw: + ori_b = self._mesh_orientations[b] if self._mesh_orientations is not None else {} + child_yaws = torch.tensor( + [ori_b.get(cache.pair_child_objs[p], 0.0) for p in range(num_pairs)], + dtype=torch.float32, + device=device, + ) + parent_yaws = torch.tensor( + [ori_b.get(cache.pair_parent_objs[p], anchor_yaws[p]) for p in range(num_pairs)], + dtype=torch.float32, + device=device, + ) + + # AABB broadphase (yaw-aware): skip separated pairs. + margins = cache.pair_max_r + clearance_m + batch_idx = min(b, cache.pair_c_bbox_min.shape[1] - 1) + c_bbox_min = cache.pair_c_bbox_min[:, batch_idx, :] + c_bbox_max = cache.pair_c_bbox_max[:, batch_idx, :] + p_bbox_min = cache.pair_p_bbox_min[:, batch_idx, :] + p_bbox_max = cache.pair_p_bbox_max[:, batch_idx, :] + + if has_any_yaw: + c_bbox_min, c_bbox_max = self._rotate_bbox_extents(c_bbox_min, c_bbox_max, child_yaws) + p_bbox_min, p_bbox_max = self._rotate_bbox_extents(p_bbox_min, p_bbox_max, parent_yaws) + + child_min = child_positions + c_bbox_min + child_max = child_positions + c_bbox_max + parent_min = parent_positions + p_bbox_min + parent_max = parent_positions + p_bbox_max + + sep_child = (child_min - margins.unsqueeze(1)) > parent_max + sep_parent = (parent_min - margins.unsqueeze(1)) > child_max + separated = sep_child.any(dim=1) | sep_parent.any(dim=1) + active_pair = ~separated + + if not active_pair.any(): + continue + + offsets = child_positions - parent_positions + sphere_active_mask = active_pair[cache.sphere_pair_id] + active_idx = sphere_active_mask.nonzero(as_tuple=True)[0] + + active_sphere_pair_id = cache.sphere_pair_id[active_idx] + local_centers = cache.all_centers_local[active_idx] + + # R(child_yaw - parent_yaw) · local + R(-parent_yaw) · offset + if has_any_yaw: + net_yaws = (child_yaws - parent_yaws)[active_sphere_pair_id] + half_net = net_yaws / 2.0 + q_net_z = torch.zeros(len(half_net), 4, device=device, dtype=local_centers.dtype) + q_net_z[:, 2] = torch.sin(half_net) + q_net_z[:, 3] = torch.cos(half_net) + local_centers = quat_apply(q_net_z, local_centers) + + pair_offsets = offsets[active_sphere_pair_id] + p_yaws = parent_yaws[active_sphere_pair_id] + half_p = p_yaws / 2.0 + q_parent_z = torch.zeros(len(half_p), 4, device=device, dtype=local_centers.dtype) + q_parent_z[:, 2] = torch.sin(half_p) + q_parent_z[:, 3] = torch.cos(half_p) + rotated_offsets = quat_apply_inverse(q_parent_z, pair_offsets) + active_centers = local_centers + rotated_offsets + else: + active_centers = local_centers + offsets[active_sphere_pair_id] + active_radii = cache.all_radii[active_idx] + active_mesh_idx = cache.sphere_mesh_idx[active_idx].contiguous() + + active_mesh_indices_wp = wp.from_torch(active_mesh_idx, dtype=wp.int32) + sdf_values = multi_mesh_sdf(active_centers, cache.mesh_id_array, active_mesh_indices_wp) + self._mesh_manager.warn_sdf_sentinel(sdf_values) + sdf_values = clamp_sdf_sentinel(sdf_values) + penetration = torch.relu(active_radii + clearance_m - sdf_values) + + pair_sum = torch.zeros(num_pairs, device=device, dtype=penetration.dtype) + pair_sum.index_add_(0, active_sphere_pair_id, penetration) + pair_mean = pair_sum / cache.pair_sphere_count + active_pair_idx = active_pair.nonzero(as_tuple=True)[0] + total_loss[b] = total_loss[b] + slope * pair_mean[active_pair_idx].sum() + + if debug: + print(f" [NoOverlap MESH] total_loss={total_loss.tolist()}") + + return total_loss + + @staticmethod + def _rotate_bbox_extents( + bbox_min: torch.Tensor, bbox_max: torch.Tensor, yaws: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Return the AABB enclosing a Z-rotated bbox (bbox_min/max: (N,3), yaws: (N,)).""" + cos_y = torch.cos(yaws).abs().unsqueeze(1) + sin_y = torch.sin(yaws).abs().unsqueeze(1) + half = (bbox_max - bbox_min) / 2.0 + center = (bbox_max + bbox_min) / 2.0 + new_hx = half[:, 0:1] * cos_y + half[:, 1:2] * sin_y + new_hy = half[:, 0:1] * sin_y + half[:, 1:2] * cos_y + new_half = torch.cat([new_hx, new_hy, half[:, 2:3]], dim=1) + return center - new_half, center + new_half + def solve( self, objects: list[ObjectBase], initial_positions: list[dict[ObjectBase, tuple[float, float, float]]], env_bboxes: dict[ObjectBase, AxisAlignedBoundingBox] | None = None, + orientations: list[dict[ObjectBase, float]] | None = None, ) -> list[dict[ObjectBase, tuple[float, float, float]]]: """Solve for optimal positions of all objects. @@ -261,6 +642,8 @@ def solve( ObjectPlacer always supplies these, with each AxisAlignedBoundingBox shaped (batch, 3). Direct solver calls may omit them to use each object's default get_bounding_box(). + orientations: Optional per-env yaw angles (radians about Z) per object. + Used in MESH mode to rotate sphere centers before collision queries. Returns: List of dicts (one per env) mapping objects to their solved (x, y, z) positions. @@ -288,6 +671,22 @@ def solve( torch.cuda.synchronize() solve_start = time.perf_counter() + # Precompute mesh collision cache (once per solve, before opt loop) + if self.params.collision_mode == CollisionMode.MESH: + non_anchor_objects = state.optimizable_objects + anchor_objects = list(state.anchor_objects) + on_pairs: set[tuple[int, int]] = set() + for obj in [*non_anchor_objects, *anchor_objects]: + for rel in obj.get_relations(): + if isinstance(rel, On): + on_pairs.add((id(obj), id(rel.parent))) + on_pairs.add((id(rel.parent), id(obj))) + self._mesh_orientations = orientations + self._prepare_mesh_collision_cache(state, on_pairs) + + if self.params.collision_mode == CollisionMode.MESH: + self._mesh_manager.reset_sentinel_warning() + # Setup optimizer (only for optimizable positions) optimizer = torch.optim.Adam([state.optimizable_positions], lr=self.params.lr) @@ -305,13 +704,13 @@ def solve( if self.params.save_position_history and iter % self.POSITION_HISTORY_SAVE_INTERVAL == 0: position_history.append(state.get_all_positions_snapshot()) - # Compute total loss loss = self._compute_total_loss(state) loss_history.append(loss.item()) - # Backprop and update (only optimizable positions will update) - loss.backward() - optimizer.step() + # Constant-zero loss has no grad_fn — skip backward when broadphase culls all pairs. + if loss.grad_fn is not None: + loss.backward() + optimizer.step() if self.params.verbose and iter % 100 == 0: print(f"Iter {iter}: loss = {loss.item():.6f}") @@ -381,7 +780,6 @@ def debug_losses(self, objects: list[ObjectBase]) -> None: print("No position history available. Run solve() first.") return - # Build positions dict from final position history final_positions = {obj: (pos[0], pos[1], pos[2]) for obj, pos in zip(objects, final_positions_list)} state = RelationSolverState(objects, [final_positions]) diff --git a/isaaclab_arena/relations/relation_solver_params.py b/isaaclab_arena/relations/relation_solver_params.py index bcc99aa0bc..26b125663a 100644 --- a/isaaclab_arena/relations/relation_solver_params.py +++ b/isaaclab_arena/relations/relation_solver_params.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field +from isaaclab_arena.relations.collision_mode import CollisionMode from isaaclab_arena.relations.relation_loss_strategies import ( AtPositionLossStrategy, NextToLossStrategy, @@ -50,6 +51,12 @@ class RelationSolverParams: save_position_history: bool = True """Save position snapshots during optimization for visualization/debugging. Disable to reduce memory.""" + collision_mode: CollisionMode = CollisionMode.BBOX + """Which collision detection method to use for no-overlap constraints.""" + + num_spheres: int = 30 + """Number of bounding spheres per object for MESH mode. Higher = more accurate but slower.""" + clearance_m: float = 0.01 """Minimum clearance (meters) enforced between every pair of non-anchor objects. The solver adds a no-overlap loss for all pairs automatically. Set to 0.0 to only diff --git a/isaaclab_arena/relations/warp_mesh_manager.py b/isaaclab_arena/relations/warp_mesh_manager.py new file mode 100644 index 0000000000..5a5192df2f --- /dev/null +++ b/isaaclab_arena/relations/warp_mesh_manager.py @@ -0,0 +1,212 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Warp mesh management and greedy sphere decomposition for mesh-based collision.""" + +from __future__ import annotations + +import numpy as np +import torch +import trimesh +from collections import defaultdict +from heapq import heappop, heappush +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab_arena.relations.warp_sdf_kernels import has_sdf_sentinel, sdf_sentinel_count + +if TYPE_CHECKING: + from isaaclab_arena.assets.object_base import ObjectBase + + +def _mesh_content_hash(mesh: trimesh.Trimesh) -> int: + """Content-based hash for a trimesh. Safe across GC cycles unlike id().""" + return hash((mesh.vertices.tobytes(), mesh.faces.tobytes())) + + +def greedy_sphere_decomposition( + mesh: trimesh.Trimesh, + num_spheres: int = 30, + sphere_radius: float = 0.01, + n_candidates: int = 200, + n_surface: int = 1000, + seed: int = 42, +) -> np.ndarray: + """Decompose a mesh into bounding spheres via greedy set-cover. + + Args: + mesh: Input trimesh (must be watertight or convex-hull-repairable). + num_spheres: Maximum number of output spheres. + sphere_radius: Inflation added to tangent sphere radii (safety margin). + n_candidates: Number of candidate sphere centers sampled. + n_surface: Number of surface points for coverage tracking. + seed: RNG seed for reproducible surface sampling. + + Returns: + (K, 4) array of [cx, cy, cz, radius] in mesh-local frame. K <= num_spheres. + """ + n_candidates = max(num_spheres, n_candidates) + n_surface = max(n_candidates, n_surface) + + rng = np.random.default_rng(seed) + points = trimesh.sample.sample_surface(mesh, n_surface, seed=rng)[0] + cloud = trimesh.PointCloud(points) + + work_mesh = mesh if mesh.is_watertight else mesh.convex_hull + candidates = points[:n_candidates] + try: + centers, radii = trimesh.proximity.max_tangent_sphere(work_mesh, candidates) + except (IndexError, ValueError) as e: + print(f" [SphereDecomp] max_tangent_sphere failed ({e}), using uniform fallback — coverage may be poor") + centers = candidates[:num_spheres] + radii = np.full(len(centers), sphere_radius) + return np.column_stack([centers, radii]) + + radii = radii + sphere_radius + + max_radius = np.linalg.norm(mesh.extents) / 2 + valid = (radii <= max_radius) & np.isfinite(radii) + centers, radii = centers[valid], radii[valid] + + if len(centers) == 0: + print(" [SphereDecomp] All tangent spheres filtered (degenerate mesh?) — using uniform fallback") + pts = points[:num_spheres] + return np.column_stack([pts, np.full(len(pts), sphere_radius)]) + + outgoing: dict[int, set[int]] = defaultdict(set) + incoming: dict[int, set[int]] = defaultdict(set) + for idx, (center, radius) in enumerate(zip(centers, radii)): + covered = cloud.kdtree.query_ball_point(center, r=radius, eps=1e-6) + for pt_idx in covered: + outgoing[idx].add(pt_idx) + incoming[pt_idx].add(idx) + + selected: list[int] = [] + queue: list[tuple[int, int]] = [] + for idx in outgoing: + heappush(queue, (-len(outgoing[idx]), idx)) + + while queue and len(selected) < num_spheres: + neg_count, idx = heappop(queue) + if len(outgoing[idx]) != -neg_count: + heappush(queue, (-len(outgoing[idx]), idx)) + continue + if neg_count == 0: + break + for pt_idx in list(outgoing[idx]): + for other_idx in incoming[pt_idx]: + outgoing[other_idx].discard(pt_idx) + selected.append(idx) + + if not selected: + print(" [SphereDecomp] Set-cover selected 0 spheres — using uniform fallback") + pts = points[:num_spheres] + return np.column_stack([pts, np.full(len(pts), sphere_radius)]) + + return np.column_stack([centers[selected], radii[selected]]) + + +class WarpMeshAndSphereCache: + """Cache for Warp BVH meshes and sphere decompositions used in mesh-based collision queries.""" + + def __init__( + self, + num_spheres: int = 30, + sphere_radius: float = 0.01, + device: str = "cuda:0", + ): + self._num_spheres = num_spheres + self._sphere_radius = sphere_radius + self._device = device + self._warp_mesh_cache: dict[tuple, wp.Mesh] = {} + self._sphere_cache: dict[tuple, torch.Tensor] = {} + self._trimesh_cache: dict[tuple, trimesh.Trimesh | None] = {} + self._sentinel_warned: bool = False + + def reset_sentinel_warning(self) -> None: + """Re-arm for a new solve/validation pass.""" + self._sentinel_warned = False + + def warn_sdf_sentinel(self, sdf_values: torch.Tensor) -> None: + """Warn (once per pass) if any query hit the no-face sentinel.""" + if self._sentinel_warned: + return + if has_sdf_sentinel(sdf_values): + self._sentinel_warned = True + n_bad = sdf_sentinel_count(sdf_values) + print( + f" [MeshSDF] WARNING: {n_bad}/{len(sdf_values)} sphere queries returned sentinel SDF " + "(no mesh face found). Collision detection may be incomplete for these points." + ) + + def get_collision_mesh(self, obj: ObjectBase) -> trimesh.Trimesh | None: + """Extract or retrieve cached collision mesh for an object.""" + usd_path = getattr(obj, "usd_path", None) + if usd_path is None: + return obj.get_collision_mesh() + scale = tuple(getattr(obj, "scale", (1.0, 1.0, 1.0))) + key = (usd_path, scale) + if key not in self._trimesh_cache: + from isaaclab_arena.utils.usd_helpers import extract_trimesh_from_usd + + try: + self._trimesh_cache[key] = extract_trimesh_from_usd(usd_path, scale) + except ValueError as e: + # Permanent: bad USD content, cache None to avoid re-parsing. + print(f" [WarpMeshAndSphereCache] Could not extract mesh for '{obj.name}': {e}") + self._trimesh_cache[key] = None + except OSError as e: + # Transient: file I/O failure, don't cache so next call retries. + print(f" [WarpMeshAndSphereCache] Could not extract mesh for '{obj.name}': {e}") + return None + return self._trimesh_cache[key] + + @property + def device(self) -> str: + """Target Warp device string (e.g. 'cuda:0', 'cpu').""" + return self._device + + def _cache_key(self, mesh: trimesh.Trimesh, obj: ObjectBase | None = None) -> tuple: + """Compute cache key. Uses (usd_path, scale) for USD objects, content hash otherwise.""" + usd_path = getattr(obj, "usd_path", None) if obj is not None else None + if usd_path is not None: + scale = tuple(getattr(obj, "scale", (1.0, 1.0, 1.0))) + return (usd_path, scale, self._num_spheres, self._sphere_radius) + return (_mesh_content_hash(mesh), self._num_spheres, self._sphere_radius) + + def get_warp_mesh(self, mesh: trimesh.Trimesh, obj: ObjectBase | None = None) -> wp.Mesh: + """Get or create a Warp BVH mesh for SDF queries. + + Non-watertight meshes are replaced by their convex hull to ensure + correct inside/outside signs. + """ + key = self._cache_key(mesh, obj) + if key not in self._warp_mesh_cache: + if not mesh.is_watertight: + name = obj.name if obj is not None else repr(mesh) + print( + f" [WarpMeshAndSphereCache] '{name}' mesh is not watertight — using convex hull (concavities will" + " be filled)" + ) + work_mesh = mesh if mesh.is_watertight else mesh.convex_hull + vertices = wp.array(np.asarray(work_mesh.vertices, dtype=np.float32), dtype=wp.vec3, device=self._device) + indices = wp.array( + np.asarray(work_mesh.faces, dtype=np.int32).flatten(), dtype=wp.int32, device=self._device + ) + self._warp_mesh_cache[key] = wp.Mesh(points=vertices, indices=indices) + return self._warp_mesh_cache[key] + + def get_query_spheres(self, mesh: trimesh.Trimesh, obj: ObjectBase | None = None) -> torch.Tensor: + """Get or compute sphere decomposition as (K, 4) tensor [cx, cy, cz, radius].""" + key = self._cache_key(mesh, obj) + if key not in self._sphere_cache: + spheres_np = greedy_sphere_decomposition( + mesh, + num_spheres=self._num_spheres, + sphere_radius=self._sphere_radius, + ) + self._sphere_cache[key] = torch.from_numpy(spheres_np).float() + return self._sphere_cache[key] diff --git a/isaaclab_arena/relations/warp_sdf_kernels.py b/isaaclab_arena/relations/warp_sdf_kernels.py new file mode 100644 index 0000000000..1594b6e459 --- /dev/null +++ b/isaaclab_arena/relations/warp_sdf_kernels.py @@ -0,0 +1,232 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Warp SDF kernels and PyTorch autograd bridge for mesh-based collision loss.""" + +from __future__ import annotations + +import torch + +import warp as wp + + +@wp.kernel +def _sdf_query_kernel( + mesh_id: wp.uint64, + query_points: wp.array(dtype=wp.vec3), + sdf_out: wp.array(dtype=wp.float32), + grad_out: wp.array(dtype=wp.vec3), +): + """Query signed distance and gradient for each point against a Warp mesh. + + Points must be in mesh-local frame. Sign convention: negative = inside mesh. + Points with no enclosing face write a large sentinel value (~1e6). + """ + tid = wp.tid() + p = query_points[tid] + + face_index = int(0) + face_u = float(0.0) + face_v = float(0.0) + sign = float(0.0) + + found = wp.mesh_query_point_sign_normal(mesh_id, p, 1.0e6, sign, face_index, face_u, face_v) + + if found: + closest = wp.mesh_eval_position(mesh_id, face_index, face_u, face_v) + diff = p - closest + dist = wp.length(diff) + sdf_out[tid] = sign * dist + if dist > 1.0e-8: + grad_out[tid] = sign * wp.normalize(diff) + else: + grad_out[tid] = wp.vec3(0.0, 0.0, 0.0) + else: + sdf_out[tid] = float(1.0e6) + grad_out[tid] = wp.vec3(0.0, 0.0, 0.0) + + +class _MeshSDFFunction(torch.autograd.Function): + """Autograd bridge for single-mesh SDF queries.""" + + @staticmethod + def forward(ctx, points: torch.Tensor, mesh: wp.Mesh) -> torch.Tensor: + """Compute SDF values for query points against a Warp mesh. + + Args: + points: (N, 3) float32 tensor of query positions (must be on same device as mesh). + mesh: Warp Mesh with BVH built. + + Returns: + (N,) tensor of signed distance values (negative = inside). + """ + device = points.device + n = points.shape[0] + wp_device = str(device) + + wp_points = wp.from_torch(points.contiguous(), dtype=wp.vec3) + sdf_wp = wp.zeros(n, dtype=wp.float32, device=wp_device) + grad_wp = wp.zeros(n, dtype=wp.vec3, device=wp_device) + + wp.launch( + kernel=_sdf_query_kernel, + dim=n, + inputs=[mesh.id, wp_points, sdf_wp, grad_wp], + device=wp_device, + ) + + sdf_torch = wp.to_torch(sdf_wp) + grad_torch = wp.to_torch(grad_wp).reshape(n, 3) + + ctx.save_for_backward(grad_torch) + return sdf_torch + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + (grad_sdf,) = ctx.saved_tensors + grad_points = grad_output.unsqueeze(-1) * grad_sdf + return grad_points, None + + +def mesh_sdf(points: torch.Tensor, warp_mesh: wp.Mesh) -> torch.Tensor: + """Differentiable signed distance query. + + Args: + points: (N, 3) query points on the same device as the mesh. + warp_mesh: Warp Mesh object with BVH. + + Returns: + (N,) signed distance values. Negative = penetrating. + """ + return _MeshSDFFunction.apply(points, warp_mesh) + + +# Warp returns ~1e6 when a BVH query finds no enclosing face; 1e5 catches these +# while staying safely below any realistic SDF magnitude. +_SDF_SENTINEL = 1.0e5 + + +def has_sdf_sentinel(sdf_values: torch.Tensor) -> bool: + """True when any query hit the no-face sentinel, so the collision result is unreliable.""" + return bool((sdf_values >= _SDF_SENTINEL).any()) + + +def sdf_sentinel_count(sdf_values: torch.Tensor) -> int: + """Number of queries that hit the no-face sentinel.""" + return int((sdf_values >= _SDF_SENTINEL).sum().item()) + + +def clamp_sdf_sentinel(sdf_values: torch.Tensor) -> torch.Tensor: + """Replace sentinel SDF values with 0 (treat as "on surface") so they produce gradient.""" + return torch.where(sdf_values >= _SDF_SENTINEL, torch.zeros_like(sdf_values), sdf_values) + + +# --------------------------------------------------------------------------- +# Multi-mesh kernel: query multiple meshes in a single launch +# --------------------------------------------------------------------------- + + +@wp.kernel +def _multi_mesh_sdf_kernel( + mesh_ids: wp.array(dtype=wp.uint64), + mesh_indices: wp.array(dtype=wp.int32), + query_points: wp.array(dtype=wp.vec3), + sdf_out: wp.array(dtype=wp.float32), + grad_out: wp.array(dtype=wp.vec3), +): + """Query signed distance per point against its assigned mesh (indexed by mesh_indices). + + Points must be in mesh-local frame. Sign convention: negative = inside mesh. + Points with no enclosing face write a large sentinel value (~1e6). + """ + tid = wp.tid() + p = query_points[tid] + mesh_id = mesh_ids[mesh_indices[tid]] + + face_index = int(0) + face_u = float(0.0) + face_v = float(0.0) + sign = float(0.0) + + found = wp.mesh_query_point_sign_normal(mesh_id, p, 1.0e6, sign, face_index, face_u, face_v) + + if found: + closest = wp.mesh_eval_position(mesh_id, face_index, face_u, face_v) + diff = p - closest + dist = wp.length(diff) + sdf_out[tid] = sign * dist + if dist > 1.0e-8: + grad_out[tid] = sign * wp.normalize(diff) + else: + grad_out[tid] = wp.vec3(0.0, 0.0, 0.0) + else: + sdf_out[tid] = float(1.0e6) + grad_out[tid] = wp.vec3(0.0, 0.0, 0.0) + + +class _MultiMeshSDFFunction(torch.autograd.Function): + """Autograd bridge for multi-mesh SDF queries (each point addressed to its own mesh).""" + + @staticmethod + def forward( + ctx, + points: torch.Tensor, + mesh_id_array: wp.array, + mesh_indices: wp.array, + ) -> torch.Tensor: + """Compute SDF values for query points, each against its own target mesh. + + Args: + points: (N, 3) float32 tensor of query positions. + mesh_id_array: Warp array of uint64 mesh IDs. + mesh_indices: Warp array of int32 indices into mesh_id_array (one per point). + + Returns: + (N,) tensor of signed distance values. + """ + device = points.device + n = points.shape[0] + wp_device = str(device) + + wp_points = wp.from_torch(points.contiguous(), dtype=wp.vec3) + sdf_wp = wp.zeros(n, dtype=wp.float32, device=wp_device) + grad_wp = wp.zeros(n, dtype=wp.vec3, device=wp_device) + + wp.launch( + kernel=_multi_mesh_sdf_kernel, + dim=n, + inputs=[mesh_id_array, mesh_indices, wp_points, sdf_wp, grad_wp], + device=wp_device, + ) + + sdf_torch = wp.to_torch(sdf_wp) + grad_torch = wp.to_torch(grad_wp).reshape(n, 3) + + ctx.save_for_backward(grad_torch) + return sdf_torch + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + (grad_sdf,) = ctx.saved_tensors + grad_points = grad_output.unsqueeze(-1) * grad_sdf + return grad_points, None, None + + +def multi_mesh_sdf( + points: torch.Tensor, + mesh_id_array: wp.array, + mesh_indices: wp.array, +) -> torch.Tensor: + """Differentiable multi-mesh SDF query. Single kernel launch for all points. + + Args: + points: (N, 3) query points. + mesh_id_array: Warp uint64 array of mesh IDs (one per unique target mesh). + mesh_indices: Warp int32 array (N,) mapping each point to its target mesh index. + + Returns: + (N,) signed distance values. Negative = penetrating. + """ + return _MultiMeshSDFFunction.apply(points, mesh_id_array, mesh_indices) diff --git a/isaaclab_arena/tests/test_mesh_collision.py b/isaaclab_arena/tests/test_mesh_collision.py new file mode 100644 index 0000000000..178133016e --- /dev/null +++ b/isaaclab_arena/tests/test_mesh_collision.py @@ -0,0 +1,696 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for mesh-based collision detection: sphere decomposition, dispatch, and end-to-end solver.""" + +from __future__ import annotations + +import math +import numpy as np +import torch +import trimesh + +import pytest + +from isaaclab_arena.assets.dummy_object import DummyObject +from isaaclab_arena.relations.relation_loss_strategies import NoCollisionLossStrategy +from isaaclab_arena.relations.relation_solver import RelationSolver +from isaaclab_arena.relations.relation_solver_params import CollisionMode, RelationSolverParams +from isaaclab_arena.relations.relations import IsAnchor, On +from isaaclab_arena.relations.warp_mesh_manager import WarpMeshAndSphereCache, greedy_sphere_decomposition +from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox +from isaaclab_arena.utils.pose import Pose + +try: + import warp as wp + + wp.init() + _WARP_AVAILABLE = True +except Exception: + _WARP_AVAILABLE = False + +requires_warp = pytest.mark.skipif(not _WARP_AVAILABLE, reason="Warp not available") + + +# Unit tests + + +def _make_cylinder(name: str, radius: float = 0.033, height: float = 0.1) -> DummyObject: + mesh = trimesh.creation.cylinder(radius=radius, height=height, sections=32) + return DummyObject( + name=name, + bounding_box=AxisAlignedBoundingBox( + min_point=(-radius, -radius, -height / 2), + max_point=(radius, radius, height / 2), + ), + collision_mesh=mesh, + ) + + +def _make_box_obj(name: str, sx: float, sy: float, sz: float) -> DummyObject: + mesh = trimesh.creation.box(extents=(sx, sy, sz)) + return DummyObject( + name=name, + bounding_box=AxisAlignedBoundingBox( + min_point=(-sx / 2, -sy / 2, -sz / 2), + max_point=(sx / 2, sy / 2, sz / 2), + ), + collision_mesh=mesh, + ) + + +def _make_table() -> DummyObject: + mesh = trimesh.creation.box(extents=(1.0, 1.0, 0.05)) + table = DummyObject( + name="table", + bounding_box=AxisAlignedBoundingBox(min_point=(-0.5, -0.5, -0.025), max_point=(0.5, 0.5, 0.025)), + collision_mesh=mesh, + ) + table.set_initial_pose(Pose(position_xyz=(0.0, 0.0, 0.0), rotation_xyzw=(0.0, 0.0, 0.0, 1.0))) + table.add_relation(IsAnchor()) + return table + + +def test_sphere_decomposition_covers_surface(): + """Sphere decomposition should cover >80% of surface sample points.""" + mesh = trimesh.creation.cylinder(radius=0.05, height=0.1, sections=32) + spheres = greedy_sphere_decomposition(mesh, num_spheres=20, n_surface=500) + assert spheres.shape[1] == 4 + assert spheres.shape[0] <= 20 + + # Check coverage: what fraction of surface points are within radius of some sphere? + surface_pts = trimesh.sample.sample_surface(mesh, 200)[0] + centers = spheres[:, :3] + radii = spheres[:, 3] + covered = 0 + for pt in surface_pts: + dists = np.linalg.norm(centers - pt, axis=1) + if (dists < radii).any(): + covered += 1 + coverage = covered / len(surface_pts) + assert coverage > 0.8, f"Coverage only {coverage:.1%}" + + +@requires_warp +def test_warp_mesh_caching(): + """Same mesh object should return identical Warp mesh from cache.""" + mesh = trimesh.creation.box(extents=(0.1, 0.1, 0.1)) + manager = WarpMeshAndSphereCache(num_spheres=10) + m1 = manager.get_warp_mesh(mesh) + m2 = manager.get_warp_mesh(mesh) + assert m1 is m2 + + +def test_aabb_zero_loss_well_separated(): + """AABB loss is zero when objects are well separated.""" + strategy = NoCollisionLossStrategy(slope=10.0) + + obj_a = _make_cylinder("a") + obj_b = _make_cylinder("b") + + loss = strategy.compute_loss( + clearance_m=0.0, + child_pos=torch.tensor([0.0, 0.0, 0.0]), + child_bbox=obj_a.get_bounding_box(), + parent_world_bbox=obj_b.get_bounding_box().translated((0.5, 0.0, 0.0)), + ) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-5) + + +def test_aabb_positive_loss_fully_overlapping(): + """AABB loss is positive when two objects fully overlap.""" + strategy = NoCollisionLossStrategy(slope=10000.0) + a = _make_cylinder("a") + b = _make_cylinder("b") + + loss = strategy.compute_loss( + clearance_m=0.0, + child_pos=torch.tensor([0.0, 0.0, 0.0]), + child_bbox=a.get_bounding_box(), + parent_world_bbox=b.get_bounding_box().translated((0.0, 0.0, 0.0)), + ) + assert loss.item() > 0.0 + + +def test_aabb_clearance_m_increases_loss(): + """Near-miss cylinders should have positive AABB loss when clearance_m > 0.""" + strategy = NoCollisionLossStrategy(slope=10000.0) + a = _make_cylinder("a", radius=0.03) + b = _make_cylinder("b", radius=0.03) + child_pos = torch.tensor([0.0, 0.0, 0.0]) + parent_world_bbox = b.get_bounding_box().translated((0.07, 0.0, 0.0)) + + # Separated (0.07 > 0.03+0.03): zero AABB loss without clearance + loss_no_clearance = strategy.compute_loss( + clearance_m=0.0, + child_pos=child_pos, + child_bbox=a.get_bounding_box(), + parent_world_bbox=parent_world_bbox, + ) + # With clearance: boxes expand, loss should be positive + loss_with_clearance = strategy.compute_loss( + clearance_m=0.05, + child_pos=child_pos, + child_bbox=a.get_bounding_box(), + parent_world_bbox=parent_world_bbox, + ) + assert loss_with_clearance.item() > loss_no_clearance.item() + + +# Integration tests + + +@requires_warp +def test_solver_separates_overlapping_cylinders_mesh_mode(): + """RelationSolver with MESH mode should push overlapping cylinders apart.""" + table = _make_table() + a = _make_cylinder("cyl_a") + b = _make_cylinder("cyl_b") + a.add_relation(On(table)) + b.add_relation(On(table)) + + objects = [table, a, b] + initial = [{table: (0.0, 0.0, 0.0), a: (0.0, 0.0, 0.03), b: (0.01, 0.0, 0.03)}] + + solver = RelationSolver( + params=RelationSolverParams( + collision_mode=CollisionMode.MESH, max_iters=200, convergence_threshold=1e-4, verbose=False + ) + ) + result = solver.solve(objects, initial)[0] + + pos_a = np.array(result[a]) + pos_b = np.array(result[b]) + dist = np.linalg.norm(pos_a[:2] - pos_b[:2]) + # Must be separated by at least sum of radii (0.033 + 0.033 = 0.066) + assert dist > 0.066, f"Cylinders not separated: dist={dist:.4f}, need > 0.066" + + +@requires_warp +def test_on_pairs_skipped_in_mesh_mode(): + """On-linked pairs should not be penalized in mesh mode (same as AABB).""" + table = _make_table() + obj = _make_cylinder("can") + obj.add_relation(On(table)) + + objects = [table, obj] + # Place can directly on table surface -- should converge without fighting On + initial = [{table: (0.0, 0.0, 0.0), obj: (0.0, 0.0, 0.03)}] + + solver = RelationSolver( + params=RelationSolverParams( + collision_mode=CollisionMode.MESH, max_iters=100, convergence_threshold=1e-4, verbose=False + ) + ) + result = solver.solve(objects, initial)[0] + + # Should stay near table surface, not be pushed far away + z = result[obj][2] + assert 0.0 < z < 0.15, f"Object pushed too far: z={z}" + + +# Guard tests + + +@requires_warp +def test_random_yaw_mesh_mode_places_successfully(): + """random_yaw_init=True + CollisionMode.MESH should place objects without error.""" + from isaaclab_arena.relations.object_placer import ObjectPlacer + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams + + params = ObjectPlacerParams( + solver_params=RelationSolverParams( + collision_mode=CollisionMode.MESH, + max_iters=50, + ), + random_yaw_init=True, + ) + placer = ObjectPlacer(params=params) + + table = _make_table() + cyl_a = _make_cylinder("a") + cyl_b = _make_cylinder("b") + cyl_a.add_relation(On(table)) + cyl_b.add_relation(On(table)) + + results = placer.place([table, cyl_a, cyl_b]) + assert results[0].success + + +@requires_warp +def test_anchor_with_rotate_around_solution_rejected(): + """Anchor + RotateAroundSolution must fail loudly (not silently mismatch).""" + from isaaclab_arena.relations.object_placer import ObjectPlacer + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams + from isaaclab_arena.relations.relations import RotateAroundSolution + + params = ObjectPlacerParams( + solver_params=RelationSolverParams(collision_mode=CollisionMode.MESH, max_iters=5), + random_yaw_init=True, + ) + placer = ObjectPlacer(params=params) + + table = _make_table() + table.add_relation(RotateAroundSolution(yaw_rad=0.5)) + child = _make_cylinder("child") + child.add_relation(On(table)) + + with pytest.raises(AssertionError, match="Anchor.*RotateAroundSolution"): + placer.place([table, child]) + + +@requires_warp +def test_centers_in_target_frame_applies_both_yaws(): + """Net yaw = source - target; equal yaws cancel out.""" + + from isaaclab_arena.relations.object_placer import ObjectPlacer + + src = DummyObject( + "src", + bounding_box=AxisAlignedBoundingBox(min_point=(-0.1, -0.1, -0.1), max_point=(0.1, 0.1, 0.1)), + collision_mesh=trimesh.creation.box(extents=(0.2, 0.2, 0.2)), + ) + tgt = DummyObject( + "tgt", + bounding_box=AxisAlignedBoundingBox(min_point=(-0.1, -0.1, -0.1), max_point=(0.1, 0.1, 0.1)), + collision_mesh=trimesh.creation.box(extents=(0.2, 0.2, 0.2)), + ) + centers = torch.tensor([[0.10, 0.0, 0.0]]) + src_pos = torch.tensor([0.0, 0.0, 0.0]) + tgt_pos = torch.tensor([0.0, 0.0, 0.0]) + + # No orientations: pass-through + result = ObjectPlacer._centers_in_target_frame(centers, src, tgt, src_pos, tgt_pos, None) + assert torch.allclose(result, centers, atol=1e-6) + + # Source yaw=pi/2, target yaw=0: net rotation = pi/2 + result = ObjectPlacer._centers_in_target_frame(centers, src, tgt, src_pos, tgt_pos, {src: math.pi / 2}) + assert abs(result[0, 0].item()) < 1e-5 + assert abs(result[0, 1].item() - 0.10) < 1e-5 + + # Both at same yaw: net rotation = 0, centers unchanged (offset is zero here) + result = ObjectPlacer._centers_in_target_frame( + centers, src, tgt, src_pos, tgt_pos, {src: math.pi / 2, tgt: math.pi / 2} + ) + assert torch.allclose(result, centers, atol=1e-5) + + +@requires_warp +def test_object_placer_mesh_mode_end_to_end(): + """ObjectPlacer.place() with CollisionMode.MESH returns a valid result.""" + from isaaclab_arena.relations.object_placer import ObjectPlacer + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams + + table = _make_table() + a = _make_cylinder("cyl_a") + b = _make_cylinder("cyl_b") + a.add_relation(On(table)) + b.add_relation(On(table)) + + params = ObjectPlacerParams( + solver_params=RelationSolverParams(collision_mode=CollisionMode.MESH, max_iters=300, verbose=False), + max_placement_attempts=5, + verbose=False, + ) + placer = ObjectPlacer(params=params) + results = placer.place([table, a, b]) + assert results[0].success, f"Placement failed with loss={results[0].final_loss}" + + +@requires_warp +def test_validate_no_overlap_mesh_catches_overlap(): + """Direct test of _validate_no_overlap_mesh: overlapping cylinders should fail validation.""" + from isaaclab_arena.relations.object_placer import ObjectPlacer + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams + + table = _make_table() + a = _make_cylinder("cyl_a") + b = _make_cylinder("cyl_b") + a.add_relation(On(table)) + b.add_relation(On(table)) + + params = ObjectPlacerParams( + solver_params=RelationSolverParams(collision_mode=CollisionMode.MESH, verbose=False), + verbose=False, + ) + placer = ObjectPlacer(params=params) + + # Overlapping positions + positions = {table: (0.0, 0.0, 0.0), a: (0.0, 0.0, 0.05), b: (0.0, 0.0, 0.05)} + assert not placer._validate_no_overlap_mesh(positions) + + # Separated positions + positions_sep = {table: (0.0, 0.0, 0.0), a: (0.2, 0.0, 0.05), b: (-0.2, 0.0, 0.05)} + assert placer._validate_no_overlap_mesh(positions_sep) + + +@requires_warp +def test_validate_no_overlap_mesh_sentinel_fails(monkeypatch): + """A sentinel SDF (no resolvable face) must fail validation, not certify collision-free.""" + from isaaclab_arena.relations import warp_sdf_kernels + from isaaclab_arena.relations.object_placer import ObjectPlacer + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams + + table = _make_table() + a = _make_cylinder("cyl_a") + b = _make_cylinder("cyl_b") + a.add_relation(On(table)) + b.add_relation(On(table)) + + params = ObjectPlacerParams( + solver_params=RelationSolverParams(collision_mode=CollisionMode.MESH, verbose=False), + verbose=False, + ) + placer = ObjectPlacer(params=params) + positions = {table: (0.0, 0.0, 0.0), a: (0.2, 0.0, 0.05), b: (-0.2, 0.0, 0.05)} + assert placer._validate_no_overlap_mesh(positions) + + # Force every query to hit the sentinel; the same separated layout must now fail. + from isaaclab_arena.relations import object_placer as _op_mod + + real_mesh_sdf = warp_sdf_kernels.mesh_sdf + + def fake_sdf(points, mesh): + return torch.full_like(real_mesh_sdf(points, mesh), 1.0e6) + + monkeypatch.setattr(warp_sdf_kernels, "mesh_sdf", fake_sdf) + monkeypatch.setattr(_op_mod, "mesh_sdf", fake_sdf) + assert not placer._validate_no_overlap_mesh(positions) + + +@requires_warp +def test_validate_no_overlap_mesh_respects_anchor_yaw(): + """Validator must use anchor's initial_pose yaw (not identity) when checking overlap.""" + from isaaclab_arena.relations.object_placer import ObjectPlacer + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams + + table = _make_table() + # Long thin anchor rotated 90° about Z + anchor_mesh = trimesh.creation.box(extents=(0.2, 0.02, 0.05)) + anchor = DummyObject( + "anchor", + bounding_box=AxisAlignedBoundingBox(min_point=(-0.1, -0.01, -0.025), max_point=(0.1, 0.01, 0.025)), + collision_mesh=anchor_mesh, + ) + sz = math.sin(math.pi / 4) + cz = math.cos(math.pi / 4) + anchor.set_initial_pose(Pose(position_xyz=(0.0, 0.0, 0.05), rotation_xyzw=(0.0, 0.0, sz, cz))) + anchor.add_relation(IsAnchor()) + + child = _make_cylinder("child", radius=0.012) + child.add_relation(On(table)) + + params = ObjectPlacerParams( + solver_params=RelationSolverParams(collision_mode=CollisionMode.MESH, verbose=False), + verbose=False, + ) + placer = ObjectPlacer(params=params) + + # Child at Y=0.02: outside unrotated anchor (half-width=0.01), + # but inside rotated anchor (half-length=0.1 now spans Y). + positions = {table: (0.0, 0.0, 0.0), anchor: (0.0, 0.0, 0.05), child: (0.0, 0.02, 0.05)} + assert not placer._validate_no_overlap_mesh(positions), "Validator should detect overlap with yawed anchor" + + +@requires_warp +def test_mesh_sdf_backward_gradient(): + """mesh_sdf backward should produce non-zero gradients pointing outward for interior points.""" + from isaaclab_arena.relations.warp_sdf_kernels import mesh_sdf + + mesh = trimesh.creation.box(extents=(0.2, 0.2, 0.2)) + manager = WarpMeshAndSphereCache(num_spheres=10, device="cpu") + warp_mesh = manager.get_warp_mesh(mesh) + + # Off-center point inside the box, closer to +X face + points = torch.tensor([[0.08, 0.0, 0.0]], dtype=torch.float32, requires_grad=True) + sdf = mesh_sdf(points, warp_mesh) + assert sdf.item() < 0.0, "Point inside box should have negative SDF" + + sdf.backward() + assert points.grad is not None + grad = points.grad[0] + # Gradient should point toward the nearest face (+X direction) + assert grad[0].item() > 0.0, f"Gradient X should be positive (toward +X face), got {grad[0].item()}" + assert abs(grad[0].item()) > abs(grad[1].item()), "X component should dominate (closest to +X face)" + + +@requires_warp +def test_solver_mesh_batch_size_two(): + """Solver MESH mode handles batch_size > 1 (both envs solved independently).""" + table = _make_table() + a = _make_cylinder("cyl_a") + b = _make_cylinder("cyl_b") + a.add_relation(On(table)) + b.add_relation(On(table)) + + # Env 0: overlapping, Env 1: separated + initial = [ + {table: (0.0, 0.0, 0.0), a: (0.0, 0.0, 0.05), b: (0.01, 0.0, 0.05)}, + {table: (0.0, 0.0, 0.0), a: (-0.2, 0.0, 0.05), b: (0.2, 0.0, 0.05)}, + ] + + solver = RelationSolver( + params=RelationSolverParams( + collision_mode=CollisionMode.MESH, max_iters=200, convergence_threshold=1e-4, verbose=False + ) + ) + results = solver.solve([table, a, b], initial) + assert len(results) == 2 + + # Env 0: should have moved objects apart + pos_a_0 = np.array(results[0][a]) + pos_b_0 = np.array(results[0][b]) + dist_0 = np.linalg.norm(pos_a_0[:2] - pos_b_0[:2]) + assert dist_0 > 0.06, f"Env 0: objects not separated, dist={dist_0:.4f}" + + # Env 1: already separated, should stay roughly in place + pos_a_1 = np.array(results[1][a]) + pos_b_1 = np.array(results[1][b]) + dist_1 = np.linalg.norm(pos_a_1[:2] - pos_b_1[:2]) + assert dist_1 > 0.3, f"Env 1: separated objects moved too much, dist={dist_1:.4f}" + + +@requires_warp +def test_broadphase_skips_separated_pairs(): + """Well-separated objects produce zero mesh loss from the solver path.""" + table = _make_table() + a = _make_cylinder("a", radius=0.03) + b = _make_cylinder("b", radius=0.03) + a.add_relation(On(table)) + b.add_relation(On(table)) + + # Objects far apart — broadphase should filter them out + initial = [{table: (0.0, 0.0, 0.0), a: (-0.4, 0.0, 0.05), b: (0.4, 0.0, 0.05)}] + + solver = RelationSolver(params=RelationSolverParams(collision_mode=CollisionMode.MESH, max_iters=0, verbose=False)) + solver.solve([table, a, b], initial) + # With max_iters=0, loss is from initial positions. + # Objects are well separated, so collision loss should be minimal + # (only On-relation losses contribute). + loss = solver.last_loss_per_env[0].item() + + # Compare with an overlapping case to confirm broadphase actually filters + initial_overlap = [{table: (0.0, 0.0, 0.0), a: (0.0, 0.0, 0.05), b: (0.01, 0.0, 0.05)}] + solver2 = RelationSolver(params=RelationSolverParams(collision_mode=CollisionMode.MESH, max_iters=0, verbose=False)) + solver2.solve([table, a, b], initial_overlap) + loss_overlap = solver2.last_loss_per_env[0].item() + + assert loss_overlap > loss, "Overlapping case should have higher loss than separated" + + +@requires_warp +def test_broadphase_does_not_skip_overlapping_pairs(): + """Overlapping objects must produce nonzero mesh loss from the solver path.""" + table = _make_table() + a = _make_cylinder("a", radius=0.03) + b = _make_cylinder("b", radius=0.03) + a.add_relation(On(table)) + b.add_relation(On(table)) + + # Overlapping: at same position + initial = [{table: (0.0, 0.0, 0.0), a: (0.0, 0.0, 0.05), b: (0.0, 0.0, 0.05)}] + + solver = RelationSolver(params=RelationSolverParams(collision_mode=CollisionMode.MESH, max_iters=0, verbose=False)) + solver.solve([table, a, b], initial) + loss = solver.last_loss_per_env[0].item() + assert loss > 0.0, "Overlapping objects should produce nonzero loss" + + +@requires_warp +def test_multi_mesh_sdf_distinct_meshes(): + """Verify mesh_indices routes queries to different meshes (not stuck at index 0).""" + from isaaclab_arena.relations.warp_sdf_kernels import multi_mesh_sdf + + # Tall cylinder vs flat box — maximally different SDF at the query point. + cylinder = trimesh.creation.cylinder(radius=0.05, height=0.3, sections=32) + box = trimesh.creation.box(extents=(0.4, 0.4, 0.02)) + mgr = WarpMeshAndSphereCache(num_spheres=10, device="cpu") + + warp_cyl = mgr.get_warp_mesh(cylinder) + warp_box = mgr.get_warp_mesh(box) + + mesh_id_array = wp.array([warp_cyl.id, warp_box.id], dtype=wp.uint64, device="cpu") + + # Point at origin: inside cylinder (depth ~0.05), inside box (depth ~0.01) + p = torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float32) + + idx_cyl = wp.array([0], dtype=wp.int32, device="cpu") + sdf_cyl = multi_mesh_sdf(p, mesh_id_array, idx_cyl) + + idx_box = wp.array([1], dtype=wp.int32, device="cpu") + sdf_box = multi_mesh_sdf(p, mesh_id_array, idx_box) + + # Both inside (negative), but cylinder is much deeper + assert sdf_cyl.item() < -0.03, f"Expected deep inside cylinder, got {sdf_cyl.item()}" + assert sdf_box.item() < 0.0 + assert sdf_cyl.item() < sdf_box.item(), "Cylinder should be deeper than flat box at origin" + + +@requires_warp +def test_multi_mesh_sdf_backward(): + """Backward through multi_mesh_sdf produces correct gradient direction.""" + from isaaclab_arena.relations.warp_sdf_kernels import multi_mesh_sdf + + mesh = trimesh.creation.cylinder(radius=0.05, height=0.1, sections=32) + mgr = WarpMeshAndSphereCache(num_spheres=10, device="cpu") + warp_mesh = mgr.get_warp_mesh(mesh) + + mesh_id_array = wp.array([warp_mesh.id], dtype=wp.uint64, device="cpu") + mesh_indices = wp.array([0], dtype=wp.int32, device="cpu") + + # Point at (0.02, 0, 0) inside cylinder (r=0.05): SDF gradient should point outward (+x) + points = torch.tensor([[0.02, 0.0, 0.0]], dtype=torch.float32, requires_grad=True) + sdf = multi_mesh_sdf(points, mesh_id_array, mesh_indices) + sdf.backward() + + assert points.grad is not None + assert torch.isfinite(points.grad).all() + # SDF gradient at (0.02, 0, 0) should point radially outward: positive x, near-zero y/z + assert points.grad[0, 0].item() > 0.1, f"Expected +x gradient, got {points.grad[0].tolist()}" + + +@requires_warp +def test_solver_target_only_yaw(): + """Target-only yaw must affect collision detection (catches missing parent rotation).""" + table = _make_table() + # Long thin box: if rotated 90° around Z, its collision footprint changes axis + target = _make_box_obj("target", sx=0.2, sy=0.02, sz=0.05) + target.set_initial_pose(Pose(position_xyz=(0.0, 0.0, 0.05))) + target.add_relation(IsAnchor()) + + child = _make_cylinder("child", radius=0.015) + child.add_relation(On(table)) + + # Place child next to target's long axis (Y=0.03). Without target rotation, + # this is well outside the 0.02/2 half-width in Y → no collision. + # With target rotated 90°, the 0.2/2=0.1 half-extent now spans Y → collision. + initial = [{table: (0.0, 0.0, 0.0), target: (0.0, 0.0, 0.05), child: (0.0, 0.03, 0.05)}] + + # No rotation: should be low/zero mesh collision + solver_no_rot = RelationSolver( + params=RelationSolverParams(collision_mode=CollisionMode.MESH, max_iters=0, verbose=False) + ) + solver_no_rot.solve([table, target, child], initial, orientations=None) + loss_no_rot = solver_no_rot.last_loss_per_env[0].item() + + # Target rotated 90° around Z: child is now inside target's mesh + orientations_rotated = [{target: math.pi / 2, child: 0.0}] + solver_rot = RelationSolver( + params=RelationSolverParams(collision_mode=CollisionMode.MESH, max_iters=0, verbose=False) + ) + solver_rot.solve([table, target, child], initial, orientations=orientations_rotated) + loss_rot = solver_rot.last_loss_per_env[0].item() + + assert ( + loss_rot > loss_no_rot + 1.0 + ), f"Target yaw=90° should dramatically increase collision loss (got {loss_rot:.2f} vs {loss_no_rot:.2f})" + + +@requires_warp +def test_anchor_initial_pose_yaw_affects_collision(): + """Anchor Z-yaw baked in initial_pose (no orientations dict) must affect SDF queries.""" + table = _make_table() + # Long thin anchor: 0.2 x 0.02 — rotation changes which axis is long + target_mesh = trimesh.creation.box(extents=(0.2, 0.02, 0.05)) + target = DummyObject( + "target", + bounding_box=AxisAlignedBoundingBox(min_point=(-0.1, -0.01, -0.025), max_point=(0.1, 0.01, 0.025)), + collision_mesh=target_mesh, + ) + # Bake 90° Z-yaw into initial_pose (not via orientations dict) + sz = math.sin(math.pi / 4) + cz = math.cos(math.pi / 4) + target.set_initial_pose(Pose(position_xyz=(0.0, 0.0, 0.05), rotation_xyzw=(0.0, 0.0, sz, cz))) + target.add_relation(IsAnchor()) + + child = _make_cylinder("child", radius=0.012) + child.add_relation(On(table)) + + # Child at Y=0.02: outside unrotated target (half-width=0.01), inside rotated (half-length=0.1) + initial = [{table: (0.0, 0.0, 0.0), target: (0.0, 0.0, 0.05), child: (0.0, 0.02, 0.05)}] + + solver = RelationSolver(params=RelationSolverParams(collision_mode=CollisionMode.MESH, max_iters=0, verbose=False)) + solver.solve([table, target, child], initial, orientations=None) + loss_yawed = solver.last_loss_per_env[0].item() + + # Same geometry with identity anchor — should have lower loss + target_id = DummyObject( + "target_id", + bounding_box=AxisAlignedBoundingBox(min_point=(-0.1, -0.01, -0.025), max_point=(0.1, 0.01, 0.025)), + collision_mesh=target_mesh, + ) + target_id.set_initial_pose(Pose(position_xyz=(0.0, 0.0, 0.05))) + target_id.add_relation(IsAnchor()) + initial_id = [{table: (0.0, 0.0, 0.0), target_id: (0.0, 0.0, 0.05), child: (0.0, 0.02, 0.05)}] + + solver_id = RelationSolver( + params=RelationSolverParams(collision_mode=CollisionMode.MESH, max_iters=0, verbose=False) + ) + solver_id.solve([table, target_id, child], initial_id, orientations=None) + loss_identity = solver_id.last_loss_per_env[0].item() + + assert loss_yawed > loss_identity + 1.0, ( + "Yawed anchor (from initial_pose) should produce higher collision " + f"(got {loss_yawed:.2f} vs identity {loss_identity:.2f})" + ) + + +@requires_warp +def test_aabb_gate_does_not_reject_diagonal_cylinders(): + """Regression: MESH-mode validator accepts cylinders whose AABBs overlap but meshes don't.""" + from isaaclab_arena.relations.object_placer import ObjectPlacer + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams + + table = _make_table() + # r=0.05, b at (0.09, 0.09): AABB overlap (0.09 < 2*0.05=0.10) but geometric + # distance = 0.127 > sum-of-radii 0.10 + sphere_radius 0.01, plenty of margin. + a = _make_cylinder("a", radius=0.05, height=0.1) + b = _make_cylinder("b", radius=0.05, height=0.1) + a.add_relation(On(table)) + b.add_relation(On(table)) + + params = ObjectPlacerParams( + solver_params=RelationSolverParams(collision_mode=CollisionMode.MESH, verbose=False), + verbose=False, + ) + placer = ObjectPlacer(params=params) + + positions = {table: (0.0, 0.0, 0.0), a: (0.0, 0.0, 0.05), b: (0.09, 0.09, 0.05)} + + # Sanity: AABB check without skip_mesh_pairs REJECTS this layout + env_bboxes = {obj: obj.get_bounding_box() for obj in positions} + assert not placer._validate_no_overlap( + positions, env_bboxes, skip_mesh_pairs=False + ), "Sanity check failed: AABB should reject diagonal cylinders" + + # With skip_mesh_pairs=True (MESH mode), AABB validator skips this pair + assert placer._validate_no_overlap( + positions, env_bboxes, skip_mesh_pairs=True + ), "AABB validator with skip_mesh_pairs should accept this pair" + + # Mesh validator accepts (cylinders don't actually overlap) + assert placer._validate_no_overlap_mesh( + positions + ), "Mesh validator should accept diagonal cylinders that don't geometrically overlap" diff --git a/isaaclab_arena/tests/test_object_placer_reproducibility.py b/isaaclab_arena/tests/test_object_placer_reproducibility.py index 10122fb0f1..5b77028681 100644 --- a/isaaclab_arena/tests/test_object_placer_reproducibility.py +++ b/isaaclab_arena/tests/test_object_placer_reproducibility.py @@ -306,7 +306,7 @@ def test_random_yaw_init_applied_yaw_matches_selected_candidate(): def test_random_yaw_init_composes_marker_yaw(): - """A yaw RotateAroundSolution marker composes with the sampled yaw: applied == marker + sampled.""" + """orientations dict carries total yaw (marker + sampled); applied pose matches it.""" marker_yaw = math.pi / 6 solver_params = RelationSolverParams(max_iters=10, verbose=False) desk, box1, box2 = _create_test_objects() @@ -316,7 +316,8 @@ def test_random_yaw_init_composes_marker_yaw(): ) (result,) = placer.place([desk, box1, box2], num_envs=1) applied = _yaw_rad_from_quat(box1.get_initial_pose().rotation_xyzw) - assert abs(wrap_angle_to_pi(applied - (marker_yaw + result.orientations[box1]))) < 1e-5 + # result.orientations now carries total yaw = marker + sampled + assert abs(wrap_angle_to_pi(applied - result.orientations[box1])) < 1e-5 def test_random_yaw_init_rejects_roll_pitch_marker(): @@ -331,6 +332,20 @@ def test_random_yaw_init_rejects_roll_pitch_marker(): placer.place([desk, box1, box2], num_envs=1) +def test_marker_yaw_applied_without_random_yaw_init(): + """RotateAroundSolution marker must be applied even when random_yaw_init=False.""" + marker_yaw = math.pi / 4 + solver_params = RelationSolverParams(max_iters=5, verbose=False) + desk, box1, box2 = _create_test_objects() + box1.add_relation(RotateAroundSolution(yaw_rad=marker_yaw)) + placer = ObjectPlacer( + params=ObjectPlacerParams(placement_seed=1, solver_params=solver_params, random_yaw_init=False) + ) + placer.place([desk, box1, box2], num_envs=1) + applied = _yaw_rad_from_quat(box1.get_initial_pose().rotation_xyzw) + assert abs(wrap_angle_to_pi(applied - marker_yaw)) < 1e-5, f"Marker yaw {marker_yaw} must be applied; got {applied}" + + def _positions_by_name(result: PlacementResult) -> dict[str, tuple[float, float, float]]: return {obj.name: pos for obj, pos in result.positions.items()} diff --git a/isaaclab_arena/tests/test_relation_solver_interface.py b/isaaclab_arena/tests/test_relation_solver_interface.py index 27f4a3fc19..f957064d00 100644 --- a/isaaclab_arena/tests/test_relation_solver_interface.py +++ b/isaaclab_arena/tests/test_relation_solver_interface.py @@ -62,12 +62,13 @@ def test_solve_and_apply_relation_placement_with_no_objects_returns_empty_result def test_solve_and_apply_relation_placement_with_only_anchors_returns_no_reset_event(): from isaaclab_arena.environments.relation_solver_interface import solve_and_apply_relation_placement + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams + params = ObjectPlacerParams(placement_seed=11, resolve_on_reset=False) placement_event_cfg = solve_and_apply_relation_placement( [_make_desk()], num_envs=3, - placement_seed=11, - resolve_on_reset=False, + placer_params=params, ) assert placement_event_cfg is None @@ -75,6 +76,7 @@ def test_solve_and_apply_relation_placement_with_only_anchors_returns_no_reset_e def test_static_solve_and_apply_relation_placement_reuses_object_only_placement(): from isaaclab_arena.environments.relation_solver_interface import solve_and_apply_relation_placement + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams from isaaclab_arena.relations.relations import On from isaaclab_arena.utils.pose import PosePerEnv @@ -82,11 +84,11 @@ def test_static_solve_and_apply_relation_placement_reuses_object_only_placement( box = _make_box() box.add_relation(On(desk, clearance_m=0.01)) + params = ObjectPlacerParams(placement_seed=7, resolve_on_reset=False) placement_event_cfg = solve_and_apply_relation_placement( [desk, box], num_envs=2, - placement_seed=7, - resolve_on_reset=False, + placer_params=params, ) assert placement_event_cfg is None diff --git a/isaaclab_arena/tests/test_usd_scale_helpers.py b/isaaclab_arena/tests/test_usd_scale_helpers.py new file mode 100644 index 0000000000..66548ef7fd --- /dev/null +++ b/isaaclab_arena/tests/test_usd_scale_helpers.py @@ -0,0 +1,279 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Regression tests for USD scale handling in extract_trimesh_from_usd and compute_local_bounding_box_from_usd. + +Verifies that spawn-scale is applied in the local frame (R·(S·v)+t) rather than world +frame (S·(R·v+t)), which matters for translated/rotated child prims under non-uniform scale. +""" + +import numpy as np + +from isaaclab_arena.tests.utils.subprocess import run_simulation_app_function + +HEADLESS = True + + +def _test_extract_trimesh_translated_child_nonuniform_scale(simulation_app): + """extract_trimesh_from_usd must scale in local frame, not world frame. + + Setup: unit cube under a child Xform translated +1.0 in X, scale=(2,1,1). + Correct (local scale): verts ±0.5 → ±1.0 in local X, then translate +1 → world X [0.0, 2.0]. + Bug (world scale): verts ±0.5, translate +1 → world [0.5, 1.5], then *2 → [1.0, 3.0]. + """ + import tempfile + + from pxr import Gf, Usd, UsdGeom + + from isaaclab_arena.utils.usd_helpers import extract_trimesh_from_usd + + stage = Usd.Stage.CreateInMemory() + root = stage.DefinePrim("/root", "Xform") + stage.SetDefaultPrim(root) + + child_xform = UsdGeom.Xform.Define(stage, "/root/child") + child_xform.AddTranslateOp().Set(Gf.Vec3d(1.0, 0.0, 0.0)) + + mesh_prim = UsdGeom.Mesh.Define(stage, "/root/child/cube") + points = [ + Gf.Vec3f(-0.5, -0.5, -0.5), + Gf.Vec3f(0.5, -0.5, -0.5), + Gf.Vec3f(0.5, 0.5, -0.5), + Gf.Vec3f(-0.5, 0.5, -0.5), + Gf.Vec3f(-0.5, -0.5, 0.5), + Gf.Vec3f(0.5, -0.5, 0.5), + Gf.Vec3f(0.5, 0.5, 0.5), + Gf.Vec3f(-0.5, 0.5, 0.5), + ] + face_vertex_counts = [4, 4, 4, 4, 4, 4] + face_vertex_indices = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 0, + 1, + 5, + 4, + 2, + 3, + 7, + 6, + 0, + 3, + 7, + 4, + 1, + 2, + 6, + 5, + ] + mesh_prim.GetPointsAttr().Set(points) + mesh_prim.GetFaceVertexCountsAttr().Set(face_vertex_counts) + mesh_prim.GetFaceVertexIndicesAttr().Set(face_vertex_indices) + + with tempfile.NamedTemporaryFile(suffix=".usda", delete=False) as f: + usd_path = f.name + stage.Export(usd_path) + + scale = (2.0, 1.0, 1.0) + tri = extract_trimesh_from_usd(usd_path, scale=scale) + verts = tri.vertices + + # Local scale: ±0.5*2=±1.0 then +1 translate → world X [0.0, 2.0] + assert np.isclose(verts[:, 0].min(), 0.0, atol=1e-5), f"got {verts[:, 0].min():.4f}" + assert np.isclose(verts[:, 0].max(), 2.0, atol=1e-5), f"got {verts[:, 0].max():.4f}" + + assert np.isclose(verts[:, 1].min(), -0.5, atol=1e-5) + assert np.isclose(verts[:, 1].max(), 0.5, atol=1e-5) + assert np.isclose(verts[:, 2].min(), -0.5, atol=1e-5) + assert np.isclose(verts[:, 2].max(), 0.5, atol=1e-5) + + return True + + +def _test_bbox_translated_child_nonuniform_scale(simulation_app): + """BBox uses ComputeLocalBound * scale (post-transform scale on root-local extents). + + For a child translated +1 with verts ±0.5: root-local bound X=[0.5, 1.5], * scale_x=2 → [1.0, 3.0]. + Note: mesh path scales per-prim verts first → X=[0.0, 2.0]. These differ for translated children + under non-uniform scale. AABB is conservative (larger), which is safe for collision checks. + """ + import tempfile + + from pxr import Gf, Usd, UsdGeom + + from isaaclab_arena.utils.usd_helpers import compute_local_bounding_box_from_usd + + stage = Usd.Stage.CreateInMemory() + root = stage.DefinePrim("/root", "Xform") + stage.SetDefaultPrim(root) + + child_xform = UsdGeom.Xform.Define(stage, "/root/child") + child_xform.AddTranslateOp().Set(Gf.Vec3d(1.0, 0.0, 0.0)) + + mesh_prim = UsdGeom.Mesh.Define(stage, "/root/child/cube") + points = [ + Gf.Vec3f(-0.5, -0.5, -0.5), + Gf.Vec3f(0.5, -0.5, -0.5), + Gf.Vec3f(0.5, 0.5, -0.5), + Gf.Vec3f(-0.5, 0.5, -0.5), + Gf.Vec3f(-0.5, -0.5, 0.5), + Gf.Vec3f(0.5, -0.5, 0.5), + Gf.Vec3f(0.5, 0.5, 0.5), + Gf.Vec3f(-0.5, 0.5, 0.5), + ] + face_vertex_counts = [4, 4, 4, 4, 4, 4] + face_vertex_indices = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 0, + 1, + 5, + 4, + 2, + 3, + 7, + 6, + 0, + 3, + 7, + 4, + 1, + 2, + 6, + 5, + ] + mesh_prim.GetPointsAttr().Set(points) + mesh_prim.GetFaceVertexCountsAttr().Set(face_vertex_counts) + mesh_prim.GetFaceVertexIndicesAttr().Set(face_vertex_indices) + + with tempfile.NamedTemporaryFile(suffix=".usda", delete=False) as f: + usd_path = f.name + stage.Export(usd_path) + + scale = (2.0, 1.0, 1.0) + bbox = compute_local_bounding_box_from_usd(usd_path, scale=scale) + + # ComputeLocalBound gives [0.5,1.5] * scale_x=2 → [1.0, 3.0] + min_pt = bbox.min_point[0] # (3,) tensor + max_pt = bbox.max_point[0] # (3,) tensor + assert np.isclose(min_pt[0].item(), 1.0, atol=1e-5), f"got {min_pt[0].item():.4f}" + assert np.isclose(max_pt[0].item(), 3.0, atol=1e-5), f"got {max_pt[0].item():.4f}" + assert np.isclose(min_pt[1].item(), -0.5, atol=1e-5) + assert np.isclose(max_pt[1].item(), 0.5, atol=1e-5) + assert np.isclose(min_pt[2].item(), -0.5, atol=1e-5) + assert np.isclose(max_pt[2].item(), 0.5, atol=1e-5) + + return True + + +def _test_both_paths_agree_origin_prim(simulation_app): + """For an origin-centered single prim, mesh and bbox agree exactly.""" + import tempfile + + from pxr import Gf, Usd, UsdGeom + + from isaaclab_arena.utils.usd_helpers import compute_local_bounding_box_from_usd, extract_trimesh_from_usd + + stage = Usd.Stage.CreateInMemory() + root = stage.DefinePrim("/root", "Xform") + stage.SetDefaultPrim(root) + + mesh_prim = UsdGeom.Mesh.Define(stage, "/root/cube") + points = [ + Gf.Vec3f(-0.5, -0.5, -0.5), + Gf.Vec3f(0.5, -0.5, -0.5), + Gf.Vec3f(0.5, 0.5, -0.5), + Gf.Vec3f(-0.5, 0.5, -0.5), + Gf.Vec3f(-0.5, -0.5, 0.5), + Gf.Vec3f(0.5, -0.5, 0.5), + Gf.Vec3f(0.5, 0.5, 0.5), + Gf.Vec3f(-0.5, 0.5, 0.5), + ] + face_vertex_counts = [4, 4, 4, 4, 4, 4] + face_vertex_indices = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 0, + 1, + 5, + 4, + 2, + 3, + 7, + 6, + 0, + 3, + 7, + 4, + 1, + 2, + 6, + 5, + ] + mesh_prim.GetPointsAttr().Set(points) + mesh_prim.GetFaceVertexCountsAttr().Set(face_vertex_counts) + mesh_prim.GetFaceVertexIndicesAttr().Set(face_vertex_indices) + + with tempfile.NamedTemporaryFile(suffix=".usda", delete=False) as f: + usd_path = f.name + stage.Export(usd_path) + + scale = (2.0, 3.0, 0.5) + tri = extract_trimesh_from_usd(usd_path, scale=scale) + bbox = compute_local_bounding_box_from_usd(usd_path, scale=scale) + + verts = tri.vertices + # ±0.5 * (2, 3, 0.5) = (±1.0, ±1.5, ±0.25) + assert np.isclose(verts[:, 0].min(), -1.0, atol=1e-5) + assert np.isclose(verts[:, 0].max(), 1.0, atol=1e-5) + assert np.isclose(verts[:, 1].min(), -1.5, atol=1e-5) + assert np.isclose(verts[:, 1].max(), 1.5, atol=1e-5) + assert np.isclose(verts[:, 2].min(), -0.25, atol=1e-5) + assert np.isclose(verts[:, 2].max(), 0.25, atol=1e-5) + + # BBox must match mesh extents exactly for origin-centered single prim. + min_pt = bbox.min_point[0] # (3,) tensor + max_pt = bbox.max_point[0] # (3,) tensor + assert np.isclose(min_pt[0].item(), -1.0, atol=1e-5) + assert np.isclose(max_pt[0].item(), 1.0, atol=1e-5) + assert np.isclose(min_pt[1].item(), -1.5, atol=1e-5) + assert np.isclose(max_pt[1].item(), 1.5, atol=1e-5) + assert np.isclose(min_pt[2].item(), -0.25, atol=1e-5) + assert np.isclose(max_pt[2].item(), 0.25, atol=1e-5) + + return True + + +def test_extract_trimesh_translated_child_nonuniform_scale(): + result = run_simulation_app_function(_test_extract_trimesh_translated_child_nonuniform_scale, headless=HEADLESS) + assert result + + +def test_bbox_translated_child_nonuniform_scale(): + result = run_simulation_app_function(_test_bbox_translated_child_nonuniform_scale, headless=HEADLESS) + assert result + + +def test_both_paths_agree_origin_prim(): + result = run_simulation_app_function(_test_both_paths_agree_origin_prim, headless=HEADLESS) + assert result diff --git a/isaaclab_arena/tests/test_validate_placement.py b/isaaclab_arena/tests/test_validate_placement.py index c2c56aea9c..2e0459fcb5 100644 --- a/isaaclab_arena/tests/test_validate_placement.py +++ b/isaaclab_arena/tests/test_validate_placement.py @@ -150,17 +150,19 @@ def test_candidate_bbox_aligns_with_candidate_yaw(): def test_rotate_candidate_bboxes_encloses_marker_plus_sampled_yaw(): - """_rotate_candidate_bboxes folds the marker yaw into the box, not just the sampled yaw.""" + """_rotate_candidate_bboxes applies the total yaw (marker + sampled) passed by the caller.""" box = _make_long_box("box") marker_yaw, sampled_yaw = math.pi / 6, math.pi / 3 + total_yaw = marker_yaw + sampled_yaw box.add_relation(RotateAroundSolution(yaw_rad=marker_yaw)) - rotated = ObjectPlacer._rotate_candidate_bboxes([box], {box: box.get_bounding_box()}, [{box: sampled_yaw}]) + # In production, _generate_initial_orientations computes total_yaw before calling this method. + rotated = ObjectPlacer._rotate_candidate_bboxes([box], {box: box.get_bounding_box()}, [{box: total_yaw}]) - expected = box.get_bounding_box().rotated_around_z(marker_yaw + sampled_yaw) + expected = box.get_bounding_box().rotated_around_z(total_yaw) torch.testing.assert_close(rotated[box].min_point, expected.min_point, atol=1e-6, rtol=0) torch.testing.assert_close(rotated[box].max_point, expected.max_point, atol=1e-6, rtol=0) - # Dropping the marker (sampled yaw only) would enclose an undersized, misaligned footprint. + # Passing only sampled_yaw (without marker) would enclose an undersized, misaligned footprint. sampled_only = box.get_bounding_box().rotated_around_z(sampled_yaw) assert not torch.allclose(rotated[box].max_point, sampled_only.max_point, atol=1e-6) diff --git a/isaaclab_arena/utils/pose.py b/isaaclab_arena/utils/pose.py index ecc9289d64..95b0199400 100644 --- a/isaaclab_arena/utils/pose.py +++ b/isaaclab_arena/utils/pose.py @@ -58,6 +58,17 @@ def wrap_angle_to_pi(angle_rad: float) -> float: return (angle_rad + math.pi) % (2.0 * math.pi) - math.pi +def yaw_from_quat_xyzw(quat_xyzw: tuple[float, float, float, float]) -> float: + """Extract Z-axis yaw (radians) from an (x, y, z, w) quaternion. + + Returns 0.0 if the quaternion has non-trivial roll or pitch (|qx| or |qy| > 1e-6). + """ + qx, qy, qz, qw = quat_xyzw + if abs(qx) > 1e-6 or abs(qy) > 1e-6: + return 0.0 + return 2.0 * math.atan2(qz, qw) + + def rotate_quat_by_yaw( base_xyzw: tuple[float, float, float, float], yaw_rad: float ) -> tuple[float, float, float, float]: diff --git a/isaaclab_arena/utils/usd_helpers.py b/isaaclab_arena/utils/usd_helpers.py index 5e96c66a94..c7fcb7a00a 100644 --- a/isaaclab_arena/utils/usd_helpers.py +++ b/isaaclab_arena/utils/usd_helpers.py @@ -3,6 +3,10 @@ # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import numpy as np +import trimesh from contextlib import contextmanager from pxr import Gf, Usd, UsdGeom, UsdLux, UsdPhysics @@ -206,3 +210,68 @@ def compute_local_bounding_box_from_prim( min_point=(local_min[0], local_min[1], local_min[2]), max_point=(local_max[0], local_max[1], local_max[2]), ) + + +def extract_trimesh_from_usd( + usd_path: str, + scale: tuple[float, float, float] = (1.0, 1.0, 1.0), +) -> trimesh.Trimesh: + """Extract all mesh prims from a USD into a single trimesh. + + Scale is applied per-vertex in local frame before the prim-to-world transform. + All scale components must be positive (negative flips winding/SDF sign). + + Args: + usd_path: Path to the .usd/.usda/.usdc file. + scale: (sx, sy, sz) per-axis scale factors applied in local frame. + + Returns: + Combined trimesh with per-prim world transforms baked in. + """ + assert all( + s > 0 for s in scale + ), f"All scale components must be positive (negative scale flips winding/SDF sign), got {scale}" + + stage = Usd.Stage.Open(usd_path) + if stage is None: + raise ValueError(f"Failed to open USD: {usd_path}") + + all_verts: list[np.ndarray] = [] + all_faces: list[list[int]] = [] + offset = 0 + + for prim in stage.Traverse(): + if not prim.IsA(UsdGeom.Mesh): + continue + mesh_prim = UsdGeom.Mesh(prim) + points = mesh_prim.GetPointsAttr().Get() + face_vertex_counts = mesh_prim.GetFaceVertexCountsAttr().Get() + face_vertex_indices = mesh_prim.GetFaceVertexIndicesAttr().Get() + if points is None or face_vertex_counts is None or face_vertex_indices is None: + continue + + xform = UsdGeom.Xformable(prim) + world_tf = np.array(xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default())) + + verts = np.asarray(points, dtype=np.float64) + verts_scaled = verts * np.array(scale, dtype=np.float64) + verts_h = np.hstack([verts_scaled, np.ones((len(verts_scaled), 1))]) + verts_world = (verts_h @ world_tf)[:, :3] + + # Fan-triangulate faces + idx = 0 + for count in face_vertex_counts: + for k in range(1, count - 1): + all_faces.append([ + face_vertex_indices[idx] + offset, + face_vertex_indices[idx + k] + offset, + face_vertex_indices[idx + k + 1] + offset, + ]) + idx += count + + all_verts.append(verts_world) + offset += len(verts_world) + + if not all_verts: + raise ValueError(f"No mesh geometry found in {usd_path}") + return trimesh.Trimesh(vertices=np.vstack(all_verts), faces=np.array(all_faces, dtype=np.int32)) From 86acbc8720be43890d76385f7746e291ca52f782 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Thu, 25 Jun 2026 16:59:24 -0700 Subject: [PATCH 2/7] improve docstring Signed-off-by: zhx06 --- isaaclab_arena/relations/object_placer.py | 5 ++--- isaaclab_arena/relations/relation_solver.py | 11 ++--------- isaaclab_arena/relations/warp_mesh_manager.py | 2 +- 3 files changed, 5 insertions(+), 13 deletions(-) diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 3c834c3dfa..93e9dfba78 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -18,7 +18,6 @@ from isaaclab_arena.relations.relation_loss_strategies import SIDE_CONFIGS, next_to_violations, not_next_to_violations from isaaclab_arena.relations.relation_solver import RelationSolver from isaaclab_arena.relations.relations import ( - IsAnchor, NextTo, NotNextTo, On, @@ -742,7 +741,7 @@ def _not_next_to_margin(self, relation: NotNextTo) -> float: return strategy.margin_m def _get_cpu_mesh_manager(self): - """Lazy-init CPU WarpMeshAndSphereCache.""" + """Return the CPU-device mesh manager, creating it on first call.""" if self._cpu_mesh_manager is None: self._cpu_mesh_manager = WarpMeshAndSphereCache( num_spheres=self.params.solver_params.num_spheres, @@ -803,7 +802,7 @@ def _spheres_penetrate_mesh( tolerance, orientations, ) -> bool: - """True if source's spheres penetrate target's mesh.""" + """True if source's spheres penetrate target's mesh or if BVH returns no-face sentinel.""" spheres = mesh_manager.get_query_spheres(source_mesh, obj=source) warp_mesh = mesh_manager.get_warp_mesh(target_mesh, obj=target) centers = self._centers_in_target_frame(spheres[:, :3], source, target, source_pos, target_pos, orientations) diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index 4f50fd6f21..2a911d217d 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -295,10 +295,7 @@ def _prepare_mesh_collision_cache( def _build_vectorized_cache( self, state, manager, non_anchor_objects, anchor_objects, on_pairs, device, direction: str ) -> MeshPairCache | None: - """Build vectorized pair cache for one direction. - - Returns None if no valid pairs exist for this direction. - """ + """Build the MeshPairCache for forward or reverse assignment; None when no pairs qualify.""" centers_list: list[torch.Tensor] = [] radii_list: list[torch.Tensor] = [] pair_child_objs: list = [] @@ -495,11 +492,7 @@ def _compute_no_overlap_loss_mesh( state: RelationSolverState, debug: bool, ) -> torch.Tensor: - """Sphere-to-SDF penetration loss using the vectorized multi-mesh kernel. - - Uses precomputed pair cache (centers, radii, mesh indices) to batch all - sphere queries into a single Warp kernel call per iteration. - """ + """Sphere-to-SDF penetration loss via the vectorized multi-mesh kernel.""" device = state.device total_loss = torch.zeros(state.batch_size, device=device, dtype=torch.float32) clearance_m = self.params.clearance_m diff --git a/isaaclab_arena/relations/warp_mesh_manager.py b/isaaclab_arena/relations/warp_mesh_manager.py index 5a5192df2f..bba8bc0c05 100644 --- a/isaaclab_arena/relations/warp_mesh_manager.py +++ b/isaaclab_arena/relations/warp_mesh_manager.py @@ -143,7 +143,7 @@ def warn_sdf_sentinel(self, sdf_values: torch.Tensor) -> None: ) def get_collision_mesh(self, obj: ObjectBase) -> trimesh.Trimesh | None: - """Extract or retrieve cached collision mesh for an object.""" + """Return the cached collision mesh, extracting from USD on first access.""" usd_path = getattr(obj, "usd_path", None) if usd_path is None: return obj.get_collision_mesh() From 5a783d97e94004a7c4661754c1ea47893fb4df7c Mon Sep 17 00:00:00 2001 From: zhx06 Date: Fri, 26 Jun 2026 09:51:48 -0700 Subject: [PATCH 3/7] unify AABB and mesh Signed-off-by: zhx06 --- isaaclab_arena/assets/object.py | 2 +- isaaclab_arena/relations/mesh_pair_cache.py | 38 +- .../relations/relation_loss_strategies.py | 3 - isaaclab_arena/relations/relation_solver.py | 410 +++++++++--------- isaaclab_arena/tests/test_mesh_collision.py | 43 +- 5 files changed, 240 insertions(+), 256 deletions(-) diff --git a/isaaclab_arena/assets/object.py b/isaaclab_arena/assets/object.py index 600bfbbe7d..bc1da6f1f8 100644 --- a/isaaclab_arena/assets/object.py +++ b/isaaclab_arena/assets/object.py @@ -78,7 +78,7 @@ def get_bounding_box(self) -> AxisAlignedBoundingBox: return self.bounding_box def get_collision_mesh(self) -> trimesh.Trimesh | None: - """Return None; collision mesh is not available at the object level.""" + """Collision mesh is unavailable for USD-backed objects; subclasses with preloaded geometry may override.""" def get_world_bounding_box(self) -> AxisAlignedBoundingBox: """Get bounding box in world coordinates (local bbox rotated and translated). diff --git a/isaaclab_arena/relations/mesh_pair_cache.py b/isaaclab_arena/relations/mesh_pair_cache.py index 1ad82e1dff..edd9f3884c 100644 --- a/isaaclab_arena/relations/mesh_pair_cache.py +++ b/isaaclab_arena/relations/mesh_pair_cache.py @@ -22,39 +22,39 @@ class MeshPairCache: """Precomputed per-pair collision data for the vectorized multi-mesh kernel.""" all_centers_local: torch.Tensor - """(S, 3) sphere centers in each child's local frame, concatenated across pairs.""" + """(S, 3) sphere centers in each subject's local frame, concatenated across pairs.""" all_radii: torch.Tensor """(S,) sphere radii, concatenated across pairs.""" - pair_child_objs: list[ObjectBase] - """Per-pair child object reference.""" + pair_subject_objs: list[ObjectBase] + """Per-pair subject (sphere source) object reference.""" - pair_parent_objs: list[ObjectBase] - """Per-pair parent/target object reference.""" + pair_obstacle_objs: list[ObjectBase] + """Per-pair obstacle (mesh target) object reference.""" pair_is_anchor: list[bool] - """Per-pair flag: True if parent is a static anchor.""" + """Per-pair flag: True if the obstacle is a static anchor.""" pair_anchor_pos: list[torch.Tensor | None] - """Per-pair world position of anchors (None for non-anchor parents).""" + """Per-pair world position for anchor obstacles (None for non-anchor obstacles).""" pair_anchor_yaw: list[float] - """Per-pair anchor yaw in radians (0.0 for non-anchor parents).""" + """Per-pair anchor yaw in radians (0.0 for non-anchor obstacles).""" - pair_c_bbox_min: torch.Tensor - """(P, B, 3) child bbox min corners for broadphase.""" + pair_subject_bbox_min: torch.Tensor + """(P, B, 3) subject bbox min corners for broadphase.""" - pair_c_bbox_max: torch.Tensor - """(P, B, 3) child bbox max corners for broadphase.""" + pair_subject_bbox_max: torch.Tensor + """(P, B, 3) subject bbox max corners for broadphase.""" - pair_p_bbox_min: torch.Tensor - """(P, B, 3) parent bbox min corners for broadphase.""" + pair_obstacle_bbox_min: torch.Tensor + """(P, B, 3) obstacle bbox min corners for broadphase.""" - pair_p_bbox_max: torch.Tensor - """(P, B, 3) parent bbox max corners for broadphase.""" + pair_obstacle_bbox_max: torch.Tensor + """(P, B, 3) obstacle bbox max corners for broadphase.""" - pair_max_r: torch.Tensor + pair_max_radius: torch.Tensor """(P,) max sphere radius per pair (broadphase margin).""" sphere_pair_id: torch.Tensor @@ -76,8 +76,8 @@ class MeshPairCache: """Total number of sphere queries across all pairs.""" def __post_init__(self) -> None: - assert len(self.pair_child_objs) == self.num_pairs, "pair_child_objs length mismatch" - assert len(self.pair_parent_objs) == self.num_pairs, "pair_parent_objs length mismatch" + assert len(self.pair_subject_objs) == self.num_pairs, "pair_subject_objs length mismatch" + assert len(self.pair_obstacle_objs) == self.num_pairs, "pair_obstacle_objs length mismatch" assert len(self.pair_is_anchor) == self.num_pairs, "pair_is_anchor length mismatch" assert self.all_centers_local.shape[0] == self.total_spheres, "all_centers_local size mismatch" assert self.all_radii.shape[0] == self.total_spheres, "all_radii size mismatch" diff --git a/isaaclab_arena/relations/relation_loss_strategies.py b/isaaclab_arena/relations/relation_loss_strategies.py index e622d11db0..f4abd97875 100644 --- a/isaaclab_arena/relations/relation_loss_strategies.py +++ b/isaaclab_arena/relations/relation_loss_strategies.py @@ -454,15 +454,12 @@ class NoCollisionLossStrategy: def __init__( self, slope: float = 10.0, - debug: bool = False, ): """ Args: slope: Gradient magnitude for overlap loss. - debug: If True, print detailed AABB loss component breakdown. """ self.slope = slope - self.debug = debug def compute_loss_batched( self, diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index 2a911d217d..f0fa051d8b 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -9,7 +9,7 @@ import time import torch from dataclasses import dataclass -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, NamedTuple, cast import warp as wp from isaaclab.utils.math import quat_apply, quat_apply_inverse @@ -46,6 +46,23 @@ class NoOverlapPair: obstacle_max: torch.Tensor +class MeshPairEntry(NamedTuple): + """One directed sphere-to-mesh pair collected during cache construction.""" + + subject: ObjectBase + obstacle: ObjectBase + is_anchor: bool + anchor_pos: torch.Tensor | None # (3,) world position, or None for non-anchors + anchor_yaw: float + centers_local: torch.Tensor # (S, 3) sphere centers in subject-local frame + radii: torch.Tensor # (S,) sphere radii + subject_bbox_min: torch.Tensor # (B, 3) subject bbox min corners, B = batch_size + subject_bbox_max: torch.Tensor # (B, 3) + obstacle_bbox_min: torch.Tensor # (B, 3) obstacle bbox min corners + obstacle_bbox_max: torch.Tensor # (B, 3) + warp_mesh: object # wp.Mesh (untyped to avoid import at runtime) + + class RelationSolver: """Differentiable solver for 3D spatial relations of IsaacLab Arena Objects. @@ -74,8 +91,8 @@ def __init__( self._mesh_orientations: list[dict[ObjectBase, float]] | None = None self._warned_no_mesh: set[str] = set() self._mesh_manager: WarpMeshAndSphereCache | None = None - self._mesh_cache_fwd: MeshPairCache | None = None - self._mesh_cache_rev: MeshPairCache | None = None + self._mesh_cache_forward: MeshPairCache | None = None + self._mesh_cache_reverse: MeshPairCache | None = None def _get_strategy(self, relation: RelationBase) -> RelationLossStrategy | UnaryRelationLossStrategy: """Look up the loss strategy for a relation type; raises ValueError if none registered. @@ -189,6 +206,7 @@ def _compute_no_overlap_loss_aabb( non_anchor_objects = state.optimizable_objects anchor_objects = list(state.anchor_objects) + # Collect On-relation pairs to skip (stacked objects shouldn't repel each other). on_pairs: set[tuple[int, int]] = set() for obj in [*non_anchor_objects, *anchor_objects]: for rel in obj.get_relations(): @@ -285,34 +303,24 @@ def _prepare_mesh_collision_cache( non_anchor_objects = state.optimizable_objects anchor_objects = list(state.anchor_objects) - self._mesh_cache_fwd = self._build_vectorized_cache( - state, manager, non_anchor_objects, anchor_objects, on_pairs, device, direction="fwd" - ) - self._mesh_cache_rev = self._build_vectorized_cache( - state, manager, non_anchor_objects, anchor_objects, on_pairs, device, direction="rev" + forward_pairs, reverse_pairs = self._collect_mesh_pairs( + state, manager, non_anchor_objects, anchor_objects, on_pairs, device ) + self._mesh_cache_forward = self._finalize_mesh_cache(forward_pairs, device) + self._mesh_cache_reverse = self._finalize_mesh_cache(reverse_pairs, device) - def _build_vectorized_cache( - self, state, manager, non_anchor_objects, anchor_objects, on_pairs, device, direction: str - ) -> MeshPairCache | None: - """Build the MeshPairCache for forward or reverse assignment; None when no pairs qualify.""" - centers_list: list[torch.Tensor] = [] - radii_list: list[torch.Tensor] = [] - pair_child_objs: list = [] - pair_parent_objs: list = [] - pair_is_anchor: list[bool] = [] - pair_anchor_pos: list[torch.Tensor | None] = [] - pair_anchor_yaw: list[float] = [] - pair_c_bbox_min: list[torch.Tensor] = [] - pair_c_bbox_max: list[torch.Tensor] = [] - pair_p_bbox_min: list[torch.Tensor] = [] - pair_p_bbox_max: list[torch.Tensor] = [] - pair_max_r: list[float] = [] - mesh_id_map: dict[int, int] = {} - mesh_id_values: list[int] = [] - mesh_idx_per_sphere: list[int] = [] - pair_slices: list[tuple[int, int]] = [] - offset = 0 + def _collect_mesh_pairs( + self, + state: RelationSolverState, + manager: WarpMeshAndSphereCache, + non_anchor_objects: list, + anchor_objects: list, + on_pairs: set[tuple[int, int]], + device: torch.device, + ) -> tuple[list[MeshPairEntry], list[MeshPairEntry]]: + """Collect forward and reverse mesh pairs in a single pass.""" + forward_pairs: list[MeshPairEntry] = [] + reverse_pairs: list[MeshPairEntry] = [] for i, child in enumerate(non_anchor_objects): child_mesh = manager.get_collision_mesh(child) @@ -325,165 +333,145 @@ def _build_vectorized_cache( child_centers_local = child_spheres[:, :3] child_radii = child_spheres[:, 3] child_bbox = state.get_bbox(child) - c_bbox_min = child_bbox.min_point.to(device) - c_bbox_max = child_bbox.max_point.to(device) - - if direction == "fwd": - for anchor in anchor_objects: - if (id(child), id(anchor)) in on_pairs: - continue - parent_mesh = manager.get_collision_mesh(anchor) - if parent_mesh is None: - if anchor.name not in self._warned_no_mesh: - self._warned_no_mesh.add(anchor.name) - print(f"[NoCollision] '{anchor.name}' has no collision mesh; pair will use AABB fallback.") - continue - warp_mesh = manager.get_warp_mesh(parent_mesh, obj=anchor) - parent_bbox = state.get_bbox(anchor) - p_bbox_min = parent_bbox.min_point.to(device) - p_bbox_max = parent_bbox.max_point.to(device) - pose = anchor.get_initial_pose() - assert pose is not None and isinstance( - pose, Pose - ), f"MESH collision requires anchor '{anchor.name}' to have a fixed Pose initial_pose" - assert abs(pose.rotation_xyzw[0]) < 1e-6 and abs(pose.rotation_xyzw[1]) < 1e-6, ( - f"MESH collision requires anchor '{anchor.name}' to have identity or " - f"pure-Z rotation, got rotation_xyzw={pose.rotation_xyzw}. " - "Roll/pitch anchors are not supported in MESH mode." + c_bbox_min = child_bbox.min_point.to(device).expand(state.batch_size, 3) + c_bbox_max = child_bbox.max_point.to(device).expand(state.batch_size, 3) + + # Forward: child's spheres → anchor's mesh + for anchor in anchor_objects: + if (id(child), id(anchor)) in on_pairs: + continue + anchor_mesh = manager.get_collision_mesh(anchor) + if anchor_mesh is None: + if anchor.name not in self._warned_no_mesh: + self._warned_no_mesh.add(anchor.name) + print(f"[NoCollision] '{anchor.name}' has no collision mesh; pair will use AABB fallback.") + continue + pose = anchor.get_initial_pose() + assert pose is not None and isinstance( + pose, Pose + ), f"MESH collision requires anchor '{anchor.name}' to have a fixed Pose initial_pose" + assert abs(pose.rotation_xyzw[0]) < 1e-6 and abs(pose.rotation_xyzw[1]) < 1e-6, ( + f"MESH collision requires anchor '{anchor.name}' to have identity or " + f"pure-Z rotation, got rotation_xyzw={pose.rotation_xyzw}. " + "Roll/pitch anchors are not supported in MESH mode." + ) + anchor_bbox = state.get_bbox(anchor) + forward_pairs.append( + MeshPairEntry( + subject=child, + obstacle=anchor, + is_anchor=True, + anchor_pos=torch.tensor(pose.position_xyz, dtype=torch.float32, device=device), + anchor_yaw=yaw_from_quat_xyzw(pose.rotation_xyzw), + centers_local=child_centers_local, + radii=child_radii, + subject_bbox_min=c_bbox_min, + subject_bbox_max=c_bbox_max, + obstacle_bbox_min=anchor_bbox.min_point.to(device).expand(state.batch_size, 3), + obstacle_bbox_max=anchor_bbox.max_point.to(device).expand(state.batch_size, 3), + warp_mesh=manager.get_warp_mesh(anchor_mesh, obj=anchor), + ) + ) + + # Forward + Reverse: non-anchor pairs (bidirectional gradient) + for j in range(i + 1, len(non_anchor_objects)): + other = non_anchor_objects[j] + if (id(child), id(other)) in on_pairs: + continue + other_mesh = manager.get_collision_mesh(other) + if other_mesh is None: + if other.name not in self._warned_no_mesh: + self._warned_no_mesh.add(other.name) + print(f"[NoCollision] '{other.name}' has no collision mesh; pair will use AABB fallback.") + continue + other_bbox = state.get_bbox(other) + o_bbox_min = other_bbox.min_point.to(device).expand(state.batch_size, 3) + o_bbox_max = other_bbox.max_point.to(device).expand(state.batch_size, 3) + + # forward: child's spheres → other's mesh + forward_pairs.append( + MeshPairEntry( + subject=child, + obstacle=other, + is_anchor=False, + anchor_pos=None, + anchor_yaw=0.0, + centers_local=child_centers_local, + radii=child_radii, + subject_bbox_min=c_bbox_min, + subject_bbox_max=c_bbox_max, + obstacle_bbox_min=o_bbox_min, + obstacle_bbox_max=o_bbox_max, + warp_mesh=manager.get_warp_mesh(other_mesh, obj=other), ) - anchor_pos = torch.tensor(pose.position_xyz, dtype=torch.float32, device=device) - anchor_yaw = yaw_from_quat_xyzw(pose.rotation_xyzw) - - n_spheres = child_centers_local.shape[0] - mesh_key = id(warp_mesh) - if mesh_key not in mesh_id_map: - mesh_id_map[mesh_key] = len(mesh_id_values) - mesh_id_values.append(warp_mesh.id) - mesh_idx = mesh_id_map[mesh_key] - - centers_list.append(child_centers_local) - radii_list.append(child_radii) - pair_child_objs.append(child) - pair_parent_objs.append(anchor) - pair_is_anchor.append(True) - pair_anchor_pos.append(anchor_pos) - pair_anchor_yaw.append(anchor_yaw) - pair_c_bbox_min.append(c_bbox_min) - pair_c_bbox_max.append(c_bbox_max) - pair_p_bbox_min.append(p_bbox_min) - pair_p_bbox_max.append(p_bbox_max) - pair_max_r.append(child_radii.max().item()) - mesh_idx_per_sphere.extend([mesh_idx] * n_spheres) - pair_slices.append((offset, offset + n_spheres)) - offset += n_spheres - - for j in range(i + 1, len(non_anchor_objects)): - other = non_anchor_objects[j] - if (id(child), id(other)) in on_pairs: - continue - other_mesh = manager.get_collision_mesh(other) - if other_mesh is None: - if other.name not in self._warned_no_mesh: - self._warned_no_mesh.add(other.name) - print(f"[NoCollision] '{other.name}' has no collision mesh; pair will use AABB fallback.") - continue - warp_mesh = manager.get_warp_mesh(other_mesh, obj=other) - other_bbox = state.get_bbox(other) - p_bbox_min = other_bbox.min_point.to(device) - p_bbox_max = other_bbox.max_point.to(device) - - n_spheres = child_centers_local.shape[0] - mesh_key = id(warp_mesh) - if mesh_key not in mesh_id_map: - mesh_id_map[mesh_key] = len(mesh_id_values) - mesh_id_values.append(warp_mesh.id) - mesh_idx = mesh_id_map[mesh_key] - - centers_list.append(child_centers_local) - radii_list.append(child_radii) - pair_child_objs.append(child) - pair_parent_objs.append(other) - pair_is_anchor.append(False) - pair_anchor_pos.append(None) - pair_anchor_yaw.append(0.0) - pair_c_bbox_min.append(c_bbox_min) - pair_c_bbox_max.append(c_bbox_max) - pair_p_bbox_min.append(p_bbox_min) - pair_p_bbox_max.append(p_bbox_max) - pair_max_r.append(child_radii.max().item()) - mesh_idx_per_sphere.extend([mesh_idx] * n_spheres) - pair_slices.append((offset, offset + n_spheres)) - offset += n_spheres - - else: # direction == "rev" - for j in range(i + 1, len(non_anchor_objects)): - other = non_anchor_objects[j] - if (id(child), id(other)) in on_pairs: - continue - other_mesh = manager.get_collision_mesh(other) - if other_mesh is None: - if other.name not in self._warned_no_mesh: - self._warned_no_mesh.add(other.name) - print(f"[NoCollision] '{other.name}' has no collision mesh; pair will use AABB fallback.") - continue - other_spheres = manager.get_query_spheres(other_mesh, obj=other).to(device) - other_centers_local = other_spheres[:, :3] - other_radii = other_spheres[:, 3] - warp_mesh = manager.get_warp_mesh(child_mesh, obj=child) - other_bbox = state.get_bbox(other) - o_bbox_min = other_bbox.min_point.to(device) - o_bbox_max = other_bbox.max_point.to(device) - - n_spheres = other_centers_local.shape[0] - mesh_key = id(warp_mesh) - if mesh_key not in mesh_id_map: - mesh_id_map[mesh_key] = len(mesh_id_values) - mesh_id_values.append(warp_mesh.id) - mesh_idx = mesh_id_map[mesh_key] - - centers_list.append(other_centers_local) - radii_list.append(other_radii) - pair_child_objs.append(other) - pair_parent_objs.append(child) - pair_is_anchor.append(False) - pair_anchor_pos.append(None) - pair_anchor_yaw.append(0.0) - pair_c_bbox_min.append(o_bbox_min) - pair_c_bbox_max.append(o_bbox_max) - pair_p_bbox_min.append(c_bbox_min) - pair_p_bbox_max.append(c_bbox_max) - pair_max_r.append(other_radii.max().item()) - mesh_idx_per_sphere.extend([mesh_idx] * n_spheres) - pair_slices.append((offset, offset + n_spheres)) - offset += n_spheres - - if not centers_list: + ) + + # reverse: other's spheres → child's mesh + other_spheres = manager.get_query_spheres(other_mesh, obj=other).to(device) + reverse_pairs.append( + MeshPairEntry( + subject=other, + obstacle=child, + is_anchor=False, + anchor_pos=None, + anchor_yaw=0.0, + centers_local=other_spheres[:, :3], + radii=other_spheres[:, 3], + subject_bbox_min=o_bbox_min, + subject_bbox_max=o_bbox_max, + obstacle_bbox_min=c_bbox_min, + obstacle_bbox_max=c_bbox_max, + warp_mesh=manager.get_warp_mesh(child_mesh, obj=child), + ) + ) + + return forward_pairs, reverse_pairs + + @staticmethod + def _finalize_mesh_cache(entries: list[MeshPairEntry], device: torch.device) -> MeshPairCache | None: + """Stack collected pair entries into a MeshPairCache; None when no pairs qualify.""" + if not entries: return None - wp_device = str(device) + mesh_id_map: dict[int, int] = {} + mesh_id_values: list[int] = [] + mesh_idx_per_sphere: list[int] = [] + pair_slices: list[tuple[int, int]] = [] + offset = 0 + + for entry in entries: + n_spheres = entry.centers_local.shape[0] + mesh_key = id(entry.warp_mesh) + if mesh_key not in mesh_id_map: + mesh_id_map[mesh_key] = len(mesh_id_values) + mesh_id_values.append(entry.warp_mesh.id) + mesh_idx_per_sphere.extend([mesh_id_map[mesh_key]] * n_spheres) + pair_slices.append((offset, offset + n_spheres)) + offset += n_spheres + pair_sphere_count = torch.tensor([e - s for s, e in pair_slices], dtype=torch.float32, device=device) sphere_pair_id = torch.repeat_interleave( torch.arange(len(pair_slices), device=device), pair_sphere_count.long() ) return MeshPairCache( - all_centers_local=torch.cat(centers_list, dim=0), - all_radii=torch.cat(radii_list, dim=0), - pair_child_objs=pair_child_objs, - pair_parent_objs=pair_parent_objs, - pair_is_anchor=pair_is_anchor, - pair_anchor_pos=pair_anchor_pos, - pair_anchor_yaw=pair_anchor_yaw, - pair_c_bbox_min=torch.stack(pair_c_bbox_min), - pair_c_bbox_max=torch.stack(pair_c_bbox_max), - pair_p_bbox_min=torch.stack(pair_p_bbox_min), - pair_p_bbox_max=torch.stack(pair_p_bbox_max), - pair_max_r=torch.tensor(pair_max_r, device=device), + all_centers_local=torch.cat([e.centers_local for e in entries], dim=0), + all_radii=torch.cat([e.radii for e in entries], dim=0), + pair_subject_objs=[e.subject for e in entries], + pair_obstacle_objs=[e.obstacle for e in entries], + pair_is_anchor=[e.is_anchor for e in entries], + pair_anchor_pos=[e.anchor_pos for e in entries], + pair_anchor_yaw=[e.anchor_yaw for e in entries], + pair_subject_bbox_min=torch.stack([e.subject_bbox_min for e in entries]), + pair_subject_bbox_max=torch.stack([e.subject_bbox_max for e in entries]), + pair_obstacle_bbox_min=torch.stack([e.obstacle_bbox_min for e in entries]), + pair_obstacle_bbox_max=torch.stack([e.obstacle_bbox_max for e in entries]), + pair_max_radius=torch.tensor([e.radii.max().item() for e in entries], device=device), sphere_pair_id=sphere_pair_id, sphere_mesh_idx=torch.tensor(mesh_idx_per_sphere, dtype=torch.int32, device=device), pair_sphere_count=pair_sphere_count, - mesh_id_array=wp.array(np.array(mesh_id_values, dtype=np.uint64), dtype=wp.uint64, device=wp_device), - num_pairs=len(pair_slices), + mesh_id_array=wp.array(np.array(mesh_id_values, dtype=np.uint64), dtype=wp.uint64, device=str(device)), + num_pairs=len(entries), total_spheres=offset, ) @@ -492,27 +480,29 @@ def _compute_no_overlap_loss_mesh( state: RelationSolverState, debug: bool, ) -> torch.Tensor: - """Sphere-to-SDF penetration loss via the vectorized multi-mesh kernel.""" + """Per-env sphere-to-SDF penetration loss; iterates envs, calls the multi-mesh kernel per batch.""" device = state.device total_loss = torch.zeros(state.batch_size, device=device, dtype=torch.float32) clearance_m = self.params.clearance_m slope = self._no_collision_strategy.slope + # Per-env loop (not batched like AABB): per-env yaw and active-pair masking each produce a + # different sphere subset before the kernel launch, so envs cannot be collapsed into one call. for b in range(state.batch_size): - for cache in (self._mesh_cache_fwd, self._mesh_cache_rev): + for cache in (self._mesh_cache_forward, self._mesh_cache_reverse): if cache is None: continue num_pairs = cache.num_pairs - child_positions = torch.stack( - [state.get_position(cache.pair_child_objs[p])[b] for p in range(num_pairs)] + subject_positions = torch.stack( + [state.get_position(cache.pair_subject_objs[p])[b] for p in range(num_pairs)] ) - parent_positions = torch.stack([ + obstacle_positions = torch.stack([ ( cache.pair_anchor_pos[p] if cache.pair_is_anchor[p] - else state.get_position(cache.pair_parent_objs[p])[b].detach() + else state.get_position(cache.pair_obstacle_objs[p])[b].detach() ) for p in range(num_pairs) ]) @@ -521,52 +511,51 @@ def _compute_no_overlap_loss_mesh( has_any_yaw = self._mesh_orientations is not None or any(y != 0.0 for y in anchor_yaws) if has_any_yaw: ori_b = self._mesh_orientations[b] if self._mesh_orientations is not None else {} - child_yaws = torch.tensor( - [ori_b.get(cache.pair_child_objs[p], 0.0) for p in range(num_pairs)], + subject_yaws = torch.tensor( + [ori_b.get(cache.pair_subject_objs[p], 0.0) for p in range(num_pairs)], dtype=torch.float32, device=device, ) - parent_yaws = torch.tensor( - [ori_b.get(cache.pair_parent_objs[p], anchor_yaws[p]) for p in range(num_pairs)], + obstacle_yaws = torch.tensor( + [ori_b.get(cache.pair_obstacle_objs[p], anchor_yaws[p]) for p in range(num_pairs)], dtype=torch.float32, device=device, ) # AABB broadphase (yaw-aware): skip separated pairs. - margins = cache.pair_max_r + clearance_m - batch_idx = min(b, cache.pair_c_bbox_min.shape[1] - 1) - c_bbox_min = cache.pair_c_bbox_min[:, batch_idx, :] - c_bbox_max = cache.pair_c_bbox_max[:, batch_idx, :] - p_bbox_min = cache.pair_p_bbox_min[:, batch_idx, :] - p_bbox_max = cache.pair_p_bbox_max[:, batch_idx, :] + margins = cache.pair_max_radius + clearance_m + s_bbox_min = cache.pair_subject_bbox_min[:, b, :] + s_bbox_max = cache.pair_subject_bbox_max[:, b, :] + o_bbox_min = cache.pair_obstacle_bbox_min[:, b, :] + o_bbox_max = cache.pair_obstacle_bbox_max[:, b, :] if has_any_yaw: - c_bbox_min, c_bbox_max = self._rotate_bbox_extents(c_bbox_min, c_bbox_max, child_yaws) - p_bbox_min, p_bbox_max = self._rotate_bbox_extents(p_bbox_min, p_bbox_max, parent_yaws) + s_bbox_min, s_bbox_max = self._rotate_bbox_extents(s_bbox_min, s_bbox_max, subject_yaws) + o_bbox_min, o_bbox_max = self._rotate_bbox_extents(o_bbox_min, o_bbox_max, obstacle_yaws) - child_min = child_positions + c_bbox_min - child_max = child_positions + c_bbox_max - parent_min = parent_positions + p_bbox_min - parent_max = parent_positions + p_bbox_max + subject_min = subject_positions + s_bbox_min + subject_max = subject_positions + s_bbox_max + obstacle_min = obstacle_positions + o_bbox_min + obstacle_max = obstacle_positions + o_bbox_max - sep_child = (child_min - margins.unsqueeze(1)) > parent_max - sep_parent = (parent_min - margins.unsqueeze(1)) > child_max - separated = sep_child.any(dim=1) | sep_parent.any(dim=1) + sep_subject = (subject_min - margins.unsqueeze(1)) > obstacle_max + sep_obstacle = (obstacle_min - margins.unsqueeze(1)) > subject_max + separated = sep_subject.any(dim=1) | sep_obstacle.any(dim=1) active_pair = ~separated if not active_pair.any(): continue - offsets = child_positions - parent_positions + offsets = subject_positions - obstacle_positions sphere_active_mask = active_pair[cache.sphere_pair_id] active_idx = sphere_active_mask.nonzero(as_tuple=True)[0] active_sphere_pair_id = cache.sphere_pair_id[active_idx] local_centers = cache.all_centers_local[active_idx] - # R(child_yaw - parent_yaw) · local + R(-parent_yaw) · offset + # R(subject_yaw - obstacle_yaw) · local + R(-obstacle_yaw) · offset if has_any_yaw: - net_yaws = (child_yaws - parent_yaws)[active_sphere_pair_id] + net_yaws = (subject_yaws - obstacle_yaws)[active_sphere_pair_id] half_net = net_yaws / 2.0 q_net_z = torch.zeros(len(half_net), 4, device=device, dtype=local_centers.dtype) q_net_z[:, 2] = torch.sin(half_net) @@ -574,12 +563,12 @@ def _compute_no_overlap_loss_mesh( local_centers = quat_apply(q_net_z, local_centers) pair_offsets = offsets[active_sphere_pair_id] - p_yaws = parent_yaws[active_sphere_pair_id] - half_p = p_yaws / 2.0 - q_parent_z = torch.zeros(len(half_p), 4, device=device, dtype=local_centers.dtype) - q_parent_z[:, 2] = torch.sin(half_p) - q_parent_z[:, 3] = torch.cos(half_p) - rotated_offsets = quat_apply_inverse(q_parent_z, pair_offsets) + obs_yaws = obstacle_yaws[active_sphere_pair_id] + half_o = obs_yaws / 2.0 + q_obstacle_z = torch.zeros(len(half_o), 4, device=device, dtype=local_centers.dtype) + q_obstacle_z[:, 2] = torch.sin(half_o) + q_obstacle_z[:, 3] = torch.cos(half_o) + rotated_offsets = quat_apply_inverse(q_obstacle_z, pair_offsets) active_centers = local_centers + rotated_offsets else: active_centers = local_centers + offsets[active_sphere_pair_id] @@ -735,7 +724,6 @@ def solve( f" | iters={iters_run} ({solve_elapsed_ms / iters_run:.2f} ms/iter)" ) - # Store metadata for optional access self._last_loss_history = loss_history self._last_position_history = position_history diff --git a/isaaclab_arena/tests/test_mesh_collision.py b/isaaclab_arena/tests/test_mesh_collision.py index 178133016e..a00046bee3 100644 --- a/isaaclab_arena/tests/test_mesh_collision.py +++ b/isaaclab_arena/tests/test_mesh_collision.py @@ -74,7 +74,6 @@ def _make_table() -> DummyObject: def test_sphere_decomposition_covers_surface(): - """Sphere decomposition should cover >80% of surface sample points.""" mesh = trimesh.creation.cylinder(radius=0.05, height=0.1, sections=32) spheres = greedy_sphere_decomposition(mesh, num_spheres=20, n_surface=500) assert spheres.shape[1] == 4 @@ -95,7 +94,6 @@ def test_sphere_decomposition_covers_surface(): @requires_warp def test_warp_mesh_caching(): - """Same mesh object should return identical Warp mesh from cache.""" mesh = trimesh.creation.box(extents=(0.1, 0.1, 0.1)) manager = WarpMeshAndSphereCache(num_spheres=10) m1 = manager.get_warp_mesh(mesh) @@ -103,14 +101,24 @@ def test_warp_mesh_caching(): assert m1 is m2 +def _batched_aabb_loss(strategy, clearance_m, child_pos, child_bbox, parent_world_bbox): + """Helper: compute single-pair AABB loss via compute_loss_batched.""" + subject_min = (child_pos + child_bbox.min_point).unsqueeze(0).unsqueeze(0) + subject_max = (child_pos + child_bbox.max_point).unsqueeze(0).unsqueeze(0) + obstacle_min = parent_world_bbox.min_point.unsqueeze(0) + obstacle_max = parent_world_bbox.max_point.unsqueeze(0) + loss = strategy.compute_loss_batched(clearance_m, subject_min, subject_max, obstacle_min, obstacle_max) + return loss.squeeze() + + def test_aabb_zero_loss_well_separated(): - """AABB loss is zero when objects are well separated.""" strategy = NoCollisionLossStrategy(slope=10.0) obj_a = _make_cylinder("a") obj_b = _make_cylinder("b") - loss = strategy.compute_loss( + loss = _batched_aabb_loss( + strategy, clearance_m=0.0, child_pos=torch.tensor([0.0, 0.0, 0.0]), child_bbox=obj_a.get_bounding_box(), @@ -120,12 +128,12 @@ def test_aabb_zero_loss_well_separated(): def test_aabb_positive_loss_fully_overlapping(): - """AABB loss is positive when two objects fully overlap.""" strategy = NoCollisionLossStrategy(slope=10000.0) a = _make_cylinder("a") b = _make_cylinder("b") - loss = strategy.compute_loss( + loss = _batched_aabb_loss( + strategy, clearance_m=0.0, child_pos=torch.tensor([0.0, 0.0, 0.0]), child_bbox=a.get_bounding_box(), @@ -135,22 +143,21 @@ def test_aabb_positive_loss_fully_overlapping(): def test_aabb_clearance_m_increases_loss(): - """Near-miss cylinders should have positive AABB loss when clearance_m > 0.""" strategy = NoCollisionLossStrategy(slope=10000.0) a = _make_cylinder("a", radius=0.03) b = _make_cylinder("b", radius=0.03) child_pos = torch.tensor([0.0, 0.0, 0.0]) parent_world_bbox = b.get_bounding_box().translated((0.07, 0.0, 0.0)) - # Separated (0.07 > 0.03+0.03): zero AABB loss without clearance - loss_no_clearance = strategy.compute_loss( + loss_no_clearance = _batched_aabb_loss( + strategy, clearance_m=0.0, child_pos=child_pos, child_bbox=a.get_bounding_box(), parent_world_bbox=parent_world_bbox, ) - # With clearance: boxes expand, loss should be positive - loss_with_clearance = strategy.compute_loss( + loss_with_clearance = _batched_aabb_loss( + strategy, clearance_m=0.05, child_pos=child_pos, child_bbox=a.get_bounding_box(), @@ -164,7 +171,6 @@ def test_aabb_clearance_m_increases_loss(): @requires_warp def test_solver_separates_overlapping_cylinders_mesh_mode(): - """RelationSolver with MESH mode should push overlapping cylinders apart.""" table = _make_table() a = _make_cylinder("cyl_a") b = _make_cylinder("cyl_b") @@ -190,7 +196,6 @@ def test_solver_separates_overlapping_cylinders_mesh_mode(): @requires_warp def test_on_pairs_skipped_in_mesh_mode(): - """On-linked pairs should not be penalized in mesh mode (same as AABB).""" table = _make_table() obj = _make_cylinder("can") obj.add_relation(On(table)) @@ -321,7 +326,6 @@ def test_object_placer_mesh_mode_end_to_end(): @requires_warp def test_validate_no_overlap_mesh_catches_overlap(): - """Direct test of _validate_no_overlap_mesh: overlapping cylinders should fail validation.""" from isaaclab_arena.relations.object_placer import ObjectPlacer from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams @@ -416,7 +420,7 @@ def test_validate_no_overlap_mesh_respects_anchor_yaw(): @requires_warp def test_mesh_sdf_backward_gradient(): - """mesh_sdf backward should produce non-zero gradients pointing outward for interior points.""" + """mesh_sdf backward gradient points toward the nearest face for an interior point.""" from isaaclab_arena.relations.warp_sdf_kernels import mesh_sdf mesh = trimesh.creation.box(extents=(0.2, 0.2, 0.2)) @@ -474,7 +478,7 @@ def test_solver_mesh_batch_size_two(): @requires_warp def test_broadphase_skips_separated_pairs(): - """Well-separated objects produce zero mesh loss from the solver path.""" + """Separated objects have lower mesh loss than overlapping ones (broadphase filters pairs).""" table = _make_table() a = _make_cylinder("a", radius=0.03) b = _make_cylinder("b", radius=0.03) @@ -486,9 +490,6 @@ def test_broadphase_skips_separated_pairs(): solver = RelationSolver(params=RelationSolverParams(collision_mode=CollisionMode.MESH, max_iters=0, verbose=False)) solver.solve([table, a, b], initial) - # With max_iters=0, loss is from initial positions. - # Objects are well separated, so collision loss should be minimal - # (only On-relation losses contribute). loss = solver.last_loss_per_env[0].item() # Compare with an overlapping case to confirm broadphase actually filters @@ -502,7 +503,6 @@ def test_broadphase_skips_separated_pairs(): @requires_warp def test_broadphase_does_not_skip_overlapping_pairs(): - """Overlapping objects must produce nonzero mesh loss from the solver path.""" table = _make_table() a = _make_cylinder("a", radius=0.03) b = _make_cylinder("b", radius=0.03) @@ -520,7 +520,7 @@ def test_broadphase_does_not_skip_overlapping_pairs(): @requires_warp def test_multi_mesh_sdf_distinct_meshes(): - """Verify mesh_indices routes queries to different meshes (not stuck at index 0).""" + """Regression: mesh_indices routing could silently query only mesh 0.""" from isaaclab_arena.relations.warp_sdf_kernels import multi_mesh_sdf # Tall cylinder vs flat box — maximally different SDF at the query point. @@ -550,7 +550,6 @@ def test_multi_mesh_sdf_distinct_meshes(): @requires_warp def test_multi_mesh_sdf_backward(): - """Backward through multi_mesh_sdf produces correct gradient direction.""" from isaaclab_arena.relations.warp_sdf_kernels import multi_mesh_sdf mesh = trimesh.creation.cylinder(radius=0.05, height=0.1, sections=32) From 52f2914f2353202ef83658ab62b82e3a42c92c7b Mon Sep 17 00:00:00 2001 From: zhx06 Date: Mon, 29 Jun 2026 09:56:16 -0700 Subject: [PATCH 4/7] improve test cases, edit docstring Signed-off-by: zhx06 --- isaaclab_arena/relations/relation_solver.py | 9 ++++----- isaaclab_arena/relations/warp_mesh_manager.py | 1 + isaaclab_arena/tests/test_mesh_collision.py | 3 --- .../tests/test_object_placer_reproducibility.py | 1 - isaaclab_arena/tests/test_usd_scale_helpers.py | 8 +------- 5 files changed, 6 insertions(+), 16 deletions(-) diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index f0fa051d8b..7bda28caaa 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -55,8 +55,8 @@ class MeshPairEntry(NamedTuple): anchor_pos: torch.Tensor | None # (3,) world position, or None for non-anchors anchor_yaw: float centers_local: torch.Tensor # (S, 3) sphere centers in subject-local frame - radii: torch.Tensor # (S,) sphere radii - subject_bbox_min: torch.Tensor # (B, 3) subject bbox min corners, B = batch_size + radii: torch.Tensor # (S,) + subject_bbox_min: torch.Tensor # (B, 3) subject bbox min corners subject_bbox_max: torch.Tensor # (B, 3) obstacle_bbox_min: torch.Tensor # (B, 3) obstacle bbox min corners obstacle_bbox_max: torch.Tensor # (B, 3) @@ -480,14 +480,13 @@ def _compute_no_overlap_loss_mesh( state: RelationSolverState, debug: bool, ) -> torch.Tensor: - """Per-env sphere-to-SDF penetration loss; iterates envs, calls the multi-mesh kernel per batch.""" + """Per-env sphere-to-SDF penetration loss.""" device = state.device total_loss = torch.zeros(state.batch_size, device=device, dtype=torch.float32) clearance_m = self.params.clearance_m slope = self._no_collision_strategy.slope - # Per-env loop (not batched like AABB): per-env yaw and active-pair masking each produce a - # different sphere subset before the kernel launch, so envs cannot be collapsed into one call. + # Per-env loop (not batched like AABB): per-env yaw and active-pair masking produce a different sphere subset per env. for b in range(state.batch_size): for cache in (self._mesh_cache_forward, self._mesh_cache_reverse): if cache is None: diff --git a/isaaclab_arena/relations/warp_mesh_manager.py b/isaaclab_arena/relations/warp_mesh_manager.py index bba8bc0c05..8ca2fb50f9 100644 --- a/isaaclab_arena/relations/warp_mesh_manager.py +++ b/isaaclab_arena/relations/warp_mesh_manager.py @@ -144,6 +144,7 @@ def warn_sdf_sentinel(self, sdf_values: torch.Tensor) -> None: def get_collision_mesh(self, obj: ObjectBase) -> trimesh.Trimesh | None: """Return the cached collision mesh, extracting from USD on first access.""" + # ObjectBase doesn't guarantee usd_path; only Object subclasses set it. usd_path = getattr(obj, "usd_path", None) if usd_path is None: return obj.get_collision_mesh() diff --git a/isaaclab_arena/tests/test_mesh_collision.py b/isaaclab_arena/tests/test_mesh_collision.py index a00046bee3..c60154c9f0 100644 --- a/isaaclab_arena/tests/test_mesh_collision.py +++ b/isaaclab_arena/tests/test_mesh_collision.py @@ -304,7 +304,6 @@ def test_centers_in_target_frame_applies_both_yaws(): @requires_warp def test_object_placer_mesh_mode_end_to_end(): - """ObjectPlacer.place() with CollisionMode.MESH returns a valid result.""" from isaaclab_arena.relations.object_placer import ObjectPlacer from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams @@ -442,7 +441,6 @@ def test_mesh_sdf_backward_gradient(): @requires_warp def test_solver_mesh_batch_size_two(): - """Solver MESH mode handles batch_size > 1 (both envs solved independently).""" table = _make_table() a = _make_cylinder("cyl_a") b = _make_cylinder("cyl_b") @@ -478,7 +476,6 @@ def test_solver_mesh_batch_size_two(): @requires_warp def test_broadphase_skips_separated_pairs(): - """Separated objects have lower mesh loss than overlapping ones (broadphase filters pairs).""" table = _make_table() a = _make_cylinder("a", radius=0.03) b = _make_cylinder("b", radius=0.03) diff --git a/isaaclab_arena/tests/test_object_placer_reproducibility.py b/isaaclab_arena/tests/test_object_placer_reproducibility.py index 5b77028681..d11331564d 100644 --- a/isaaclab_arena/tests/test_object_placer_reproducibility.py +++ b/isaaclab_arena/tests/test_object_placer_reproducibility.py @@ -306,7 +306,6 @@ def test_random_yaw_init_applied_yaw_matches_selected_candidate(): def test_random_yaw_init_composes_marker_yaw(): - """orientations dict carries total yaw (marker + sampled); applied pose matches it.""" marker_yaw = math.pi / 6 solver_params = RelationSolverParams(max_iters=10, verbose=False) desk, box1, box2 = _create_test_objects() diff --git a/isaaclab_arena/tests/test_usd_scale_helpers.py b/isaaclab_arena/tests/test_usd_scale_helpers.py index 66548ef7fd..b84e7a104f 100644 --- a/isaaclab_arena/tests/test_usd_scale_helpers.py +++ b/isaaclab_arena/tests/test_usd_scale_helpers.py @@ -99,12 +99,7 @@ def _test_extract_trimesh_translated_child_nonuniform_scale(simulation_app): def _test_bbox_translated_child_nonuniform_scale(simulation_app): - """BBox uses ComputeLocalBound * scale (post-transform scale on root-local extents). - - For a child translated +1 with verts ±0.5: root-local bound X=[0.5, 1.5], * scale_x=2 → [1.0, 3.0]. - Note: mesh path scales per-prim verts first → X=[0.0, 2.0]. These differ for translated children - under non-uniform scale. AABB is conservative (larger), which is safe for collision checks. - """ + """AABB is conservative vs mesh path for translated children under non-uniform scale.""" import tempfile from pxr import Gf, Usd, UsdGeom @@ -181,7 +176,6 @@ def _test_bbox_translated_child_nonuniform_scale(simulation_app): def _test_both_paths_agree_origin_prim(simulation_app): - """For an origin-centered single prim, mesh and bbox agree exactly.""" import tempfile from pxr import Gf, Usd, UsdGeom From a690c8f5988b827c99102225d1eb6035fbff7456 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Tue, 30 Jun 2026 08:43:08 -0700 Subject: [PATCH 5/7] address comments Signed-off-by: zhx06 --- isaaclab_arena/assets/dummy_object.py | 5 +-- isaaclab_arena/cli/isaaclab_arena_cli.py | 17 --------- .../environments/arena_env_builder.py | 15 +------- .../isaaclab_arena_environment.py | 4 +-- isaaclab_arena/relations/mesh_pair_cache.py | 35 ++++++++++--------- isaaclab_arena/relations/object_placer.py | 4 +-- isaaclab_arena/relations/relation_solver.py | 2 +- isaaclab_arena/relations/warp_mesh_manager.py | 2 +- 8 files changed, 27 insertions(+), 57 deletions(-) diff --git a/isaaclab_arena/assets/dummy_object.py b/isaaclab_arena/assets/dummy_object.py index 8c1349dfd5..728606fa40 100644 --- a/isaaclab_arena/assets/dummy_object.py +++ b/isaaclab_arena/assets/dummy_object.py @@ -5,15 +5,12 @@ from __future__ import annotations import torch -from typing import TYPE_CHECKING +import trimesh from isaaclab_arena.relations.relations import IsAnchor, Relation, RelationBase, UnaryRelation from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, quaternion_to_90_deg_z_quarters from isaaclab_arena.utils.pose import Pose -if TYPE_CHECKING: - import trimesh - class DummyObject: """Dummy object for testing purposes without Isaac Sim dependencies.""" diff --git a/isaaclab_arena/cli/isaaclab_arena_cli.py b/isaaclab_arena/cli/isaaclab_arena_cli.py index cb2846d9e6..ea697e1c8e 100644 --- a/isaaclab_arena/cli/isaaclab_arena_cli.py +++ b/isaaclab_arena/cli/isaaclab_arena_cli.py @@ -76,29 +76,12 @@ def add_isaaclab_arena_cli_args(parser: argparse.ArgumentParser) -> None: " layout." ), ) - arena_group.add_argument( - "--random_yaw_init", - action="store_true", - default=False, - help=( - "Randomly rotate objects (except anchors) around the Z-axis for scene variety. " - "Collisions use a larger enclosing box; the solver won't optimize this rotation. " - "Only affects objects positioned by the placement solver; manually-placed objects are unaffected." - ), - ) arena_group.add_argument( "--list-variations", action="store_true", default=False, help="Print Hydra-configurable variations for the selected environment and exit.", ) - arena_group.add_argument( - "--collision_mode", - type=str, - choices=["bbox", "mesh"], - default="bbox", - help="Collision detection mode: 'bbox' (AABB, default) or 'mesh' (sphere-to-SDF, requires Warp).", - ) def add_env_graph_spec_cli_args(parser: argparse.ArgumentParser) -> None: diff --git a/isaaclab_arena/environments/arena_env_builder.py b/isaaclab_arena/environments/arena_env_builder.py index 4536031afa..d4f04e8286 100644 --- a/isaaclab_arena/environments/arena_env_builder.py +++ b/isaaclab_arena/environments/arena_env_builder.py @@ -36,10 +36,8 @@ ) from isaaclab_arena.recording.common_terms import CoreEpisodeRecorderTermCfg, VariationEpisodeRecorderTermCfg from isaaclab_arena.recording.episode_recorder_manager import EpisodeRecorderTermCfg -from isaaclab_arena.relations.collision_mode import CollisionMode from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams from isaaclab_arena.relations.placement_events import PLACEMENT_RESET_EVENT_NAME -from isaaclab_arena.relations.relation_solver_params import RelationSolverParams from isaaclab_arena.tasks.no_task import NoTask from isaaclab_arena.utils.configclass import combine_configclass_instances, make_configclass from isaaclab_arena.utils.isaaclab_utils.simulation_app import reapply_viewer_cfg @@ -86,20 +84,9 @@ def _solve_relations(self) -> None: """ objects_with_relations = self.arena_env.scene.get_objects_with_relations() - # Prefer env-level placer_params; fall back to CLI-constructed defaults. placer_params = self.arena_env.placer_params if placer_params is None: - collision_mode_str = getattr(self.args, "collision_mode", "bbox") - mode = CollisionMode.MESH if collision_mode_str == "mesh" else CollisionMode.BBOX - placer_params = ObjectPlacerParams( - placement_seed=self.args.placement_seed, - random_yaw_init=self.args.random_yaw_init, - solver_params=RelationSolverParams( - collision_mode=mode, - save_position_history=False, - verbose=False, - ), - ) + placer_params = ObjectPlacerParams(placement_seed=self.args.placement_seed) if self.args.resolve_on_reset is not None: placer_params.resolve_on_reset = self.args.resolve_on_reset self._placement_event_cfg = solve_and_apply_relation_placement( diff --git a/isaaclab_arena/environments/isaaclab_arena_environment.py b/isaaclab_arena/environments/isaaclab_arena_environment.py index 708992eba0..5b76e2cc92 100644 --- a/isaaclab_arena/environments/isaaclab_arena_environment.py +++ b/isaaclab_arena/environments/isaaclab_arena_environment.py @@ -52,8 +52,8 @@ def __init__( ``"my_module:RLPolicyCfg"``. episode_recorder_terms: Additional per-episode recorder terms to record alongside the built-in ones, keyed by name. - placer_params: Object placement configuration. When set, used as-is - (CLI flags are ignored). When None, params are built from CLI flags. + placer_params: Object placement configuration. When None, default + ObjectPlacerParams are used. """ self.name = name self.scene = scene diff --git a/isaaclab_arena/relations/mesh_pair_cache.py b/isaaclab_arena/relations/mesh_pair_cache.py index edd9f3884c..b629ccb2bc 100644 --- a/isaaclab_arena/relations/mesh_pair_cache.py +++ b/isaaclab_arena/relations/mesh_pair_cache.py @@ -9,17 +9,20 @@ import torch from dataclasses import dataclass -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import warp as wp +import warp as wp - from isaaclab_arena.assets.object_base import ObjectBase +from isaaclab_arena.assets.object_base import ObjectBase @dataclass(slots=True) class MeshPairCache: - """Precomputed per-pair collision data for the vectorized multi-mesh kernel.""" + """Precomputed per-pair collision data for the vectorized multi-mesh kernel. + + Dimensions: P = num_pairs (ordered subject/obstacle pairs), B = batch_size (num envs), + S = total_spheres (sum of sphere counts across all P pairs; each subject object is decomposed + into multiple covering spheres via greedy_sphere_decomposition). + """ all_centers_local: torch.Tensor """(S, 3) sphere centers in each subject's local frame, concatenated across pairs.""" @@ -28,34 +31,34 @@ class MeshPairCache: """(S,) sphere radii, concatenated across pairs.""" pair_subject_objs: list[ObjectBase] - """Per-pair subject (sphere source) object reference.""" + """(P,) subject (sphere source) object reference per pair.""" pair_obstacle_objs: list[ObjectBase] - """Per-pair obstacle (mesh target) object reference.""" + """(P,) obstacle (mesh target) object reference per pair.""" pair_is_anchor: list[bool] - """Per-pair flag: True if the obstacle is a static anchor.""" + """(P,) True if the obstacle is a static anchor.""" pair_anchor_pos: list[torch.Tensor | None] - """Per-pair world position for anchor obstacles (None for non-anchor obstacles).""" + """(P,) world position for anchor obstacles (None for non-anchors).""" pair_anchor_yaw: list[float] - """Per-pair anchor yaw in radians (0.0 for non-anchor obstacles).""" + """(P,) anchor yaw in radians (0.0 for non-anchors).""" pair_subject_bbox_min: torch.Tensor - """(P, B, 3) subject bbox min corners for broadphase.""" + """(P, B, 3) subject bbox min corners for overlap filtering.""" pair_subject_bbox_max: torch.Tensor - """(P, B, 3) subject bbox max corners for broadphase.""" + """(P, B, 3) subject bbox max corners for overlap filtering.""" pair_obstacle_bbox_min: torch.Tensor - """(P, B, 3) obstacle bbox min corners for broadphase.""" + """(P, B, 3) obstacle bbox min corners for overlap filtering.""" pair_obstacle_bbox_max: torch.Tensor - """(P, B, 3) obstacle bbox max corners for broadphase.""" + """(P, B, 3) obstacle bbox max corners for overlap filtering.""" pair_max_radius: torch.Tensor - """(P,) max sphere radius per pair (broadphase margin).""" + """(P,) max sphere radius per pair (overlap filter margin).""" sphere_pair_id: torch.Tensor """(S,) maps each sphere to its pair index for segment reduction.""" @@ -67,7 +70,7 @@ class MeshPairCache: """(P,) number of spheres per pair (for mean reduction).""" mesh_id_array: wp.array - """Warp uint64 array of mesh IDs for the multi-mesh kernel.""" + """(num_unique_meshes,) Warp uint64 array of mesh IDs for the multi-mesh kernel.""" num_pairs: int """Total number of active object pairs.""" diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 93e9dfba78..6aadde0caa 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -37,7 +37,7 @@ @dataclass class PlacementCandidate: - """A scored solver result used for ranking inside ObjectPlacer.""" + """A candidate object layout with its solver loss and validation outcome.""" loss: float """Loss value returned by the solver.""" @@ -226,7 +226,7 @@ def _place_ranked( self._generate_initial_orientations(objects, anchor_objects_set, generator) ) - # Bake each candidate's total yaw into a conservative enclosing bbox (AABB broadphase). + # Bake each candidate's total yaw into a conservative enclosing bbox for overlap checks. candidate_bboxes = self._rotate_candidate_bboxes(objects, candidate_bboxes, orientations_per_candidate) all_positions = self._solver.solve( diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index 7bda28caaa..01cbd90013 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -47,7 +47,7 @@ class NoOverlapPair: class MeshPairEntry(NamedTuple): - """One directed sphere-to-mesh pair collected during cache construction.""" + """One directed sphere-to-mesh collision pair (subject spheres vs obstacle mesh).""" subject: ObjectBase obstacle: ObjectBase diff --git a/isaaclab_arena/relations/warp_mesh_manager.py b/isaaclab_arena/relations/warp_mesh_manager.py index 8ca2fb50f9..134eb876dc 100644 --- a/isaaclab_arena/relations/warp_mesh_manager.py +++ b/isaaclab_arena/relations/warp_mesh_manager.py @@ -110,7 +110,7 @@ def greedy_sphere_decomposition( class WarpMeshAndSphereCache: - """Cache for Warp BVH meshes and sphere decompositions used in mesh-based collision queries.""" + """Cache for Warp BVH meshes and sphere decompositions.""" def __init__( self, From 5b83c19e9a7cde573a1ca739f6ff23dae2568a9a Mon Sep 17 00:00:00 2001 From: zhx06 Date: Tue, 30 Jun 2026 09:58:17 -0700 Subject: [PATCH 6/7] fix import order Signed-off-by: zhx06 --- isaaclab_arena/relations/mesh_pair_cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/isaaclab_arena/relations/mesh_pair_cache.py b/isaaclab_arena/relations/mesh_pair_cache.py index b629ccb2bc..cdb3160e5c 100644 --- a/isaaclab_arena/relations/mesh_pair_cache.py +++ b/isaaclab_arena/relations/mesh_pair_cache.py @@ -9,10 +9,12 @@ import torch from dataclasses import dataclass +from typing import TYPE_CHECKING import warp as wp -from isaaclab_arena.assets.object_base import ObjectBase +if TYPE_CHECKING: + from isaaclab_arena.assets.object_base import ObjectBase @dataclass(slots=True) From 19aefa1ded1cf5c3277fcd6c0f41f5850d2c9cc0 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Tue, 30 Jun 2026 11:26:09 -0700 Subject: [PATCH 7/7] improve docstrings Signed-off-by: zhx06 --- isaaclab_arena/relations/relation_solver.py | 6 ++---- isaaclab_arena/relations/warp_sdf_kernels.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index 01cbd90013..b2f8acc1b5 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -521,7 +521,7 @@ def _compute_no_overlap_loss_mesh( device=device, ) - # AABB broadphase (yaw-aware): skip separated pairs. + # AABB overlap filter (yaw-aware): skip separated pairs. margins = cache.pair_max_radius + clearance_m s_bbox_min = cache.pair_subject_bbox_min[:, b, :] s_bbox_max = cache.pair_subject_bbox_max[:, b, :] @@ -664,8 +664,6 @@ def solve( on_pairs.add((id(rel.parent), id(obj))) self._mesh_orientations = orientations self._prepare_mesh_collision_cache(state, on_pairs) - - if self.params.collision_mode == CollisionMode.MESH: self._mesh_manager.reset_sentinel_warning() # Setup optimizer (only for optimizable positions) @@ -688,7 +686,7 @@ def solve( loss = self._compute_total_loss(state) loss_history.append(loss.item()) - # Constant-zero loss has no grad_fn — skip backward when broadphase culls all pairs. + # Constant-zero loss has no grad_fn — skip backward when overlap filter culls all pairs. if loss.grad_fn is not None: loss.backward() optimizer.step() diff --git a/isaaclab_arena/relations/warp_sdf_kernels.py b/isaaclab_arena/relations/warp_sdf_kernels.py index 1594b6e459..7169608696 100644 --- a/isaaclab_arena/relations/warp_sdf_kernels.py +++ b/isaaclab_arena/relations/warp_sdf_kernels.py @@ -119,7 +119,7 @@ def sdf_sentinel_count(sdf_values: torch.Tensor) -> int: def clamp_sdf_sentinel(sdf_values: torch.Tensor) -> torch.Tensor: - """Replace sentinel SDF values with 0 (treat as "on surface") so they produce gradient.""" + """Replace sentinel SDF values with 0 so no-face hits contribute zero loss rather than large positive.""" return torch.where(sdf_values >= _SDF_SENTINEL, torch.zeros_like(sdf_values), sdf_values)