Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions doctr/models/detection/differentiable_binarization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,40 @@ def draw_thresh_map(

return polygon, canvas, mask

def _draw_polygon_on_maps(
self,
poly: np.ndarray,
box: np.ndarray,
box_size: float,
idx: int,
class_idx: int,
seg_target: np.ndarray,
seg_mask: np.ndarray,
thresh_target: np.ndarray,
thresh_mask: np.ndarray,
) -> None:
if box_size < self.min_size_box:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
return

polygon = Polygon(poly)
distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
subject = [tuple(coor) for coor in poly]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrunken = padding.Execute(-distance)

if len(shrunken) == 0:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
return
shrunken = np.array(shrunken[0]).reshape(-1, 2)
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
return
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)

self.draw_thresh_map(poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx])

def build_target(
self,
target: list[dict[str, np.ndarray]],
Expand Down Expand Up @@ -321,32 +355,8 @@ def build_target(
boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])

for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
# Mask boxes that are too small
if box_size < self.min_size_box:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue

# Negative shrink for gt, as described in paper
polygon = Polygon(poly)
distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
subject = [tuple(coor) for coor in poly]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
shrunken = padding.Execute(-distance)

# Draw polygon on gt if it is valid
if len(shrunken) == 0:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
shrunken = np.array(shrunken[0]).reshape(-1, 2)
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
continue
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)

# Draw on both thresh map and thresh mask
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx]
self._draw_polygon_on_maps(
poly, box, box_size, idx, class_idx, seg_target, seg_mask, thresh_target, thresh_mask
)

thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min
Expand Down
201 changes: 134 additions & 67 deletions doctr/models/layout/lw_detr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,15 +449,9 @@

return object_query, output_proposals, invalid_mask

def forward(
self,
input: torch.Tensor,
masks: torch.Tensor | None = None,
target: list[dict[str, np.ndarray]] | None = None,
return_model_output: bool = False,
return_preds: bool = False,
**kwargs: Any,
) -> dict[str, Any]:
def _extract_features(
self, input: torch.Tensor, masks: torch.Tensor | None

Check warning on line 453 in doctr/models/layout/lw_detr/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/layout/lw_detr/pytorch.py#L453

Redefining built-in 'input'
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
feats = self.feat_extractor(input, masks)

sources: list[torch.Tensor] = []
Expand All @@ -469,20 +463,26 @@
if mask is None: # pragma: no cover
raise ValueError("No attention mask was provided")

return sources, feats_masks

def _setup_queries(self) -> tuple[torch.Tensor, torch.Tensor]:
if self.training:
reference_points = self.reference_point_embed.weight
query_feat = self.query_feat.weight
else:
# only use one group in inference
reference_points = self.reference_point_embed.weight[: self.num_queries]
query_feat = self.query_feat.weight[: self.num_queries]

# Prepare encoder inputs (by flattening)
return reference_points, query_feat

def _prepare_encoder_inputs(
self, sources: list[torch.Tensor], feats_masks: list[torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]], torch.Tensor, torch.Tensor, torch.Tensor]:
source_flatten_list: list[torch.Tensor] = []
mask_flatten_list: list[torch.Tensor] = []
spatial_shapes_list: list[tuple[int, int]] = []
for source, mask in zip(sources, feats_masks):
batch_size, num_channels, height, width = source.shape
_, _, height, width = source.shape
spatial_shape = (height, width)
spatial_shapes_list.append(spatial_shape)
source = source.flatten(2).transpose(1, 2)
Expand All @@ -492,19 +492,22 @@
source_flatten = torch.cat(source_flatten_list, 1)
mask_flatten = torch.cat(mask_flatten_list, 1)

tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1)
reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1)

object_query_embedding, output_proposals, invalid_mask = self.gen_encoder_output_proposals(
source_flatten, mask_flatten, spatial_shapes_list
)

return source_flatten, mask_flatten, spatial_shapes_list, object_query_embedding, output_proposals, invalid_mask

def _encoder_group_predictions(
self,
object_query_embedding: torch.Tensor,
output_proposals: torch.Tensor,
invalid_mask: torch.Tensor,
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
group_detr = self.group_detr if self.training else 1
topk = self.num_queries

topk_coords_logits_list: list[torch.Tensor] = []

# encoder predictions on the selected top-k proposals, kept undetached for the auxiliary loss
all_group_enc_logits: list[torch.Tensor] = []
all_group_enc_coords: list[torch.Tensor] = []

Expand All @@ -513,7 +516,6 @@
group_object_query = self.enc_output_norm[group_id](group_object_query)

group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query)

group_enc_outputs_class_masked = group_enc_outputs_class.masked_fill(invalid_mask, float("-inf"))

group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query)
Expand All @@ -526,22 +528,25 @@
1,
group_topk_proposals.unsqueeze(-1).repeat(1, 1, 6),
)
# the auxiliary loss supervises only the selected proposals,
# so gather the matching class logits as well
group_topk_logits_undetach = torch.gather(
group_enc_outputs_class,
1,
group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.num_classes),
)
all_group_enc_logits.append(group_topk_logits_undetach)
all_group_enc_coords.append(group_topk_coords_logits_undetach)

# the decoder consumes detached proposals as initial reference points
topk_coords_logits_list.append(group_topk_coords_logits_undetach.detach())

topk_coords_logits = torch.cat(topk_coords_logits_list, 1)
reference_points = refine_obb_boxes(topk_coords_logits, reference_points)
return topk_coords_logits_list, all_group_enc_logits, all_group_enc_coords

def _decode_and_predict(
self,
tgt: torch.Tensor,
reference_points: torch.Tensor,
source_flatten: torch.Tensor,
mask_flatten: torch.Tensor,
spatial_shapes_list: list[tuple[int, int]],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]:
last_hidden_states, intermediate, intermediate_reference_points = self.decoder(
inputs_embeds=tgt,
reference_points=reference_points,
Expand All @@ -554,62 +559,124 @@
pred_boxes_delta = self.bbox_embed(last_hidden_states)
pred_boxes = refine_obb_boxes(intermediate_reference_points[-1], pred_boxes_delta)

out: dict[str, Any] = {}
return logits, pred_boxes, intermediate, intermediate_reference_points

if self.exportable:
out["logits"] = logits
out["pred_boxes"] = pred_boxes
return out
def _prepare_outputs(
self,
logits: torch.Tensor,
pred_boxes: torch.Tensor,
return_model_output: bool,
target: list[dict[str, np.ndarray]] | None,
return_preds: bool,
) -> dict[str, Any]:
out: dict[str, Any] = {}

if return_model_output or target is None or return_preds:
out["logits"] = logits

if target is None or return_preds:
# Disable for torch.compile compatibility

@torch.compiler.disable
def _postprocess(logits, boxes):
return self.postprocessor(logits, boxes)

out["preds"] = _postprocess(logits.detach().cpu().numpy(), pred_boxes.detach().cpu().numpy())

return out

def _compute_losses(
self,
logits: torch.Tensor,
pred_boxes: torch.Tensor,
target: list[dict[str, np.ndarray]],
input: torch.Tensor,
group_detr: int,
intermediate: torch.Tensor,
intermediate_reference_points: list[torch.Tensor],
all_group_enc_logits: list[torch.Tensor],
all_group_enc_coords: list[torch.Tensor],
) -> torch.Tensor:
processed_targets = self.build_target(target, self.class_names)

box_scale = float(max(input.shape[-2], input.shape[-1]))

split_logits = logits.chunk(group_detr, dim=1)
split_boxes = pred_boxes.chunk(group_detr, dim=1)

main_loss: float | torch.Tensor = 0.0
for g_logits, g_boxes in zip(split_logits, split_boxes):
main_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale)
loss = main_loss / group_detr

for i in range(intermediate.shape[0] - 1):
aux_logits = self.class_embed(intermediate[i])
aux_boxes_delta = self.bbox_embed(intermediate[i])
aux_boxes = refine_obb_boxes(intermediate_reference_points[i], aux_boxes_delta)

split_aux_logits = aux_logits.chunk(group_detr, dim=1)
split_aux_boxes = aux_boxes.chunk(group_detr, dim=1)

aux_loss: float | torch.Tensor = 0.0
for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes):
aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale)
loss += aux_loss / group_detr

enc_loss: float | torch.Tensor = 0.0
for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords):
enc_loss += self.compute_loss(group_logits, group_coords, processed_targets, box_scale=box_scale)
loss += enc_loss / group_detr

return loss

def forward(
self,
input: torch.Tensor,
masks: torch.Tensor | None = None,
target: list[dict[str, np.ndarray]] | None = None,
return_model_output: bool = False,
return_preds: bool = False,
**kwargs: Any,
) -> dict[str, Any]:
sources, feats_masks = self._extract_features(input, masks)
reference_points, query_feat = self._setup_queries()

batch_size = sources[0].shape[0]
tgt = query_feat.unsqueeze(0).expand(batch_size, -1, -1)
reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1)

source_flatten, mask_flatten, spatial_shapes_list, object_query_embedding, output_proposals, invalid_mask = (
self._prepare_encoder_inputs(sources, feats_masks)
)

topk_coords_logits_list, all_group_enc_logits, all_group_enc_coords = self._encoder_group_predictions(
object_query_embedding, output_proposals, invalid_mask
)

topk_coords_logits = torch.cat(topk_coords_logits_list, 1)
reference_points = refine_obb_boxes(topk_coords_logits, reference_points)

logits, pred_boxes, intermediate, intermediate_reference_points = self._decode_and_predict(
tgt, reference_points, source_flatten, mask_flatten, spatial_shapes_list
)

if self.exportable:
return {"logits": logits, "pred_boxes": pred_boxes}

out = self._prepare_outputs(logits, pred_boxes, return_model_output, target, return_preds)

if target is not None:
# Build target
processed_targets = self.build_target(target, self.class_names)

# ProbIoU is computed in pixel coordinates
box_scale = float(max(input.shape[-2], input.shape[-1]))

# Main loss from final decoder layer (group DETR: each group is matched independently)
split_logits = logits.chunk(group_detr, dim=1)
split_boxes = pred_boxes.chunk(group_detr, dim=1)

main_loss: float | torch.Tensor = 0.0
for g_logits, g_boxes in zip(split_logits, split_boxes):
main_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale)
loss = main_loss / group_detr

# Auxiliary losses from intermediate decoder layers
# (`intermediate_reference_points[i]` is the reference INPUT to decoder layer i)
for i in range(intermediate.shape[0] - 1):
aux_logits = self.class_embed(intermediate[i])
aux_boxes_delta = self.bbox_embed(intermediate[i])
aux_boxes = refine_obb_boxes(intermediate_reference_points[i], aux_boxes_delta)

split_aux_logits = aux_logits.chunk(group_detr, dim=1)
split_aux_boxes = aux_boxes.chunk(group_detr, dim=1)

aux_loss: float | torch.Tensor = 0.0
for g_logits, g_boxes in zip(split_aux_logits, split_aux_boxes):
aux_loss += self.compute_loss(g_logits, g_boxes, processed_targets, box_scale=box_scale)
loss += aux_loss / group_detr

# Auxiliary losses for the selected encoder proposals
enc_loss: float | torch.Tensor = 0.0
for group_logits, group_coords in zip(all_group_enc_logits, all_group_enc_coords):
enc_loss += self.compute_loss(group_logits, group_coords, processed_targets, box_scale=box_scale)
loss += enc_loss / group_detr

out["loss"] = loss
group_detr = self.group_detr if self.training else 1
out["loss"] = self._compute_losses(
logits,
pred_boxes,
target,
input,
group_detr,
intermediate,
intermediate_reference_points,
all_group_enc_logits,
all_group_enc_coords,
)

return out

Expand Down
Loading