diff --git a/mask2former/modeling/matcher.py b/mask2former/modeling/matcher.py index 7c6af7f8..da811b41 100644 --- a/mask2former/modeling/matcher.py +++ b/mask2former/modeling/matcher.py @@ -134,11 +134,16 @@ def memory_efficient_forward(self, outputs, targets): with autocast(enabled=False): out_mask = out_mask.float() tgt_mask = tgt_mask.float() - # Compute the focal loss between masks - cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) - # Compute the dice loss betwen masks - cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) + is_annotation_empty = out_mask.shape[0] == 0 or tgt_mask.shape[0] == 0 + if is_annotation_empty: + # Compute the focal loss between masks + cost_mask = batch_sigmoid_ce_loss(out_mask, tgt_mask) + # Compute the dice loss between masks + cost_dice = batch_dice_loss(out_mask, tgt_mask) + else: + cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) + cost_dice = batch_dice_loss_jit(out_mask, tgt_mask) # Final cost matrix C = (