From 9d155d8162fe63ff703bbbb26865d57b63b765c4 Mon Sep 17 00:00:00 2001 From: Thamires Taboza Date: Sun, 28 Jun 2026 23:53:37 -0300 Subject: [PATCH] refactor: Remove too-many-statements in multiple files This change resolves the too-many-statements Pylint warnings by extracting monolithic blocks into specialized private functions. The refactoring directly addresses code smells within: - doctr/models/detection/differentiable_binarization/base.py - doctr/models/layout/lw_detr/pytorch.py - doctr/transforms/modules/pytorch.py - doctr/utils/metrics.py Isolating these heavy operations ensures the main execution paths remain readable and lowers cognitive load without changing behavior. --- .../differentiable_binarization/base.py | 62 +++--- doctr/models/layout/lw_detr/pytorch.py | 201 ++++++++++++------ doctr/transforms/modules/pytorch.py | 181 ++++++++-------- doctr/utils/metrics.py | 192 ++++++++--------- 4 files changed, 354 insertions(+), 282 deletions(-) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 5f8d1e90e5..697e134249 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -267,6 +267,40 @@ def draw_thresh_map( return polygon, canvas, mask + def _draw_polygon_on_maps( + self, + poly: np.ndarray, + box: np.ndarray, + box_size: float, + idx: int, + class_idx: int, + seg_target: np.ndarray, + seg_mask: np.ndarray, + thresh_target: np.ndarray, + thresh_mask: np.ndarray, + ) -> None: + if box_size < self.min_size_box: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + return + + polygon = Polygon(poly) + distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length + subject = [tuple(coor) for coor in poly] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrunken = padding.Execute(-distance) + + if len(shrunken) == 0: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + return + shrunken = np.array(shrunken[0]).reshape(-1, 2) + if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: + seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False + return + cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) + + self.draw_thresh_map(poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx]) + def build_target( self, target: list[dict[str, np.ndarray]], @@ -321,32 +355,8 @@ def build_target( boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) for poly, box, box_size in zip(polys, abs_boxes, boxes_size): - # Mask boxes that are too small - if box_size < self.min_size_box: - seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False - continue - - # Negative shrink for gt, as described in paper - polygon = Polygon(poly) - distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length - subject = [tuple(coor) for coor in poly] - padding = pyclipper.PyclipperOffset() - padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) - shrunken = padding.Execute(-distance) - - # Draw polygon on gt if it is valid - if len(shrunken) == 0: - seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False - continue - shrunken = np.array(shrunken[0]).reshape(-1, 2) - if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: - seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False - continue - cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) - - # Draw on both thresh map and thresh mask - poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map( - poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] + self._draw_polygon_on_maps( + poly, box, box_size, idx, class_idx, seg_target, seg_mask, thresh_target, thresh_mask ) thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min diff --git a/doctr/models/layout/lw_detr/pytorch.py b/doctr/models/layout/lw_detr/pytorch.py index 521f6df7cc..8f9b4a7211 100644 --- a/doctr/models/layout/lw_detr/pytorch.py +++ b/doctr/models/layout/lw_detr/pytorch.py @@ -449,15 +449,9 @@ def gen_encoder_output_proposals( return object_query, output_proposals, invalid_mask - def forward( - self, - input: torch.Tensor, - masks: torch.Tensor | None = None, - target: list[dict[str, np.ndarray]] | None = None, - return_model_output: bool = False, - return_preds: bool = False, - **kwargs: Any, - ) -> dict[str, Any]: + def _extract_features( + self, input: torch.Tensor, masks: torch.Tensor | None + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: feats = self.feat_extractor(input, masks) sources: list[torch.Tensor] = [] @@ -469,20 +463,26 @@ def forward( if mask is None: # pragma: no cover raise ValueError("No attention mask was provided") + return sources, feats_masks + + def _setup_queries(self) -> tuple[torch.Tensor, torch.Tensor]: if self.training: reference_points = self.reference_point_embed.weight query_feat = self.query_feat.weight else: - # only use one group in inference reference_points = self.reference_point_embed.weight[: self.num_queries] query_feat = self.query_feat.weight[: self.num_queries] - # Prepare encoder inputs (by flattening) + return reference_points, query_feat + + def _prepare_encoder_inputs( + self, sources: list[torch.Tensor], feats_masks: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]], torch.Tensor, torch.Tensor, torch.Tensor]: source_flatten_list: list[torch.Tensor] = [] mask_flatten_list: list[torch.Tensor] = [] spatial_shapes_list: list[tuple[int, int]] = [] for source, mask in zip(sources, feats_masks): - batch_size, num_channels, height, width = source.shape + _, _, height, width = source.shape spatial_shape = (height, width) spatial_shapes_list.append(spatial_shape) source = source.flatten(2).transpose(1, 2) @@ -492,19 +492,22 @@ def forward( source_flatten = torch.cat(source_flatten_list, 1) mask_flatten = torch.cat(mask_flatten_list, 1) - tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1) - reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1) - object_query_embedding, output_proposals, invalid_mask = self.gen_encoder_output_proposals( source_flatten, mask_flatten, spatial_shapes_list ) + return source_flatten, mask_flatten, spatial_shapes_list, object_query_embedding, output_proposals, invalid_mask + + def _encoder_group_predictions( + self, + object_query_embedding: torch.Tensor, + output_proposals: torch.Tensor, + invalid_mask: torch.Tensor, + ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: group_detr = self.group_detr if self.training else 1 topk = self.num_queries topk_coords_logits_list: list[torch.Tensor] = [] - - # encoder predictions on the selected top-k proposals, kept undetached for the auxiliary loss all_group_enc_logits: list[torch.Tensor] = [] all_group_enc_coords: list[torch.Tensor] = [] @@ -513,7 +516,6 @@ def forward( group_object_query = self.enc_output_norm[group_id](group_object_query) group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query) - group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf")) group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query) @@ -526,8 +528,6 @@ def forward( 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6), ) - # the auxiliary loss supervises only the selected proposals, - # so gather the matching class logits as well group_topk_logits_undetach = torch.gather( group_enc_outputs_class, 1, @@ -535,13 +535,18 @@ def forward( ) all_group_enc_logits.append(group_topk_logits_undetach) all_group_enc_coords.append(group_topk_coords_logits_undetach) - - # the decoder consumes detached proposals as initial reference points topk_coords_logits_list.append(group_topk_coords_logits_undetach.detach()) - topk_coords_logits = torch.cat(topk_coords_logits_list, 1) - reference_points = refine_obb_boxes(topk_coords_logits, reference_points) + return topk_coords_logits_list, all_group_enc_logits, all_group_enc_coords + def _decode_and_predict( + self, + tgt: torch.Tensor, + reference_points: torch.Tensor, + source_flatten: torch.Tensor, + mask_flatten: torch.Tensor, + spatial_shapes_list: list[tuple[int, int]], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]: last_hidden_states, intermediate, intermediate_reference_points = self.decoder( inputs_embeds=tgt, reference_points=reference_points, @@ -554,62 +559,124 @@ def forward( pred_boxes_delta = self.bbox_embed(last_hidden_states) pred_boxes = refine_obb_boxes(intermediate_reference_points[-1], pred_boxes_delta) - out: dict[str, Any] = {} + return logits, pred_boxes, intermediate, intermediate_reference_points - if self.exportable: - out["logits"] = logits - out["pred_boxes"] = pred_boxes - return out + def _prepare_outputs( + self, + logits: torch.Tensor, + pred_boxes: torch.Tensor, + return_model_output: bool, + target: list[dict[str, np.ndarray]] | None, + return_preds: bool, + ) -> dict[str, Any]: + out: dict[str, Any] = {} if return_model_output or target is None or return_preds: out["logits"] = logits if target is None or return_preds: - # Disable for torch.compile compatibility + @torch.compiler.disable def _postprocess(logits, boxes): return self.postprocessor(logits, boxes) out["preds"] = _postprocess(logits.detach().cpu().numpy(), pred_boxes.detach().cpu().numpy()) + return out + + def _compute_losses( + self, + logits: torch.Tensor, + pred_boxes: torch.Tensor, + target: list[dict[str, np.ndarray]], + input: torch.Tensor, + group_detr: int, + intermediate: torch.Tensor, + intermediate_reference_points: list[torch.Tensor], + all_group_enc_logits: list[torch.Tensor], + all_group_enc_coords: list[torch.Tensor], + ) -> torch.Tensor: + processed_targets = self.build_target(target, self.class_names) + + box_scale = float(max(input.shape[-2], input.shape[-1])) + + split_logits = logits.chunk(group_detr, dim=1) + split_boxes = pred_boxes.chunk(group_detr, dim=1) + + main_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_logits, split_boxes): + main_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) + loss = main_loss / group_detr + + for i in range(intermediate.shape[0] - 1): + aux_logits = self.class_embed(intermediate[i]) + aux_boxes_delta = self.bbox_embed(intermediate[i]) + aux_boxes = refine_obb_boxes(intermediate_reference_points[i], aux_boxes_delta) + + split_aux_logits = aux_logits.chunk(group_detr, dim=1) + split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) + + aux_loss: float | torch.Tensor = 0.0 + for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): + aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) + loss += aux_loss / group_detr + + enc_loss: float | torch.Tensor = 0.0 + for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): + enc_loss += self.compute_loss(group_logits, group_coords, processed_targets, box_scale=box_scale) + loss += enc_loss / group_detr + + return loss + + def forward( + self, + input: torch.Tensor, + masks: torch.Tensor | None = None, + target: list[dict[str, np.ndarray]] | None = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> dict[str, Any]: + sources, feats_masks = self._extract_features(input, masks) + reference_points, query_feat = self._setup_queries() + + batch_size = sources[0].shape[0] + tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1) + reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1) + + source_flatten, mask_flatten, spatial_shapes_list, object_query_embedding, output_proposals, invalid_mask = ( + self._prepare_encoder_inputs(sources, feats_masks) + ) + + topk_coords_logits_list, all_group_enc_logits, all_group_enc_coords = self._encoder_group_predictions( + object_query_embedding, output_proposals, invalid_mask + ) + + topk_coords_logits = torch.cat(topk_coords_logits_list, 1) + reference_points = refine_obb_boxes(topk_coords_logits, reference_points) + + logits, pred_boxes, intermediate, intermediate_reference_points = self._decode_and_predict( + tgt, reference_points, source_flatten, mask_flatten, spatial_shapes_list + ) + + if self.exportable: + return {"logits": logits, "pred_boxes": pred_boxes} + + out = self._prepare_outputs(logits, pred_boxes, return_model_output, target, return_preds) + if target is not None: - # Build target - processed_targets = self.build_target(target, self.class_names) - - # ProbIoU is computed in pixel coordinates - box_scale = float(max(input.shape[-2], input.shape[-1])) - - # Main loss from final decoder layer (group DETR: each group is matched independently) - split_logits = logits.chunk(group_detr, dim=1) - split_boxes = pred_boxes.chunk(group_detr, dim=1) - - main_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_logits, split_boxes): - main_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) - loss = main_loss / group_detr - - # Auxiliary losses from intermediate decoder layers - # (`intermediate_reference_points[i]` is the reference INPUT to decoder layer i) - for i in range(intermediate.shape[0] - 1): - aux_logits = self.class_embed(intermediate[i]) - aux_boxes_delta = self.bbox_embed(intermediate[i]) - aux_boxes = refine_obb_boxes(intermediate_reference_points[i], aux_boxes_delta) - - split_aux_logits = aux_logits.chunk(group_detr, dim=1) - split_aux_boxes = aux_boxes.chunk(group_detr, dim=1) - - aux_loss: float | torch.Tensor = 0.0 - for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes): - aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale) - loss += aux_loss / group_detr - - # Auxiliary losses for the selected encoder proposals - enc_loss: float | torch.Tensor = 0.0 - for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords): - enc_loss += self.compute_loss(group_logits, group_coords, processed_targets, box_scale=box_scale) - loss += enc_loss / group_detr - - out["loss"] = loss + group_detr = self.group_detr if self.training else 1 + out["loss"] = self._compute_losses( + logits, + pred_boxes, + target, + input, + group_detr, + intermediate, + intermediate_reference_points, + all_group_enc_logits, + all_group_enc_coords, + ) return out diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index c6b7881b56..8fd87b60a6 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -100,8 +100,6 @@ def forward( target = sample.target mask = sample.mask - # Resize mask alongside image if provided - # Masks should use nearest interpolation to preserve label integrity resize_mask = mask is not None if resize_mask and mask is not None and mask.ndim == 2: mask = mask.unsqueeze(0) @@ -110,108 +108,123 @@ def forward( actual_ratio = img.shape[-2] / img.shape[-1] if not self.preserve_aspect_ratio or (target_ratio == actual_ratio): - # If we don't preserve the aspect ratio or the wanted aspect ratio is the same than the original one - # We can use with the regular resize - img = super().forward(img) - - if resize_mask: - mask = F.resize( - mask, - self.size, - interpolation=F.InterpolationMode.NEAREST, - antialias=False, - ).squeeze(0) + return self._forward_simple(img, target, mask, resize_mask, sample) - if self.return_padding_mask: - padding_mask = torch.zeros(self.size, dtype=torch.bool, device=img.device) + return self._forward_with_aspect_ratio(img, target, mask, resize_mask, sample) - if target is not None: - if self.return_padding_mask: - return sample.replace(image=img, target=target, mask=mask if resize_mask else padding_mask) - return sample.replace(image=img, target=target, mask=mask if resize_mask else sample.mask) + def _forward_simple( + self, + img: torch.Tensor, + target: np.ndarray | str | None, + mask: torch.Tensor | None, + resize_mask: bool, + sample: Sample, + ) -> Sample: + img = super().forward(img) + + if resize_mask: + mask = F.resize( + mask, + self.size, + interpolation=F.InterpolationMode.NEAREST, + antialias=False, + ).squeeze(0) + + if self.return_padding_mask: + padding_mask = torch.zeros(self.size, dtype=torch.bool, device=img.device) + + if target is not None: if self.return_padding_mask: - return sample.replace(image=img, mask=mask if resize_mask else padding_mask) - return sample.replace(image=img, mask=mask if resize_mask else sample.mask) + return sample.replace(image=img, target=target, mask=mask if resize_mask else padding_mask) + return sample.replace(image=img, target=target, mask=mask if resize_mask else sample.mask) + if self.return_padding_mask: + return sample.replace(image=img, mask=mask if resize_mask else padding_mask) + return sample.replace(image=img, mask=mask if resize_mask else sample.mask) + + def _forward_with_aspect_ratio( + self, + img: torch.Tensor, + target: np.ndarray | str | None, + mask: torch.Tensor | None, + resize_mask: bool, + sample: Sample, + ) -> Sample: + target_ratio = self.size[0] / self.size[1] + actual_ratio = img.shape[-2] / img.shape[-1] + if actual_ratio > target_ratio: + tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1)) else: - # Resize - if actual_ratio > target_ratio: - tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1)) - else: - tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1]) + tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1]) - # Scale image - img = F.resize(img, tmp_size, self.interpolation, antialias=True) + img = F.resize(img, tmp_size, self.interpolation, antialias=True) - if resize_mask: - mask = F.resize( - mask, - tmp_size, - interpolation=F.InterpolationMode.NEAREST, - antialias=False, - ).squeeze(0) + if resize_mask: + mask = F.resize( + mask, + tmp_size, + interpolation=F.InterpolationMode.NEAREST, + antialias=False, + ).squeeze(0) - raw_shape = img.shape[-2:] + raw_shape = img.shape[-2:] - if isinstance(self.size, (tuple, list)): - # Pad (inverted in pytorch) - _pad = (0, self.size[1] - img.shape[-1], 0, self.size[0] - img.shape[-2]) + if isinstance(self.size, (tuple, list)): + _pad = (0, self.size[1] - img.shape[-1], 0, self.size[0] - img.shape[-2]) - if self.symmetric_pad: - half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2)) - _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1]) - # Pad image - img = pad(img, _pad) + if self.symmetric_pad: + half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2)) + _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1]) - if resize_mask and mask is not None: - mask = pad(mask, _pad) + img = pad(img, _pad) - if self.return_padding_mask: - h, w = self.size - padding_mask = torch.zeros((h, w), dtype=torch.bool, device=img.device) - left, right, top, bottom = _pad - padding_mask[top : h - bottom, left : w - right] = True + if resize_mask and mask is not None: + mask = pad(mask, _pad) - # In case boxes are provided, resize boxes if needed (for detection task if preserve aspect ratio) - if target is not None: - if self.symmetric_pad: - offset = ( - half_pad[0] / img.shape[-1], - half_pad[1] / img.shape[-2], - ) - else: - offset = (0, 0) - - if isinstance(target, str) or (isinstance(target, np.ndarray) and target.shape == (1,)): - # Special case for orientation targets and other non-box targets, which should not be resized - pass - elif isinstance(target, dict): - target = { - cls_name: self._resize_target( - arr, - raw_shape, - img.shape[-2:], - symmetric_pad=self.symmetric_pad, - offset=offset, - ) - for cls_name, arr in target.items() - } - else: - target = self._resize_target( - target, + if self.return_padding_mask: + h, w = self.size + padding_mask = torch.zeros((h, w), dtype=torch.bool, device=img.device) + left, right, top, bottom = _pad + padding_mask[top : h - bottom, left : w - right] = True + + if target is not None: + if self.symmetric_pad: + offset = ( + half_pad[0] / img.shape[-1], + half_pad[1] / img.shape[-2], + ) + else: + offset = (0, 0) + + if isinstance(target, str) or (isinstance(target, np.ndarray) and target.shape == (1,)): + pass + elif isinstance(target, dict): + target = { + cls_name: self._resize_target( + arr, raw_shape, img.shape[-2:], symmetric_pad=self.symmetric_pad, offset=offset, ) + for cls_name, arr in target.items() + } + else: + target = self._resize_target( + target, + raw_shape, + img.shape[-2:], + symmetric_pad=self.symmetric_pad, + offset=offset, + ) - if target is not None: - if self.return_padding_mask: - return sample.replace(image=img, target=target, mask=mask if resize_mask else padding_mask) - return sample.replace(image=img, target=target, mask=mask if resize_mask else sample.mask) + if target is not None: if self.return_padding_mask: - return sample.replace(image=img, mask=mask if resize_mask else padding_mask) - return sample.replace(image=img, mask=mask if resize_mask else sample.mask) + return sample.replace(image=img, target=target, mask=mask if resize_mask else padding_mask) + return sample.replace(image=img, target=target, mask=mask if resize_mask else sample.mask) + if self.return_padding_mask: + return sample.replace(image=img, mask=mask if resize_mask else padding_mask) + return sample.replace(image=img, mask=mask if resize_mask else sample.mask) def __repr__(self) -> str: interpolate_str = self.interpolation.value diff --git a/doctr/utils/metrics.py b/doctr/utils/metrics.py index 5d01cc8bd1..3fc394d144 100644 --- a/doctr/utils/metrics.py +++ b/doctr/utils/metrics.py @@ -656,124 +656,106 @@ def summary(self) -> dict[str, float | dict[float, float]]: if len(self._gts) == 0: raise AssertionError("No samples added") - # Determine classes - if self.num_classes is None: - labels = [] - for g in self._gts: - labels.extend(g["labels"].tolist()) - for p in self._preds: - labels.extend(p["labels"].tolist()) - classes = np.unique(labels) - else: - classes = np.arange(self.num_classes) - + classes = self._get_classes() ap_per_iou = {} for iou_thresh in self.iou_thresholds: class_aps = [] - for c in classes: - # Collect GTs per image - gt_by_image = {} - total_gt = 0 - - for img_idx, gt in enumerate(self._gts): - mask = gt["labels"] == c - gt_boxes = gt["boxes"][mask] - - gt_by_image[img_idx] = { - "boxes": gt_boxes, - "matched": np.zeros(len(gt_boxes), dtype=bool), - } - - total_gt += len(gt_boxes) - - if total_gt == 0: - continue - - # Collect all detections globally - detections = [] - - for img_idx, pred in enumerate(self._preds): - mask = pred["labels"] == c - - pred_boxes = pred["boxes"][mask] - pred_scores = pred["scores"][mask] - - for box, score in zip(pred_boxes, pred_scores): - detections.append({ - "image_id": img_idx, - "box": box, - "score": float(score), - }) - - if len(detections) == 0: - class_aps.append(0.0) - continue - - # Global sorting COCO-style - detections.sort(key=lambda x: -x["score"]) - - tp = np.zeros(len(detections)) - fp = np.zeros(len(detections)) - - # Evaluate detections - for det_idx, det in enumerate(detections): - img_idx = det["image_id"] - pred_box = det["box"] - - gt_data = gt_by_image[img_idx] - gt_boxes = gt_data["boxes"] - - if len(gt_boxes) == 0: - fp[det_idx] = 1 - continue + ap = self._evaluate_class(c, iou_thresh) + if ap is not None: + class_aps.append(ap) + ap_per_iou[float(iou_thresh)] = float(np.mean(class_aps)) if class_aps else 0.0 - # Compute IoUs - if self.use_polygons: - iou_mat = polygon_iou( - gt_boxes, - np.expand_dims(pred_box, axis=0), - ) - else: - iou_mat = box_iou( - gt_boxes, - np.expand_dims(pred_box, axis=0), - ) - - ious = iou_mat[:, 0] - - best_gt = np.argmax(ious) - best_iou = ious[best_gt] + map_value = float(np.mean(list(ap_per_iou.values()))) - if best_iou >= iou_thresh and not gt_data["matched"][best_gt]: - tp[det_idx] = 1 - gt_data["matched"][best_gt] = True - else: - fp[det_idx] = 1 + return { + "mAP@[.5:.95]": map_value, + "AP@[.5]": ap_per_iou.get(0.5, 0.0), + "AP@[.75]": ap_per_iou.get(0.75, 0.0), + "AP_per_IoU": ap_per_iou, + } - # Precision / Recall - tp_cum = np.cumsum(tp) - fp_cum = np.cumsum(fp) + def _get_classes(self) -> np.ndarray: + if self.num_classes is None: + labels = [] + for g in self._gts: + labels.extend(g["labels"].tolist()) + for p in self._preds: + labels.extend(p["labels"].tolist()) + return np.unique(labels) + return np.arange(self.num_classes) + + def _collect_gt_by_image(self, class_id) -> tuple[dict, int]: + gt_by_image = {} + total_gt = 0 + for img_idx, gt in enumerate(self._gts): + mask = gt["labels"] == class_id + gt_boxes = gt["boxes"][mask] + gt_by_image[img_idx] = { + "boxes": gt_boxes, + "matched": np.zeros(len(gt_boxes), dtype=bool), + } + total_gt += len(gt_boxes) + return gt_by_image, total_gt + + def _collect_detections(self, class_id) -> list[dict]: + detections = [] + for img_idx, pred in enumerate(self._preds): + mask = pred["labels"] == class_id + pred_boxes = pred["boxes"][mask] + pred_scores = pred["scores"][mask] + for box, score in zip(pred_boxes, pred_scores): + detections.append({ + "image_id": img_idx, + "box": box, + "score": float(score), + }) + return detections + + def _match_detections(self, detections, gt_by_image, iou_thresh) -> tuple[np.ndarray, np.ndarray]: + tp = np.zeros(len(detections)) + fp = np.zeros(len(detections)) + for det_idx, det in enumerate(detections): + img_idx = det["image_id"] + pred_box = det["box"] + gt_data = gt_by_image[img_idx] + gt_boxes = gt_data["boxes"] + if len(gt_boxes) == 0: + fp[det_idx] = 1 + continue + if self.use_polygons: + iou_mat = polygon_iou(gt_boxes, np.expand_dims(pred_box, axis=0)) + else: + iou_mat = box_iou(gt_boxes, np.expand_dims(pred_box, axis=0)) + ious = iou_mat[:, 0] + best_gt = np.argmax(ious) + best_iou = ious[best_gt] + if best_iou >= iou_thresh and not gt_data["matched"][best_gt]: + tp[det_idx] = 1 + gt_data["matched"][best_gt] = True + else: + fp[det_idx] = 1 + return tp, fp - recall = tp_cum / total_gt - precision = tp_cum / np.maximum(tp_cum + fp_cum, 1e-8) + def _evaluate_class(self, class_id, iou_thresh) -> float | None: + gt_by_image, total_gt = self._collect_gt_by_image(class_id) + if total_gt == 0: + return None - ap = self._compute_ap(recall, precision) - class_aps.append(ap) + detections = self._collect_detections(class_id) + if len(detections) == 0: + return 0.0 - ap_per_iou[float(iou_thresh)] = float(np.mean(class_aps)) if len(class_aps) > 0 else 0.0 + detections.sort(key=lambda x: -x["score"]) + tp, fp = self._match_detections(detections, gt_by_image, iou_thresh) - map_value = float(np.mean(list(ap_per_iou.values()))) - ap50 = ap_per_iou.get(0.5, 0.0) - ap75 = ap_per_iou.get(0.75, 0.0) + tp_cum = np.cumsum(tp) + fp_cum = np.cumsum(fp) + recall = tp_cum / total_gt + precision = tp_cum / np.maximum(tp_cum + fp_cum, 1e-8) - return { - "mAP@[.5:.95]": map_value, - "AP@[.5]": ap50, - "AP@[.75]": ap75, - "AP_per_IoU": ap_per_iou, - } + return self._compute_ap(recall, precision) def _compute_ap(self, recall: np.ndarray, precision: np.ndarray) -> float: """Computes the Average Precision using the 101-point interpolation method from COCO