From 655ac739ad18af82356a8a6743bcfcc9944eb9ab Mon Sep 17 00:00:00 2001 From: zhx06 Date: Tue, 9 Jun 2026 14:35:01 -0700 Subject: [PATCH 1/3] mesh support for non_collision --- isaaclab_arena/assets/dummy_object.py | 16 +- isaaclab_arena/assets/object_base.py | 12 + isaaclab_arena/relations/object_placer.py | 88 ++- .../relations/object_placer_params.py | 4 +- .../relations/relation_loss_strategies.py | 164 +++++- isaaclab_arena/relations/relation_solver.py | 15 +- .../relations/relation_solver_params.py | 17 + isaaclab_arena/relations/warp_mesh_manager.py | 164 ++++++ isaaclab_arena/relations/warp_sdf_kernels.py | 141 +++++ .../scripts/benchmark_collision_modes.py | 139 +++++ isaaclab_arena/tests/test_mesh_collision.py | 542 ++++++++++++++++++ isaaclab_arena/utils/usd_helpers.py | 71 +++ 12 files changed, 1340 insertions(+), 33 deletions(-) create mode 100644 isaaclab_arena/relations/warp_mesh_manager.py create mode 100644 isaaclab_arena/relations/warp_sdf_kernels.py create mode 100644 isaaclab_arena/scripts/benchmark_collision_modes.py create mode 100644 isaaclab_arena/tests/test_mesh_collision.py diff --git a/isaaclab_arena/assets/dummy_object.py b/isaaclab_arena/assets/dummy_object.py index aabe95bfe..8d66beeda 100644 --- a/isaaclab_arena/assets/dummy_object.py +++ b/isaaclab_arena/assets/dummy_object.py @@ -2,17 +2,21 @@ # All rights reserved. # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import torch +from typing import TYPE_CHECKING from isaaclab_arena.relations.relations import Relation, RelationBase, UnaryRelation from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, quaternion_to_90_deg_z_quarters from isaaclab_arena.utils.pose import Pose +if TYPE_CHECKING: + import trimesh + class DummyObject: - """ - Dummy object for testing purposes without Isaac Sim dependencies. - """ + """Dummy object for testing purposes without Isaac Sim dependencies.""" def __init__( self, @@ -20,6 +24,7 @@ def __init__( bounding_box: AxisAlignedBoundingBox, initial_pose: Pose | None = None, relations: list[RelationBase] = [], + collision_mesh: trimesh.Trimesh | None = None, **kwargs, ): self.name = name @@ -27,6 +32,7 @@ def __init__( self.bounding_box = bounding_box assert self.bounding_box is not None self.relations = list(relations) + self._collision_mesh = collision_mesh def add_relation(self, relation: RelationBase) -> None: self.relations.append(relation) @@ -63,3 +69,7 @@ def get_initial_pose(self) -> Pose | None: def is_initial_pose_set(self) -> bool: return self.initial_pose is not None + + def get_collision_mesh(self) -> trimesh.Trimesh | None: + """Return the collision mesh, or None to fall back to AABB.""" + return self._collision_mesh diff --git a/isaaclab_arena/assets/object_base.py b/isaaclab_arena/assets/object_base.py index 075d8f1c9..2b32695cb 100644 --- a/isaaclab_arena/assets/object_base.py +++ b/isaaclab_arena/assets/object_base.py @@ -7,9 +7,13 @@ import torch from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import warp as wp from isaaclab.assets import ArticulationCfg, AssetBaseCfg, RigidObjectCfg + +if TYPE_CHECKING: + import trimesh from isaaclab.envs import ManagerBasedEnv from isaaclab.managers import EventTermCfg, SceneEntityCfg from isaaclab.sensors.contact_sensor.contact_sensor_cfg import ContactSensorCfg @@ -73,6 +77,14 @@ def get_world_bounding_box(self) -> AxisAlignedBoundingBox: """Get bounding box in world coordinates (local bbox rotated and translated).""" ... + def get_collision_mesh(self) -> trimesh.Trimesh | None: + """Return the collision mesh for this object, or None. + + When None, the mesh-based collision system falls back to AABB overlap + for any pair involving this object. Subclasses with mesh geometry + should override this method. + """ + def _get_initial_pose_as_pose(self) -> Pose | None: """Return a single ``Pose`` suitable for *init_state* and bounding-box calculations. diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index b9e6c9b2a..101a4760f 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -13,6 +13,7 @@ from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult from isaaclab_arena.relations.relation_solver import RelationSolver +from isaaclab_arena.relations.relation_solver_params import CollisionMode from isaaclab_arena.relations.relations import ( IsAnchor, On, @@ -161,6 +162,11 @@ def _prepare_placement( "Call anchor_object.set_initial_pose(...) before placing." ) + assert not (self.params.random_yaw_init and self.params.solver_params.collision_mode == CollisionMode.MESH), ( + "random_yaw_init is not yet supported with CollisionMode.MESH -- " + "sphere centers are not rotated by candidate yaw." + ) + generator: torch.Generator | None = None if self.params.placement_seed is not None: generator = torch.Generator() @@ -598,6 +604,81 @@ def _validate_no_overlap( return False return True + def _validate_no_overlap_mesh( + self, + positions: dict[ObjectBase, tuple[float, float, float]], + ) -> bool: + """Validate no-overlap using sphere-to-SDF mesh queries. + + Mirrors the AABB validator's pair-skipping logic (On pairs, anchor-anchor). + Skips pairs where either object lacks a collision mesh (the solver's loss + path still penalizes those via AABB fallback during optimization). + """ + from isaaclab_arena.relations.warp_mesh_manager import WarpMeshManager + from isaaclab_arena.relations.warp_sdf_kernels import mesh_sdf + + on_pairs: set[tuple] = set() + anchor_ids: set[int] = set() + for obj in positions: + for rel in obj.get_relations(): + if isinstance(rel, On) and rel.parent in positions: + on_pairs.add((id(obj), id(rel.parent))) + on_pairs.add((id(rel.parent), id(obj))) + if any(isinstance(r, IsAnchor) for r in obj.get_relations()): + anchor_ids.add(id(obj)) + + clearance_m = self.params.solver_params.clearance_m + tolerance = max(0.0, clearance_m - 1e-6) + manager = WarpMeshManager( + num_spheres=self.params.solver_params.num_spheres, + device="cpu", + ) + + warned_no_mesh: set[str] = set() + objects = list(positions.keys()) + for i in range(len(objects)): + for j in range(i + 1, len(objects)): + a, b = objects[i], objects[j] + if id(a) in anchor_ids and id(b) in anchor_ids: + continue + if (id(a), id(b)) in on_pairs: + continue + + a_mesh = a.get_collision_mesh() + b_mesh = b.get_collision_mesh() + if a_mesh is None or b_mesh is None: + for obj, mesh in [(a, a_mesh), (b, b_mesh)]: + if mesh is None and obj.name not in warned_no_mesh: + warned_no_mesh.add(obj.name) + print( + f" [NoCollision] MESH mode: '{obj.name}' has no collision mesh, skipping mesh" + " validation" + ) + continue + + a_pos = torch.tensor(positions[a], dtype=torch.float32) + b_pos = torch.tensor(positions[b], dtype=torch.float32) + + # Forward: a's spheres against b's mesh + spheres_a = manager.get_query_spheres(a_mesh, obj=a) + warp_b = manager.get_warp_mesh(b_mesh, obj=b) + centers_a_in_b = spheres_a[:, :3] + a_pos - b_pos + if (mesh_sdf(centers_a_in_b, warp_b) < spheres_a[:, 3] + tolerance).any(): + if self.params.verbose: + print(f" Mesh overlap between '{a.name}' and '{b.name}'") + return False + + # Reverse: b's spheres against a's mesh + spheres_b = manager.get_query_spheres(b_mesh, obj=b) + warp_a = manager.get_warp_mesh(a_mesh, obj=a) + centers_b_in_a = spheres_b[:, :3] + b_pos - a_pos + if (mesh_sdf(centers_b_in_a, warp_a) < spheres_b[:, 3] + tolerance).any(): + if self.params.verbose: + print(f" Mesh overlap between '{b.name}' and '{a.name}'") + return False + + return True + def _validate_placement( self, positions: dict[ObjectBase, tuple[float, float, float]], @@ -612,7 +693,12 @@ def _validate_placement( Returns: True if no overlaps exist and On relations hold, False otherwise. """ - return self._validate_no_overlap(positions, env_bboxes) and self._validate_on_relations(positions, env_bboxes) + if not self._validate_no_overlap(positions, env_bboxes): + return False + if self.params.solver_params.collision_mode == CollisionMode.MESH: + if not self._validate_no_overlap_mesh(positions): + return False + return self._validate_on_relations(positions, env_bboxes) def _apply_poses( self, diff --git a/isaaclab_arena/relations/object_placer_params.py b/isaaclab_arena/relations/object_placer_params.py index 353e4844c..818596850 100644 --- a/isaaclab_arena/relations/object_placer_params.py +++ b/isaaclab_arena/relations/object_placer_params.py @@ -5,7 +5,9 @@ from dataclasses import dataclass, field -from isaaclab_arena.relations.relation_solver_params import RelationSolverParams +from isaaclab_arena.relations.relation_solver_params import CollisionMode, RelationSolverParams + +__all__ = ["CollisionMode", "ObjectPlacerParams"] @dataclass diff --git a/isaaclab_arena/relations/relation_loss_strategies.py b/isaaclab_arena/relations/relation_loss_strategies.py index db3e35a9d..fcdb714f8 100644 --- a/isaaclab_arena/relations/relation_loss_strategies.py +++ b/isaaclab_arena/relations/relation_loss_strategies.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import torch from abc import ABC, abstractmethod from dataclasses import dataclass @@ -71,7 +73,7 @@ class UnaryRelationLossStrategy(ABC): @abstractmethod def compute_loss( self, - relation: "Relation", + relation: Relation, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, ) -> torch.Tensor: @@ -95,7 +97,7 @@ class RelationLossStrategy(ABC): @abstractmethod def compute_loss( self, - relation: "Relation", + relation: Relation, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, parent_world_bbox: AxisAlignedBoundingBox, @@ -136,7 +138,7 @@ def __init__(self, slope: float = 10.0, debug: bool = False): def compute_loss( self, - relation: "NextTo", + relation: NextTo, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, parent_world_bbox: AxisAlignedBoundingBox, @@ -247,7 +249,7 @@ def __init__(self, slope: float = 10.0, debug: bool = False): def compute_loss( self, - relation: "On", + relation: On, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, parent_world_bbox: AxisAlignedBoundingBox, @@ -351,7 +353,7 @@ def __init__(self, slope: float = 10.0, margin_m: float = 0.1, debug: bool = Fal def compute_loss( self, - relation: "NotNextTo", + relation: NotNextTo, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, parent_world_bbox: AxisAlignedBoundingBox, @@ -412,25 +414,46 @@ def compute_loss( class NoCollisionLossStrategy: """Loss strategy for no-overlap constraints between objects. - Computes loss based on: - 1. X overlap: zero when child and parent are separated along X; else overlap length - 2. Y overlap: zero when separated along Y; else overlap length - 3. Z overlap: zero when separated along Z; else overlap length - 4. Volume loss: slope * (overlap_x * overlap_y * overlap_z) + Supports two modes controlled by ``collision_mode``: + - BBOX (default): AABB volume-overlap loss (X/Y/Z axis overlap product). + - MESH: Sphere-to-SDF penetration loss using actual mesh geometry. - This is a standalone strategy (not a RelationLossStrategy) because no-overlap - is a built-in solver behavior, not a user-specified relation. + In MESH mode, falls back to AABB per-pair when either object lacks a + collision mesh. This is a standalone strategy (not a RelationLossStrategy) + because no-overlap is a built-in solver behavior, not a user-specified relation. """ - def __init__(self, slope: float = 10.0, debug: bool = False): + def __init__( + self, + slope: float = 10.0, + debug: bool = False, + collision_mode=None, + num_spheres: int = 30, + ): """ Args: - slope: Gradient magnitude for overlap volume loss (default: 10.0). - Loss scales with slope times overlap volume. - debug: If True, print detailed loss component breakdown. + slope: Gradient magnitude for overlap loss. + debug: If True, print detailed AABB loss component breakdown. + collision_mode: CollisionMode enum value (BBOX or MESH). Defaults to BBOX. + num_spheres: Number of spheres for mesh decomposition (MESH mode only). """ + from isaaclab_arena.relations.relation_solver_params import CollisionMode + self.slope = slope self.debug = debug + self._mode = collision_mode if collision_mode is not None else CollisionMode.BBOX + self._num_spheres = num_spheres + self._CollisionMode = CollisionMode + self._warned_no_mesh: set[str] = set() + self._mesh_managers: dict[str, object] = {} # keyed by device string + + if self._mode == CollisionMode.MESH: + try: + import warp # noqa: F401 + except ImportError as e: + raise ImportError( + "CollisionMode.MESH requires the 'warp' package. Install it with: pip install warp-lang" + ) from e def compute_loss( self, @@ -438,23 +461,51 @@ def compute_loss( child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, parent_world_bbox: AxisAlignedBoundingBox, + child_obj=None, + parent_obj=None, + parent_pos: torch.Tensor | None = None, ) -> torch.Tensor: - """Compute loss for no-overlap constraint. + """Compute collision loss, dispatching based on mode and mesh availability. Args: - clearance_m: Minimum clearance between bounding boxes in meters. - child_pos: Child object position (N, 3) in world coords. - child_bbox: Child object local bounding box (N=1). - parent_world_bbox: Parent bounding box in world coordinates. + clearance_m: Minimum clearance in meters. + child_pos: Child position (N, 3) in world coords -- gradient target. + child_bbox: Child AABB (used in BBOX mode). + parent_world_bbox: Parent world AABB (used in BBOX mode). + child_obj: Object with get_collision_mesh() (MESH mode). + parent_obj: Object with get_collision_mesh() (MESH mode). + parent_pos: Parent position tensor, or None for anchors (MESH mode). Returns: Loss tensor of shape (N,). """ + if self._mode == self._CollisionMode.MESH and child_obj is not None and parent_obj is not None: + child_mesh = child_obj.get_collision_mesh() + parent_mesh = parent_obj.get_collision_mesh() + if child_mesh is not None and parent_mesh is not None: + return self._compute_mesh_loss( + clearance_m, child_pos, child_obj, child_mesh, parent_pos, parent_obj, parent_mesh + ) + for obj, mesh in [(child_obj, child_mesh), (parent_obj, parent_mesh)]: + name = getattr(obj, "name", "?") + if mesh is None and name not in self._warned_no_mesh: + self._warned_no_mesh.add(name) + print(f" [NoCollision] MESH mode: '{name}' has no collision mesh, falling back to AABB") + + return self._compute_aabb_loss(clearance_m, child_pos, child_bbox, parent_world_bbox) + + def _compute_aabb_loss( + self, + clearance_m: float, + child_pos: torch.Tensor, + child_bbox: AxisAlignedBoundingBox, + parent_world_bbox: AxisAlignedBoundingBox, + ) -> torch.Tensor: + """AABB volume-overlap loss.""" single_input = child_pos.dim() == 1 if single_input: child_pos = child_pos.unsqueeze(0) - # Parent world extents from the world bounding box, expanded by clearance_m c = clearance_m parent_x_min = parent_world_bbox.min_point[:, 0] - c parent_x_max = parent_world_bbox.max_point[:, 0] + c @@ -463,16 +514,13 @@ def compute_loss( parent_z_min = parent_world_bbox.min_point[:, 2] - c parent_z_max = parent_world_bbox.max_point[:, 2] + c - # Child world extents child_world_min = child_pos + child_bbox.min_point child_world_max = child_pos + child_bbox.max_point - # 1. Per-axis overlap: zero when separated; else overlap length (default slope 1.0 gives length in m) overlap_x = interval_overlap_axis_loss(child_world_min[:, 0], child_world_max[:, 0], parent_x_min, parent_x_max) overlap_y = interval_overlap_axis_loss(child_world_min[:, 1], child_world_max[:, 1], parent_y_min, parent_y_max) overlap_z = interval_overlap_axis_loss(child_world_min[:, 2], child_world_max[:, 2], parent_z_min, parent_z_max) - # 2. Volume loss: slope * product of per-axis overlap lengths (overlap volume when slope 1.0) overlap_volume = overlap_x * overlap_y * overlap_z total_loss = self.slope * overlap_volume @@ -496,6 +544,68 @@ def compute_loss( return total_loss.squeeze(0) if single_input else total_loss + def _get_mesh_manager(self, device: str = "cuda:0"): + if device not in self._mesh_managers: + from isaaclab_arena.relations.warp_mesh_manager import WarpMeshManager + + self._mesh_managers[device] = WarpMeshManager(num_spheres=self._num_spheres, device=device) + return self._mesh_managers[device] + + def _compute_mesh_loss( + self, + clearance_m: float, + child_pos: torch.Tensor, + child_obj, + child_mesh, + parent_pos: torch.Tensor | None, + parent_obj, + parent_mesh, + ) -> torch.Tensor: + """Sphere-to-SDF penetration loss using mesh geometry.""" + from isaaclab_arena.relations.warp_sdf_kernels import sphere_penetration_loss + + single_input = child_pos.dim() == 1 + if single_input: + child_pos = child_pos.unsqueeze(0) + + if parent_pos is None: + from isaaclab_arena.utils.pose import Pose + + pose = parent_obj.get_initial_pose() + assert pose is not None, f"parent_pos=None but '{getattr(parent_obj, 'name', '?')}' has no initial pose" + assert isinstance( + pose, Pose + ), f"Anchor '{getattr(parent_obj, 'name', '?')}' must have a fixed Pose for mesh collision" + identity = (0.0, 0.0, 0.0, 1.0) + assert pose.rotation_xyzw == identity, ( + f"Mesh collision with rotated anchor '{getattr(parent_obj, 'name', '?')}' " + f"is not yet supported (rotation={pose.rotation_xyzw})" + ) + parent_pos = torch.tensor(pose.position_xyz, dtype=child_pos.dtype, device=child_pos.device) + + assert parent_pos is not None + parent_pos_resolved: torch.Tensor = parent_pos + if parent_pos_resolved.dim() == 1: + parent_pos_resolved = parent_pos_resolved.unsqueeze(0) + + device = child_pos.device + manager = self._get_mesh_manager(str(device)) + spheres = manager.get_query_spheres(child_mesh, obj=child_obj).to(device) + centers_local = spheres[:, :3] + radii = spheres[:, 3] + warp_mesh = manager.get_warp_mesh(parent_mesh, obj=parent_obj) + + batch_size = child_pos.shape[0] + total_loss = torch.zeros(batch_size, device=device, dtype=child_pos.dtype) + + for b in range(batch_size): + centers_world = centers_local + child_pos[b] - parent_pos_resolved[b] + total_loss[b] = self.slope * sphere_penetration_loss( + centers_world, radii, warp_mesh, clearance_m=clearance_m + ) + + return total_loss.squeeze(0) if single_input else total_loss + class AtPositionLossStrategy(UnaryRelationLossStrategy): """Loss strategy for AtPosition relations. @@ -514,7 +624,7 @@ def __init__(self, slope: float = 10.0): def compute_loss( self, - relation: "AtPosition", + relation: AtPosition, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, ) -> torch.Tensor: @@ -570,7 +680,7 @@ def __init__(self, slope: float = 100.0): def compute_loss( self, - relation: "PositionLimits", + relation: PositionLimits, child_pos: torch.Tensor, child_bbox: AxisAlignedBoundingBox, ) -> torch.Tensor: diff --git a/isaaclab_arena/relations/relation_solver.py b/isaaclab_arena/relations/relation_solver.py index d11d456c6..37b4273ea 100644 --- a/isaaclab_arena/relations/relation_solver.py +++ b/isaaclab_arena/relations/relation_solver.py @@ -42,7 +42,11 @@ def __init__( """ self.params = params or RelationSolverParams() # High slope (vs 10-100 for relation strategies) so overlap avoidance dominates. - self._no_collision_strategy = NoCollisionLossStrategy(slope=10000.0) + self._no_collision_strategy = NoCollisionLossStrategy( + collision_mode=self.params.collision_mode, + slope=10000.0, + num_spheres=self.params.num_spheres, + ) self._last_loss_history: list[float] = [] self._last_position_history: list = [] self._last_loss_per_env: torch.Tensor | None = None @@ -183,6 +187,9 @@ def _compute_no_overlap_loss( child_pos=child_pos, child_bbox=child_bbox, parent_world_bbox=anchor_world_bbox, + child_obj=child, + parent_obj=anchor, + parent_pos=None, ) if debug: print(f" [NoOverlap] {child.name} vs {anchor.name}: loss={loss.mean().item():.6f}") @@ -203,6 +210,9 @@ def _compute_no_overlap_loss( child_pos=child_pos, child_bbox=child_bbox, parent_world_bbox=other_world_bbox, + child_obj=child, + parent_obj=other, + parent_pos=other_pos.detach(), ) # Reverse: gradient flows to other (object j) @@ -212,6 +222,9 @@ def _compute_no_overlap_loss( child_pos=other_pos, child_bbox=other_bbox, parent_world_bbox=child_world_bbox, + child_obj=other, + parent_obj=child, + parent_pos=child_pos.detach(), ) if debug: diff --git a/isaaclab_arena/relations/relation_solver_params.py b/isaaclab_arena/relations/relation_solver_params.py index 62f535bcb..e09ff411f 100644 --- a/isaaclab_arena/relations/relation_solver_params.py +++ b/isaaclab_arena/relations/relation_solver_params.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field +from enum import Enum from isaaclab_arena.relations.relation_loss_strategies import ( AtPositionLossStrategy, @@ -17,6 +18,16 @@ from isaaclab_arena.relations.relations import AtPosition, NextTo, NotNextTo, On, PositionLimits, RelationBase +class CollisionMode(Enum): + """Selects which collision detection method the solver uses for no-overlap constraints.""" + + BBOX = "bbox" + """Axis-aligned bounding box overlap volume (fast, conservative).""" + + MESH = "mesh" + """Sphere-to-SDF queries against actual mesh geometry (accurate, slower).""" + + def _default_strategies() -> dict[type[RelationBase], RelationLossStrategy | UnaryRelationLossStrategy]: """Factory for default loss strategies.""" return { @@ -47,6 +58,12 @@ class RelationSolverParams: save_position_history: bool = True """Save position snapshots during optimization for visualization/debugging. Disable to reduce memory.""" + collision_mode: CollisionMode = CollisionMode.BBOX + """Which collision detection method to use for no-overlap constraints.""" + + num_spheres: int = 30 + """Number of bounding spheres per object for MESH mode. Higher = more accurate but slower.""" + clearance_m: float = 0.01 """Minimum clearance (meters) enforced between every pair of non-anchor objects. The solver adds a no-overlap loss for all pairs automatically. Set to 0.0 to only diff --git a/isaaclab_arena/relations/warp_mesh_manager.py b/isaaclab_arena/relations/warp_mesh_manager.py new file mode 100644 index 000000000..ae52c2765 --- /dev/null +++ b/isaaclab_arena/relations/warp_mesh_manager.py @@ -0,0 +1,164 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Warp mesh management and greedy sphere decomposition for mesh-based collision.""" + +from __future__ import annotations + +import numpy as np +import torch +from collections import defaultdict +from heapq import heappop, heappush +from typing import TYPE_CHECKING + +import warp as wp + +if TYPE_CHECKING: + import trimesh + + +def _mesh_content_hash(mesh: trimesh.Trimesh) -> int: + """Content-based hash for a trimesh. Safe across GC cycles unlike id().""" + return hash((mesh.vertices.tobytes(), mesh.faces.tobytes())) + + +def greedy_sphere_decomposition( + mesh: trimesh.Trimesh, + num_spheres: int = 30, + sphere_radius: float = 0.01, + n_candidates: int = 200, + n_surface: int = 1000, + seed: int = 42, +) -> np.ndarray: + """Decompose a mesh into bounding spheres via greedy set-cover. + + Based on the greedy_sample_mesh algorithm by Caelan Garrett (NVIDIA). + Uses trimesh.proximity.max_tangent_sphere for candidate generation, + then greedy selection maximising surface coverage. + + Args: + mesh: Input trimesh (must be watertight or convex-hull-repairable). + num_spheres: Maximum number of output spheres. + sphere_radius: Inflation added to tangent sphere radii (safety margin). + n_candidates: Number of candidate sphere centers sampled. + n_surface: Number of surface points for coverage tracking. + seed: RNG seed for reproducible surface sampling. + + Returns: + (K, 4) array of [cx, cy, cz, radius] in mesh-local frame. K <= num_spheres. + """ + import trimesh as _trimesh + + n_candidates = max(num_spheres, n_candidates) + n_surface = max(n_candidates, n_surface) + + rng = np.random.default_rng(seed) + points = _trimesh.sample.sample_surface(mesh, n_surface, seed=rng)[0] + cloud = _trimesh.PointCloud(points) + + # Compute tangent spheres at candidate surface points + work_mesh = mesh if mesh.is_watertight else mesh.convex_hull + candidates = points[:n_candidates] + try: + centers, radii = _trimesh.proximity.max_tangent_sphere(work_mesh, candidates) + except (IndexError, ValueError): + # Fallback: uniform spheres at surface samples + centers = candidates[:num_spheres] + radii = np.full(len(centers), sphere_radius) + return np.column_stack([centers, radii]) + + radii = radii + sphere_radius + + # Filter degenerate spheres + max_radius = np.linalg.norm(mesh.extents) / 2 + valid = (radii <= max_radius) & np.isfinite(radii) + centers, radii = centers[valid], radii[valid] + + if len(centers) == 0: + # Last resort: uniform spheres + pts = points[:num_spheres] + return np.column_stack([pts, np.full(len(pts), sphere_radius)]) + + # Build coverage graph: which surface points does each sphere cover? + outgoing: dict[int, set[int]] = defaultdict(set) + incoming: dict[int, set[int]] = defaultdict(set) + for idx, (center, radius) in enumerate(zip(centers, radii)): + covered = cloud.kdtree.query_ball_point(center, r=radius, eps=1e-6) + for pt_idx in covered: + outgoing[idx].add(pt_idx) + incoming[pt_idx].add(idx) + + # Greedy set-cover selection + selected: list[int] = [] + queue: list[tuple[int, int]] = [] + for idx in outgoing: + heappush(queue, (-len(outgoing[idx]), idx)) + + while queue and len(selected) < num_spheres: + neg_count, idx = heappop(queue) + if len(outgoing[idx]) != -neg_count: + heappush(queue, (-len(outgoing[idx]), idx)) + continue + if neg_count == 0: + break + # Remove covered points from other spheres + for pt_idx in list(outgoing[idx]): + for other_idx in incoming[pt_idx]: + outgoing[other_idx].discard(pt_idx) + selected.append(idx) + + if not selected: + pts = points[:num_spheres] + return np.column_stack([pts, np.full(len(pts), sphere_radius)]) + + return np.column_stack([centers[selected], radii[selected]]) + + +class WarpMeshManager: + """Manages Warp mesh creation, caching, and sphere decomposition for collision objects. + + Caches results by content hash (in-memory trimesh) or (usd_path, scale) for USD objects. + """ + + def __init__( + self, + num_spheres: int = 30, + sphere_radius: float = 0.01, + device: str = "cuda:0", + ): + self._num_spheres = num_spheres + self._sphere_radius = sphere_radius + self._device = device + self._warp_mesh_cache: dict[tuple, wp.Mesh] = {} + self._sphere_cache: dict[tuple, torch.Tensor] = {} + + def _cache_key(self, mesh: trimesh.Trimesh, obj=None) -> tuple: + """Compute cache key. Uses (usd_path, scale) for USD objects, content hash otherwise.""" + usd_path = getattr(obj, "usd_path", None) if obj is not None else None + if usd_path is not None: + scale = getattr(obj, "scale", (1.0, 1.0, 1.0)) + return (usd_path, scale, self._num_spheres, self._sphere_radius) + return (_mesh_content_hash(mesh), self._num_spheres, self._sphere_radius) + + def get_warp_mesh(self, mesh: trimesh.Trimesh, obj=None) -> wp.Mesh: + """Get or create a Warp BVH mesh for SDF queries.""" + key = self._cache_key(mesh, obj) + if key not in self._warp_mesh_cache: + vertices = wp.array(np.asarray(mesh.vertices, dtype=np.float32), dtype=wp.vec3, device=self._device) + indices = wp.array(np.asarray(mesh.faces, dtype=np.int32).flatten(), dtype=wp.int32, device=self._device) + self._warp_mesh_cache[key] = wp.Mesh(points=vertices, indices=indices) + return self._warp_mesh_cache[key] + + def get_query_spheres(self, mesh: trimesh.Trimesh, obj=None) -> torch.Tensor: + """Get or compute sphere decomposition as (K, 4) tensor [cx, cy, cz, radius].""" + key = self._cache_key(mesh, obj) + if key not in self._sphere_cache: + spheres_np = greedy_sphere_decomposition( + mesh, + num_spheres=self._num_spheres, + sphere_radius=self._sphere_radius, + ) + self._sphere_cache[key] = torch.from_numpy(spheres_np).float() + return self._sphere_cache[key] diff --git a/isaaclab_arena/relations/warp_sdf_kernels.py b/isaaclab_arena/relations/warp_sdf_kernels.py new file mode 100644 index 000000000..4bf3fc053 --- /dev/null +++ b/isaaclab_arena/relations/warp_sdf_kernels.py @@ -0,0 +1,141 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Warp SDF kernels and PyTorch autograd bridge for mesh-based collision loss.""" + +from __future__ import annotations + +import torch + +import warp as wp + + +@wp.kernel +def _sdf_query_kernel( + mesh_id: wp.uint64, + query_points: wp.array(dtype=wp.vec3), + sdf_out: wp.array(dtype=wp.float32), + grad_out: wp.array(dtype=wp.vec3), +): + """Query signed distance and gradient for each point against a Warp mesh.""" + tid = wp.tid() + p = query_points[tid] + + face_index = int(0) + face_u = float(0.0) + face_v = float(0.0) + sign = float(0.0) + + found = wp.mesh_query_point_sign_normal(mesh_id, p, 1.0e6, sign, face_index, face_u, face_v) + + if found: + closest = wp.mesh_eval_position(mesh_id, face_index, face_u, face_v) + diff = p - closest + dist = wp.length(diff) + sdf_out[tid] = sign * dist + if dist > 1.0e-8: + grad_out[tid] = sign * wp.normalize(diff) + else: + grad_out[tid] = wp.vec3(0.0, 0.0, 0.0) + else: + sdf_out[tid] = float(1.0e6) + grad_out[tid] = wp.vec3(0.0, 0.0, 0.0) + + +class _MeshSDFFunction(torch.autograd.Function): + """Autograd bridge: forward computes SDF values, backward propagates via analytical gradients.""" + + @staticmethod + def forward(ctx, points: torch.Tensor, mesh: wp.Mesh) -> torch.Tensor: + """Compute SDF values for query points against a Warp mesh. + + Args: + points: (N, 3) float32 tensor of query positions (must be on same device as mesh). + mesh: Warp Mesh with BVH built. + + Returns: + (N,) tensor of signed distance values (negative = inside). + """ + device = points.device + n = points.shape[0] + wp_device = str(device) + + wp_points = wp.from_torch(points.contiguous(), dtype=wp.vec3) + sdf_wp = wp.zeros(n, dtype=wp.float32, device=wp_device) + grad_wp = wp.zeros(n, dtype=wp.vec3, device=wp_device) + + wp.launch( + kernel=_sdf_query_kernel, + dim=n, + inputs=[mesh.id, wp_points, sdf_wp, grad_wp], + device=wp_device, + ) + + sdf_torch = wp.to_torch(sdf_wp) + grad_torch = wp.to_torch(grad_wp).reshape(n, 3) + + ctx.save_for_backward(grad_torch) + return sdf_torch + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """Backprop through SDF: dL/dpoints = dL/dsdf * dsdf/dpoints.""" + (grad_sdf,) = ctx.saved_tensors + # grad_output: (N,), grad_sdf: (N, 3) -- analytical SDF gradient + grad_points = grad_output.unsqueeze(-1) * grad_sdf + return grad_points, None + + +def mesh_sdf(points: torch.Tensor, warp_mesh: wp.Mesh) -> torch.Tensor: + """Differentiable signed distance query. + + Args: + points: (N, 3) query points on the same device as the mesh. + warp_mesh: Warp Mesh object with BVH. + + Returns: + (N,) signed distance values. Negative = penetrating. + """ + return _MeshSDFFunction.apply(points, warp_mesh) + + +def sphere_penetration_loss( + sphere_centers_world: torch.Tensor, + sphere_radii: torch.Tensor, + warp_mesh: wp.Mesh, + clearance_m: float = 0.0, +) -> torch.Tensor: + """Compute ReLU penetration loss for spheres against a mesh SDF. + + Loss per sphere = ReLU(effective_radius - sdf). + Total loss = mean over all spheres. + + Args: + sphere_centers_world: (K, 3) world-space sphere centers. + sphere_radii: (K,) sphere radii. + warp_mesh: Target Warp mesh to check against. + clearance_m: Additional clearance added to radii. + + Returns: + Scalar loss tensor (differentiable w.r.t. sphere_centers_world). + """ + sdf_values = mesh_sdf(sphere_centers_world, warp_mesh) + + # SDF returns 1e6 for points where no mesh face was found (degenerate query). + # These read as "collision-free" (relu(r - 1e6) = 0) which is silently wrong. + _SDF_SENTINEL = 1.0e5 + if not hasattr(sphere_penetration_loss, "_warned_sentinel"): + sphere_penetration_loss._warned_sentinel = False # type: ignore[attr-defined] + if not sphere_penetration_loss._warned_sentinel and (sdf_values >= _SDF_SENTINEL).any(): # type: ignore[attr-defined] + sphere_penetration_loss._warned_sentinel = True # type: ignore[attr-defined] + n_bad = int((sdf_values >= _SDF_SENTINEL).sum().item()) + print( + f" [MeshSDF] WARNING: {n_bad}/{len(sdf_values)} sphere queries returned sentinel SDF " + "(no mesh face found). Collision detection may be incomplete for these points." + ) + + effective_radii = sphere_radii + clearance_m + penetration = torch.relu(effective_radii - sdf_values) + return penetration.mean() diff --git a/isaaclab_arena/scripts/benchmark_collision_modes.py b/isaaclab_arena/scripts/benchmark_collision_modes.py new file mode 100644 index 000000000..c20be8512 --- /dev/null +++ b/isaaclab_arena/scripts/benchmark_collision_modes.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Benchmark script comparing BBOX vs MESH collision modes. + +Measures wall-clock time, iterations to convergence, and placement success +for a configurable number of objects and sphere budgets. + +Usage: + python isaaclab_arena/scripts/benchmark_collision_modes.py +""" + +from __future__ import annotations + +import numpy as np +import time +import trimesh + +from isaaclab_arena.assets.dummy_object import DummyObject +from isaaclab_arena.relations.object_placer import ObjectPlacer +from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams +from isaaclab_arena.relations.relation_solver_params import CollisionMode, RelationSolverParams +from isaaclab_arena.relations.relations import IsAnchor, On +from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox +from isaaclab_arena.utils.pose import Pose + + +def _make_cylinder(name: str, radius: float = 0.033, height: float = 0.1) -> DummyObject: + mesh = trimesh.creation.cylinder(radius=radius, height=height, sections=32) + return DummyObject( + name=name, + bounding_box=AxisAlignedBoundingBox( + min_point=(-radius, -radius, -height / 2), + max_point=(radius, radius, height / 2), + ), + collision_mesh=mesh, + ) + + +def _make_table() -> DummyObject: + mesh = trimesh.creation.box(extents=(0.6, 0.6, 0.05)) + table = DummyObject( + name="table", + bounding_box=AxisAlignedBoundingBox(min_point=(-0.3, -0.3, -0.025), max_point=(0.3, 0.3, 0.025)), + collision_mesh=mesh, + ) + table.set_initial_pose(Pose(position_xyz=(0.0, 0.0, 0.0), rotation_xyzw=(0.0, 0.0, 0.0, 1.0))) + table.add_relation(IsAnchor()) + return table + + +def _build_scene(num_objects: int) -> list[DummyObject]: + table = _make_table() + objects = [table] + for i in range(num_objects): + obj = _make_cylinder(f"cyl_{i}", radius=0.025 + 0.005 * (i % 3)) + obj.add_relation(On(table)) + objects.append(obj) + return objects + + +def _run_benchmark( + mode: CollisionMode, + num_objects: int, + num_spheres: int, + max_iters: int = 400, + attempts: int = 5, +) -> dict: + objects = _build_scene(num_objects) + params = ObjectPlacerParams( + solver_params=RelationSolverParams( + collision_mode=mode, + num_spheres=num_spheres, + max_iters=max_iters, + verbose=False, + ), + max_placement_attempts=attempts, + verbose=False, + ) + placer = ObjectPlacer(params=params) + + t0 = time.perf_counter() + result = placer.place(objects) + elapsed = time.perf_counter() - t0 + + return { + "mode": mode.value, + "num_objects": num_objects, + "num_spheres": num_spheres, + "time_s": elapsed, + "valid": result.success, + "loss": result.final_loss, + } + + +def main(): + print("=" * 70) + print("Collision Mode Benchmark: BBOX vs MESH") + print("=" * 70) + print() + + configs = [ + (5, [10, 30, 50]), + (8, [10, 30, 50]), + ] + + results = [] + + for num_objects, sphere_counts in configs: + # BBOX baseline + r = _run_benchmark(CollisionMode.BBOX, num_objects, num_spheres=30) + results.append(r) + print(f"BBOX | {num_objects} objects | time={r['time_s']:.3f}s | valid={r['valid']} | loss={r['loss']:.6f}") + + # MESH variants + for ns in sphere_counts: + r = _run_benchmark(CollisionMode.MESH, num_objects, num_spheres=ns) + results.append(r) + print( + f"MESH | {num_objects} objects | spheres={ns:3d} | " + f"time={r['time_s']:.3f}s | valid={r['valid']} | loss={r['loss']:.6f}" + ) + print() + + print("-" * 70) + print("Summary:") + bbox_times = [r["time_s"] for r in results if r["mode"] == "bbox"] + mesh_times = [r["time_s"] for r in results if r["mode"] == "mesh"] + if bbox_times and mesh_times: + print(f" BBOX avg: {np.mean(bbox_times):.3f}s") + print(f" MESH avg: {np.mean(mesh_times):.3f}s") + print(f" Slowdown: {np.mean(mesh_times) / np.mean(bbox_times):.1f}x") + + +if __name__ == "__main__": + main() diff --git a/isaaclab_arena/tests/test_mesh_collision.py b/isaaclab_arena/tests/test_mesh_collision.py new file mode 100644 index 000000000..47435d9d3 --- /dev/null +++ b/isaaclab_arena/tests/test_mesh_collision.py @@ -0,0 +1,542 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for mesh-based collision detection: sphere decomposition, dispatch, and end-to-end solver.""" + +from __future__ import annotations + +import numpy as np +import torch +import trimesh + +import pytest + +from isaaclab_arena.assets.dummy_object import DummyObject +from isaaclab_arena.relations.relation_loss_strategies import NoCollisionLossStrategy +from isaaclab_arena.relations.relation_solver import RelationSolver +from isaaclab_arena.relations.relation_solver_params import CollisionMode, RelationSolverParams +from isaaclab_arena.relations.relations import IsAnchor, On +from isaaclab_arena.relations.warp_mesh_manager import WarpMeshManager, greedy_sphere_decomposition +from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox +from isaaclab_arena.utils.pose import Pose + +try: + import warp as wp + + wp.init() + _WARP_AVAILABLE = True +except Exception: + _WARP_AVAILABLE = False + +requires_warp = pytest.mark.skipif(not _WARP_AVAILABLE, reason="Warp not available") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_cylinder(name: str, radius: float = 0.033, height: float = 0.1) -> DummyObject: + mesh = trimesh.creation.cylinder(radius=radius, height=height, sections=32) + return DummyObject( + name=name, + bounding_box=AxisAlignedBoundingBox( + min_point=(-radius, -radius, -height / 2), + max_point=(radius, radius, height / 2), + ), + collision_mesh=mesh, + ) + + +def _make_box_obj(name: str, sx: float, sy: float, sz: float) -> DummyObject: + mesh = trimesh.creation.box(extents=(sx, sy, sz)) + return DummyObject( + name=name, + bounding_box=AxisAlignedBoundingBox( + min_point=(-sx / 2, -sy / 2, -sz / 2), + max_point=(sx / 2, sy / 2, sz / 2), + ), + collision_mesh=mesh, + ) + + +def _make_table() -> DummyObject: + mesh = trimesh.creation.box(extents=(1.0, 1.0, 0.05)) + table = DummyObject( + name="table", + bounding_box=AxisAlignedBoundingBox(min_point=(-0.5, -0.5, -0.025), max_point=(0.5, 0.5, 0.025)), + collision_mesh=mesh, + ) + table.set_initial_pose(Pose(position_xyz=(0.0, 0.0, 0.0), rotation_xyzw=(0.0, 0.0, 0.0, 1.0))) + table.add_relation(IsAnchor()) + return table + + +# --------------------------------------------------------------------------- +# Unit: greedy_sphere_decomposition +# --------------------------------------------------------------------------- + + +def test_sphere_decomposition_covers_surface(): + """Sphere decomposition should cover >80% of surface sample points.""" + mesh = trimesh.creation.cylinder(radius=0.05, height=0.1, sections=32) + spheres = greedy_sphere_decomposition(mesh, num_spheres=20, n_surface=500) + assert spheres.shape[1] == 4 + assert spheres.shape[0] <= 20 + + # Check coverage: what fraction of surface points are within radius of some sphere? + surface_pts = trimesh.sample.sample_surface(mesh, 200)[0] + centers = spheres[:, :3] + radii = spheres[:, 3] + covered = 0 + for pt in surface_pts: + dists = np.linalg.norm(centers - pt, axis=1) + if (dists < radii).any(): + covered += 1 + coverage = covered / len(surface_pts) + assert coverage > 0.8, f"Coverage only {coverage:.1%}" + + +def test_sphere_count_respects_budget(): + """Output sphere count should not exceed num_spheres.""" + mesh = trimesh.creation.box(extents=(0.1, 0.1, 0.1)) + spheres = greedy_sphere_decomposition(mesh, num_spheres=5) + assert len(spheres) <= 5 + + +# --------------------------------------------------------------------------- +# Unit: WarpMeshManager caching +# --------------------------------------------------------------------------- + + +@requires_warp +def test_warp_mesh_caching(): + """Same mesh object should return identical Warp mesh from cache.""" + mesh = trimesh.creation.box(extents=(0.1, 0.1, 0.1)) + manager = WarpMeshManager(num_spheres=10) + m1 = manager.get_warp_mesh(mesh) + m2 = manager.get_warp_mesh(mesh) + assert m1 is m2 + + +@requires_warp +def test_cache_key_differs_for_different_meshes(): + """Different trimeshes should produce different cache entries.""" + mesh_a = trimesh.creation.box(extents=(0.1, 0.1, 0.1)) + mesh_b = trimesh.creation.cylinder(radius=0.05, height=0.1) + manager = WarpMeshManager(num_spheres=10) + ma = manager.get_warp_mesh(mesh_a) + mb = manager.get_warp_mesh(mesh_b) + assert ma is not mb + + +# --------------------------------------------------------------------------- +# Unit: NoCollisionLossStrategy routing +# --------------------------------------------------------------------------- + + +def test_dispatch_routes_to_aabb_in_bbox_mode(): + """BBOX mode should use AABB even when objects have meshes.""" + dispatch = NoCollisionLossStrategy(collision_mode=CollisionMode.BBOX, slope=10.0) + + obj_a = _make_cylinder("a") + obj_b = _make_cylinder("b") + obj_b.set_initial_pose(Pose(position_xyz=(0.0, 0.0, 0.0), rotation_xyzw=(0.0, 0.0, 0.0, 1.0))) + + child_pos = torch.tensor([0.0, 0.0, 0.0]) + parent_world_bbox = obj_b.get_bounding_box().translated((0.5, 0.0, 0.0)) + + # Separated: should be zero regardless of mode + loss = dispatch.compute_loss( + clearance_m=0.0, + child_pos=child_pos, + child_bbox=obj_a.get_bounding_box(), + parent_world_bbox=parent_world_bbox, + child_obj=obj_a, + parent_obj=obj_b, + parent_pos=torch.tensor([0.5, 0.0, 0.0]), + ) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-5) + + +def test_dispatch_falls_back_when_obj_is_none(): + """Missing object refs should fall back to AABB.""" + dispatch = NoCollisionLossStrategy(collision_mode=CollisionMode.MESH, slope=10.0) + + bbox = AxisAlignedBoundingBox(min_point=(0.0, 0.0, 0.0), max_point=(0.1, 0.1, 0.1)) + child_pos = torch.tensor([0.0, 0.0, 0.0]) + parent_world_bbox = bbox.translated((0.5, 0.0, 0.0)) + + loss = dispatch.compute_loss( + clearance_m=0.0, + child_pos=child_pos, + child_bbox=bbox, + parent_world_bbox=parent_world_bbox, + child_obj=None, + parent_obj=None, + ) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-5) + + +def test_dispatch_falls_back_when_no_mesh(): + """MESH mode with objects lacking collision_mesh should use AABB.""" + dispatch = NoCollisionLossStrategy(collision_mode=CollisionMode.MESH, slope=10.0) + + # No collision_mesh + obj_a = DummyObject( + name="a", bounding_box=AxisAlignedBoundingBox(min_point=(0.0, 0.0, 0.0), max_point=(0.1, 0.1, 0.1)) + ) + obj_b = DummyObject( + name="b", bounding_box=AxisAlignedBoundingBox(min_point=(0.0, 0.0, 0.0), max_point=(0.1, 0.1, 0.1)) + ) + + child_pos = torch.tensor([0.0, 0.0, 0.0]) + parent_world_bbox = obj_b.get_bounding_box().translated((0.5, 0.0, 0.0)) + + loss = dispatch.compute_loss( + clearance_m=0.0, + child_pos=child_pos, + child_bbox=obj_a.get_bounding_box(), + parent_world_bbox=parent_world_bbox, + child_obj=obj_a, + parent_obj=obj_b, + parent_pos=torch.tensor([0.5, 0.0, 0.0]), + ) + assert torch.isclose(loss, torch.tensor(0.0), atol=1e-5) + + +# --------------------------------------------------------------------------- +# Unit: NoCollisionLossStrategy mesh mode (requires warp) +# --------------------------------------------------------------------------- + + +@requires_warp +def test_mesh_zero_loss_separated_cylinders(): + """Two cylinders far apart should produce zero mesh collision loss.""" + dispatch = NoCollisionLossStrategy(collision_mode=CollisionMode.MESH, slope=10000.0) + a = _make_cylinder("a") + b = _make_cylinder("b") + b.set_initial_pose(Pose(position_xyz=(1.0, 0.0, 0.0), rotation_xyzw=(0.0, 0.0, 0.0, 1.0))) + + child_pos = torch.tensor([0.0, 0.0, 0.0]) + parent_world_bbox = b.get_bounding_box().translated((1.0, 0.0, 0.0)) + + loss = dispatch.compute_loss( + clearance_m=0.0, + child_pos=child_pos, + child_bbox=a.get_bounding_box(), + parent_world_bbox=parent_world_bbox, + child_obj=a, + parent_obj=b, + parent_pos=torch.tensor([1.0, 0.0, 0.0]), + ) + assert loss.item() == 0.0 + + +@requires_warp +def test_mesh_positive_loss_overlapping_cylinders(): + """Two cylinders at the same position should produce positive mesh loss.""" + dispatch = NoCollisionLossStrategy(collision_mode=CollisionMode.MESH, slope=10000.0) + a = _make_cylinder("a") + b = _make_cylinder("b") + b.set_initial_pose(Pose(position_xyz=(0.0, 0.0, 0.0), rotation_xyzw=(0.0, 0.0, 0.0, 1.0))) + + child_pos = torch.tensor([0.0, 0.0, 0.0]) + parent_world_bbox = b.get_bounding_box().translated((0.0, 0.0, 0.0)) + + loss = dispatch.compute_loss( + clearance_m=0.0, + child_pos=child_pos, + child_bbox=a.get_bounding_box(), + parent_world_bbox=parent_world_bbox, + child_obj=a, + parent_obj=b, + parent_pos=torch.tensor([0.0, 0.0, 0.0]), + ) + assert loss.item() > 0.0 + + +@requires_warp +def test_mesh_loss_respects_clearance_m(): + """Near-miss cylinders should have positive loss when clearance_m > 0.""" + dispatch = NoCollisionLossStrategy(collision_mode=CollisionMode.MESH, slope=10000.0) + a = _make_cylinder("a", radius=0.03) + b = _make_cylinder("b", radius=0.03) + b.set_initial_pose(Pose(position_xyz=(0.07, 0.0, 0.0), rotation_xyzw=(0.0, 0.0, 0.0, 1.0))) + + child_pos = torch.tensor([0.0, 0.0, 0.0]) + parent_world_bbox = b.get_bounding_box().translated((0.07, 0.0, 0.0)) + + # Without clearance: separated (0.07 > 0.03+0.03) + loss_no_clearance = dispatch.compute_loss( + clearance_m=0.0, + child_pos=child_pos, + child_bbox=a.get_bounding_box(), + parent_world_bbox=parent_world_bbox, + child_obj=a, + parent_obj=b, + parent_pos=torch.tensor([0.07, 0.0, 0.0]), + ) + + # With large clearance: should trigger + loss_with_clearance = dispatch.compute_loss( + clearance_m=0.05, + child_pos=child_pos, + child_bbox=a.get_bounding_box(), + parent_world_bbox=parent_world_bbox, + child_obj=a, + parent_obj=b, + parent_pos=torch.tensor([0.07, 0.0, 0.0]), + ) + assert loss_with_clearance.item() > loss_no_clearance.item() + + +# --------------------------------------------------------------------------- +# Integration: solver with mesh mode +# --------------------------------------------------------------------------- + + +@requires_warp +def test_solver_separates_overlapping_cylinders_mesh_mode(): + """RelationSolver with MESH mode should push overlapping cylinders apart.""" + table = _make_table() + a = _make_cylinder("cyl_a") + b = _make_cylinder("cyl_b") + a.add_relation(On(table)) + b.add_relation(On(table)) + + objects = [table, a, b] + initial = [{table: (0.0, 0.0, 0.0), a: (0.0, 0.0, 0.03), b: (0.01, 0.0, 0.03)}] + + solver = RelationSolver( + params=RelationSolverParams( + collision_mode=CollisionMode.MESH, max_iters=200, convergence_threshold=1e-4, verbose=False + ) + ) + result = solver.solve(objects, initial)[0] + + pos_a = np.array(result[a]) + pos_b = np.array(result[b]) + dist = np.linalg.norm(pos_a[:2] - pos_b[:2]) + # Must be separated by at least sum of radii (0.033 + 0.033 = 0.066) + assert dist > 0.066, f"Cylinders not separated: dist={dist:.4f}, need > 0.066" + + +@requires_warp +def test_on_pairs_skipped_in_mesh_mode(): + """On-linked pairs should not be penalized in mesh mode (same as AABB).""" + table = _make_table() + obj = _make_cylinder("can") + obj.add_relation(On(table)) + + objects = [table, obj] + # Place can directly on table surface -- should converge without fighting On + initial = [{table: (0.0, 0.0, 0.0), obj: (0.0, 0.0, 0.03)}] + + solver = RelationSolver( + params=RelationSolverParams( + collision_mode=CollisionMode.MESH, max_iters=100, convergence_threshold=1e-4, verbose=False + ) + ) + result = solver.solve(objects, initial)[0] + + # Should stay near table surface, not be pushed far away + z = result[obj][2] + assert 0.0 < z < 0.15, f"Object pushed too far: z={z}" + + +# --------------------------------------------------------------------------- +# Guard: yaw + mesh incompatibility +# --------------------------------------------------------------------------- + + +def test_random_yaw_mesh_mode_assertion(): + """random_yaw_init=True + CollisionMode.MESH should raise AssertionError.""" + from isaaclab_arena.relations.object_placer import ObjectPlacer + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams + + params = ObjectPlacerParams( + solver_params=RelationSolverParams(collision_mode=CollisionMode.MESH), + random_yaw_init=True, + ) + placer = ObjectPlacer(params=params) + + table = _make_table() + obj = _make_cylinder("can") + obj.add_relation(On(table)) + + with pytest.raises(AssertionError, match="random_yaw_init"): + placer.place([table, obj]) + + +# --------------------------------------------------------------------------- +# Missing tests identified in review +# --------------------------------------------------------------------------- + + +@requires_warp +def test_mesh_zero_loss_well_separated_cylinders(): + """Mesh mode correctly reports zero loss when objects are clearly separated. + + Uses large enough separation that sphere decomposition inflation (sphere_radius=0.01) + cannot bridge the gap, proving mesh mode reads actual geometry. + """ + a = _make_cylinder("a", radius=0.03, height=0.1) + b = _make_cylinder("b", radius=0.03, height=0.1) + + # Separated by 0.15 in X — well clear of sum-of-radii + inflation + a_pos = (0.0, 0.0, 0.0) + b_pos = (0.15, 0.0, 0.0) + + strategy = NoCollisionLossStrategy(collision_mode=CollisionMode.MESH, slope=10000.0) + child_pos = torch.tensor(a_pos, dtype=torch.float32) + loss = strategy.compute_loss( + clearance_m=0.0, + child_pos=child_pos, + child_bbox=a.get_bounding_box(), + parent_world_bbox=b.get_bounding_box().translated(b_pos), + child_obj=a, + parent_obj=b, + parent_pos=torch.tensor(b_pos, dtype=torch.float32), + ) + assert loss.item() == 0.0, f"Well-separated cylinders should have zero mesh loss, got {loss.item()}" + + +@requires_warp +def test_object_placer_mesh_mode_end_to_end(): + """ObjectPlacer.place() with CollisionMode.MESH returns a valid result.""" + from isaaclab_arena.relations.object_placer import ObjectPlacer + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams + + table = _make_table() + a = _make_cylinder("cyl_a") + b = _make_cylinder("cyl_b") + a.add_relation(On(table)) + b.add_relation(On(table)) + + params = ObjectPlacerParams( + solver_params=RelationSolverParams(collision_mode=CollisionMode.MESH, max_iters=300, verbose=False), + max_placement_attempts=5, + verbose=False, + ) + placer = ObjectPlacer(params=params) + result = placer.place([table, a, b]) + assert result.success, f"Placement failed with loss={result.final_loss}" + + +@requires_warp +def test_dispatch_routes_to_mesh_in_mesh_mode(): + """MESH mode with mesh-bearing objects routes to mesh strategy (produces different loss than AABB).""" + a = _make_cylinder("a", radius=0.03) + b = _make_cylinder("b", radius=0.03) + b.set_initial_pose(Pose(position_xyz=(0.04, 0.0, 0.0), rotation_xyzw=(0.0, 0.0, 0.0, 1.0))) + + child_pos = torch.tensor([0.0, 0.0, 0.0]) + parent_world_bbox = b.get_bounding_box().translated((0.04, 0.0, 0.0)) + + # AABB: overlapping (distance 0.04 < 0.03+0.03=0.06 on X, full overlap on Y/Z) + aabb_strategy = NoCollisionLossStrategy(collision_mode=CollisionMode.BBOX, slope=10000.0) + loss_aabb = aabb_strategy.compute_loss( + clearance_m=0.0, + child_pos=child_pos, + child_bbox=a.get_bounding_box(), + parent_world_bbox=parent_world_bbox, + child_obj=a, + parent_obj=b, + parent_pos=torch.tensor([0.04, 0.0, 0.0]), + ) + + # MESH: may or may not overlap depending on sphere placement, but should differ from AABB + mesh_strategy = NoCollisionLossStrategy(collision_mode=CollisionMode.MESH, slope=10000.0) + loss_mesh = mesh_strategy.compute_loss( + clearance_m=0.0, + child_pos=child_pos, + child_bbox=a.get_bounding_box(), + parent_world_bbox=parent_world_bbox, + child_obj=a, + parent_obj=b, + parent_pos=torch.tensor([0.04, 0.0, 0.0]), + ) + + # Key assertion: AABB loss is positive (boxes overlap), proving dispatch reached AABB path + assert loss_aabb.item() > 0.0 + # Mesh loss must differ from AABB — proves a different code path ran + assert ( + loss_mesh.item() != loss_aabb.item() + ), f"Mesh and AABB losses are identical ({loss_mesh.item()}) — dispatch may not have routed to mesh" + + +@requires_warp +def test_validate_no_overlap_mesh_catches_overlap(): + """Direct test of _validate_no_overlap_mesh: overlapping cylinders should fail validation.""" + from isaaclab_arena.relations.object_placer import ObjectPlacer + from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams + + table = _make_table() + a = _make_cylinder("cyl_a") + b = _make_cylinder("cyl_b") + a.add_relation(On(table)) + b.add_relation(On(table)) + + params = ObjectPlacerParams( + solver_params=RelationSolverParams(collision_mode=CollisionMode.MESH, verbose=False), + verbose=False, + ) + placer = ObjectPlacer(params=params) + + # Overlapping positions + positions = {table: (0.0, 0.0, 0.0), a: (0.0, 0.0, 0.05), b: (0.0, 0.0, 0.05)} + assert not placer._validate_no_overlap_mesh(positions) + + # Separated positions + positions_sep = {table: (0.0, 0.0, 0.0), a: (0.2, 0.0, 0.05), b: (-0.2, 0.0, 0.05)} + assert placer._validate_no_overlap_mesh(positions_sep) + + +@requires_warp +def test_mesh_sdf_backward_gradient(): + """mesh_sdf backward should produce non-zero gradients pointing outward for interior points.""" + from isaaclab_arena.relations.warp_sdf_kernels import mesh_sdf + + mesh = trimesh.creation.box(extents=(0.2, 0.2, 0.2)) + manager = WarpMeshManager(num_spheres=10, device="cpu") + warp_mesh = manager.get_warp_mesh(mesh) + + # Off-center point inside the box, closer to +X face + points = torch.tensor([[0.08, 0.0, 0.0]], dtype=torch.float32, requires_grad=True) + sdf = mesh_sdf(points, warp_mesh) + assert sdf.item() < 0.0, "Point inside box should have negative SDF" + + sdf.backward() + assert points.grad is not None + grad = points.grad[0] + # Gradient should point toward the nearest face (+X direction) + assert grad[0].item() > 0.0, f"Gradient X should be positive (toward +X face), got {grad[0].item()}" + assert abs(grad[0].item()) > abs(grad[1].item()), "X component should dominate (closest to +X face)" + + +@requires_warp +def test_rotated_anchor_raises(): + """Mesh collision with a rotated anchor must raise AssertionError.""" + a = _make_cylinder("child", radius=0.03, height=0.1) + table = _make_cylinder("table", radius=0.2, height=0.05) + # Give the anchor a non-identity rotation + table.set_initial_pose(Pose(position_xyz=(0.0, 0.0, 0.0), rotation_xyzw=(0.0, 0.0, 0.383, 0.924))) + + strategy = NoCollisionLossStrategy(collision_mode=CollisionMode.MESH, slope=10000.0) + child_pos = torch.tensor([0.05, 0.0, 0.0], dtype=torch.float32) + dummy_bbox = a.get_bounding_box() + parent_bbox = table.get_bounding_box().translated((0.0, 0.0, 0.0)) + + with pytest.raises(AssertionError, match="rotated anchor"): + strategy.compute_loss( + clearance_m=0.0, + child_pos=child_pos, + child_bbox=dummy_bbox, + parent_world_bbox=parent_bbox, + child_obj=a, + parent_obj=table, + parent_pos=None, + ) diff --git a/isaaclab_arena/utils/usd_helpers.py b/isaaclab_arena/utils/usd_helpers.py index 9238ae275..ce884e6e8 100644 --- a/isaaclab_arena/utils/usd_helpers.py +++ b/isaaclab_arena/utils/usd_helpers.py @@ -3,12 +3,19 @@ # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import numpy as np from contextlib import contextmanager +from typing import TYPE_CHECKING from pxr import Gf, Usd, UsdGeom, UsdLux, UsdPhysics from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox +if TYPE_CHECKING: + import trimesh + def get_all_prims( stage: Usd.Stage, prim: Usd.Prim | None = None, prims_list: list[Usd.Prim] | None = None @@ -196,3 +203,67 @@ def compute_local_bounding_box_from_prim( min_point=(local_min[0], local_min[1], local_min[2]), max_point=(local_max[0], local_max[1], local_max[2]), ) + + +def extract_trimesh_from_usd( + usd_path: str, + scale: tuple[float, float, float] = (1.0, 1.0, 1.0), +) -> trimesh.Trimesh: + """Extract all mesh geometry from a USD file into a single trimesh. + + Traverses the stage, applies per-prim world transforms, fan-triangulates + non-triangle faces, and applies scale to the combined result. + + Args: + usd_path: Path to the .usd/.usda/.usdc file. + scale: (sx, sy, sz) scale factors applied to the final mesh. + + Returns: + Combined trimesh in the USD default prim's local frame. + + Raises: + ValueError: If the file cannot be opened or contains no mesh prims. + """ + import trimesh as _trimesh + + stage = Usd.Stage.Open(usd_path) + if stage is None: + raise ValueError(f"Failed to open USD: {usd_path}") + + all_verts: list[np.ndarray] = [] + all_faces: list[list[int]] = [] + offset = 0 + + for prim in stage.Traverse(): + if not prim.IsA(UsdGeom.Mesh): + continue + mesh_prim = UsdGeom.Mesh(prim) + pts = mesh_prim.GetPointsAttr().Get() + fvc = mesh_prim.GetFaceVertexCountsAttr().Get() + fvi = mesh_prim.GetFaceVertexIndicesAttr().Get() + if pts is None or fvc is None or fvi is None: + continue + + xform = UsdGeom.Xformable(prim) + world_tf = np.array(xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default())).T + + verts = np.asarray(pts, dtype=np.float64) + verts_h = np.hstack([verts, np.ones((len(verts), 1))]) + verts_world = (verts_h @ world_tf)[:, :3] + verts_world[:, 0] *= scale[0] + verts_world[:, 1] *= scale[1] + verts_world[:, 2] *= scale[2] + + # Fan-triangulate faces + idx = 0 + for count in fvc: + for k in range(1, count - 1): + all_faces.append([fvi[idx] + offset, fvi[idx + k] + offset, fvi[idx + k + 1] + offset]) + idx += count + + all_verts.append(verts_world) + offset += len(verts_world) + + if not all_verts: + raise ValueError(f"No mesh geometry found in {usd_path}") + return _trimesh.Trimesh(vertices=np.vstack(all_verts), faces=np.array(all_faces, dtype=np.int32)) From 5e86ed0a08823f922f3c5567b8bde096b2621ba8 Mon Sep 17 00:00:00 2001 From: zhx06 Date: Tue, 9 Jun 2026 17:17:57 -0700 Subject: [PATCH 2/3] address comments --- isaaclab_arena/relations/object_placer.py | 22 ++++--- .../scripts/benchmark_collision_modes.py | 0 isaaclab_arena/tests/test_mesh_collision.py | 60 ------------------- isaaclab_arena/utils/usd_helpers.py | 2 +- 4 files changed, 16 insertions(+), 68 deletions(-) mode change 100644 => 100755 isaaclab_arena/scripts/benchmark_collision_modes.py diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index 101a4760f..fd14b71e9 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -604,6 +604,17 @@ def _validate_no_overlap( return False return True + def _get_cpu_mesh_manager(self): + """Lazily create a CPU WarpMeshManager, cached across validation calls.""" + if not hasattr(self, "_cpu_mesh_manager"): + from isaaclab_arena.relations.warp_mesh_manager import WarpMeshManager + + self._cpu_mesh_manager = WarpMeshManager( + num_spheres=self.params.solver_params.num_spheres, + device="cpu", + ) + return self._cpu_mesh_manager + def _validate_no_overlap_mesh( self, positions: dict[ObjectBase, tuple[float, float, float]], @@ -614,7 +625,6 @@ def _validate_no_overlap_mesh( Skips pairs where either object lacks a collision mesh (the solver's loss path still penalizes those via AABB fallback during optimization). """ - from isaaclab_arena.relations.warp_mesh_manager import WarpMeshManager from isaaclab_arena.relations.warp_sdf_kernels import mesh_sdf on_pairs: set[tuple] = set() @@ -629,10 +639,7 @@ def _validate_no_overlap_mesh( clearance_m = self.params.solver_params.clearance_m tolerance = max(0.0, clearance_m - 1e-6) - manager = WarpMeshManager( - num_spheres=self.params.solver_params.num_spheres, - device="cpu", - ) + manager = self._get_cpu_mesh_manager() warned_no_mesh: set[str] = set() objects = list(positions.keys()) @@ -693,11 +700,12 @@ def _validate_placement( Returns: True if no overlaps exist and On relations hold, False otherwise. """ - if not self._validate_no_overlap(positions, env_bboxes): - return False if self.params.solver_params.collision_mode == CollisionMode.MESH: if not self._validate_no_overlap_mesh(positions): return False + else: + if not self._validate_no_overlap(positions, env_bboxes): + return False return self._validate_on_relations(positions, env_bboxes) def _apply_poses( diff --git a/isaaclab_arena/scripts/benchmark_collision_modes.py b/isaaclab_arena/scripts/benchmark_collision_modes.py old mode 100644 new mode 100755 diff --git a/isaaclab_arena/tests/test_mesh_collision.py b/isaaclab_arena/tests/test_mesh_collision.py index 47435d9d3..2e0715d01 100644 --- a/isaaclab_arena/tests/test_mesh_collision.py +++ b/isaaclab_arena/tests/test_mesh_collision.py @@ -99,13 +99,6 @@ def test_sphere_decomposition_covers_surface(): assert coverage > 0.8, f"Coverage only {coverage:.1%}" -def test_sphere_count_respects_budget(): - """Output sphere count should not exceed num_spheres.""" - mesh = trimesh.creation.box(extents=(0.1, 0.1, 0.1)) - spheres = greedy_sphere_decomposition(mesh, num_spheres=5) - assert len(spheres) <= 5 - - # --------------------------------------------------------------------------- # Unit: WarpMeshManager caching # --------------------------------------------------------------------------- @@ -121,17 +114,6 @@ def test_warp_mesh_caching(): assert m1 is m2 -@requires_warp -def test_cache_key_differs_for_different_meshes(): - """Different trimeshes should produce different cache entries.""" - mesh_a = trimesh.creation.box(extents=(0.1, 0.1, 0.1)) - mesh_b = trimesh.creation.cylinder(radius=0.05, height=0.1) - manager = WarpMeshManager(num_spheres=10) - ma = manager.get_warp_mesh(mesh_a) - mb = manager.get_warp_mesh(mesh_b) - assert ma is not mb - - # --------------------------------------------------------------------------- # Unit: NoCollisionLossStrategy routing # --------------------------------------------------------------------------- @@ -161,25 +143,6 @@ def test_dispatch_routes_to_aabb_in_bbox_mode(): assert torch.isclose(loss, torch.tensor(0.0), atol=1e-5) -def test_dispatch_falls_back_when_obj_is_none(): - """Missing object refs should fall back to AABB.""" - dispatch = NoCollisionLossStrategy(collision_mode=CollisionMode.MESH, slope=10.0) - - bbox = AxisAlignedBoundingBox(min_point=(0.0, 0.0, 0.0), max_point=(0.1, 0.1, 0.1)) - child_pos = torch.tensor([0.0, 0.0, 0.0]) - parent_world_bbox = bbox.translated((0.5, 0.0, 0.0)) - - loss = dispatch.compute_loss( - clearance_m=0.0, - child_pos=child_pos, - child_bbox=bbox, - parent_world_bbox=parent_world_bbox, - child_obj=None, - parent_obj=None, - ) - assert torch.isclose(loss, torch.tensor(0.0), atol=1e-5) - - def test_dispatch_falls_back_when_no_mesh(): """MESH mode with objects lacking collision_mesh should use AABB.""" dispatch = NoCollisionLossStrategy(collision_mode=CollisionMode.MESH, slope=10.0) @@ -212,29 +175,6 @@ def test_dispatch_falls_back_when_no_mesh(): # --------------------------------------------------------------------------- -@requires_warp -def test_mesh_zero_loss_separated_cylinders(): - """Two cylinders far apart should produce zero mesh collision loss.""" - dispatch = NoCollisionLossStrategy(collision_mode=CollisionMode.MESH, slope=10000.0) - a = _make_cylinder("a") - b = _make_cylinder("b") - b.set_initial_pose(Pose(position_xyz=(1.0, 0.0, 0.0), rotation_xyzw=(0.0, 0.0, 0.0, 1.0))) - - child_pos = torch.tensor([0.0, 0.0, 0.0]) - parent_world_bbox = b.get_bounding_box().translated((1.0, 0.0, 0.0)) - - loss = dispatch.compute_loss( - clearance_m=0.0, - child_pos=child_pos, - child_bbox=a.get_bounding_box(), - parent_world_bbox=parent_world_bbox, - child_obj=a, - parent_obj=b, - parent_pos=torch.tensor([1.0, 0.0, 0.0]), - ) - assert loss.item() == 0.0 - - @requires_warp def test_mesh_positive_loss_overlapping_cylinders(): """Two cylinders at the same position should produce positive mesh loss.""" diff --git a/isaaclab_arena/utils/usd_helpers.py b/isaaclab_arena/utils/usd_helpers.py index ce884e6e8..1fe0904e7 100644 --- a/isaaclab_arena/utils/usd_helpers.py +++ b/isaaclab_arena/utils/usd_helpers.py @@ -245,7 +245,7 @@ def extract_trimesh_from_usd( continue xform = UsdGeom.Xformable(prim) - world_tf = np.array(xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default())).T + world_tf = np.array(xform.ComputeLocalToWorldTransform(Usd.TimeCode.Default())) verts = np.asarray(pts, dtype=np.float64) verts_h = np.hstack([verts, np.ones((len(verts), 1))]) From 729d892c6fdbf0813bdae767326e965b8959942e Mon Sep 17 00:00:00 2001 From: zhx06 Date: Tue, 9 Jun 2026 17:30:57 -0700 Subject: [PATCH 3/3] address agent comments --- isaaclab_arena/relations/relation_loss_strategies.py | 1 + isaaclab_arena/relations/warp_mesh_manager.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/isaaclab_arena/relations/relation_loss_strategies.py b/isaaclab_arena/relations/relation_loss_strategies.py index fcdb714f8..8e68b0613 100644 --- a/isaaclab_arena/relations/relation_loss_strategies.py +++ b/isaaclab_arena/relations/relation_loss_strategies.py @@ -596,6 +596,7 @@ def _compute_mesh_loss( warp_mesh = manager.get_warp_mesh(parent_mesh, obj=parent_obj) batch_size = child_pos.shape[0] + parent_pos_resolved = parent_pos_resolved.expand(batch_size, -1) total_loss = torch.zeros(batch_size, device=device, dtype=child_pos.dtype) for b in range(batch_size): diff --git a/isaaclab_arena/relations/warp_mesh_manager.py b/isaaclab_arena/relations/warp_mesh_manager.py index ae52c2765..4bc78451d 100644 --- a/isaaclab_arena/relations/warp_mesh_manager.py +++ b/isaaclab_arena/relations/warp_mesh_manager.py @@ -138,7 +138,7 @@ def _cache_key(self, mesh: trimesh.Trimesh, obj=None) -> tuple: """Compute cache key. Uses (usd_path, scale) for USD objects, content hash otherwise.""" usd_path = getattr(obj, "usd_path", None) if obj is not None else None if usd_path is not None: - scale = getattr(obj, "scale", (1.0, 1.0, 1.0)) + scale = tuple(getattr(obj, "scale", (1.0, 1.0, 1.0))) return (usd_path, scale, self._num_spheres, self._sphere_radius) return (_mesh_content_hash(mesh), self._num_spheres, self._sphere_radius)