Skip to content

Commit 01b44cb

Browse files
committed
Add common utils to BoundingBox class
1 parent 17ded38 commit 01b44cb

File tree

2 files changed

+166
-1
lines changed

2 files changed

+166
-1
lines changed

open_image_models/detection/core/base.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Protocol
2+
from typing import Any, Optional, Protocol
33

44
import numpy as np
55

@@ -19,6 +19,64 @@ class BoundingBox:
1919
y2: int
2020
"""Y-coordinate of the bottom-right corner"""
2121

22+
@property
23+
def width(self) -> int:
24+
"""Returns the width of the bounding box."""
25+
return self.x2 - self.x1
26+
27+
@property
28+
def height(self) -> int:
29+
"""Returns the height of the bounding box."""
30+
return self.y2 - self.y1
31+
32+
@property
33+
def area(self) -> int:
34+
"""Returns the area of the bounding box."""
35+
return self.width * self.height
36+
37+
@property
38+
def center(self) -> tuple[float, float]:
39+
"""
40+
Returns the (x, y) coordinates of the center of the bounding box.
41+
"""
42+
cx = (self.x1 + self.x2) / 2.0
43+
cy = (self.y1 + self.y2) / 2.0
44+
45+
return cx, cy
46+
47+
def intersection(self, other: "BoundingBox") -> Optional["BoundingBox"]:
48+
"""
49+
Returns the intersection of this bounding box with another bounding box. If they do not intersect, returns None.
50+
"""
51+
x1 = max(self.x1, other.x1)
52+
y1 = max(self.y1, other.y1)
53+
x2 = min(self.x2, other.x2)
54+
y2 = min(self.y2, other.y2)
55+
56+
if x2 > x1 and y2 > y1:
57+
return BoundingBox(x1, y1, x2, y2)
58+
59+
return None
60+
61+
def iou(self, other: "BoundingBox") -> float:
62+
"""
63+
Computes the Intersection-over-Union (IoU) between this bounding box and another bounding box.
64+
"""
65+
inter = self.intersection(other)
66+
67+
if inter is None:
68+
return 0.0
69+
70+
inter_area = inter.area
71+
union_area = self.area + other.area - inter_area
72+
return inter_area / union_area if union_area > 0 else 0.0
73+
74+
def to_xywh(self) -> tuple[int, int, int, int]:
75+
"""
76+
Converts bounding box to (x, y, width, height) format, where (x, y) is the top-left corner.
77+
"""
78+
return self.x1, self.y1, self.width, self.height
79+
2280

2381
@dataclass(frozen=True)
2482
class DetectionResult:
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from math import isclose
2+
3+
import pytest
4+
5+
from open_image_models.detection.core.base import BoundingBox
6+
7+
# pylint: disable=too-many-positional-arguments
8+
9+
10+
@pytest.mark.parametrize(
11+
"x1, y1, x2, y2, expected_width, expected_height, expected_area, expected_center",
12+
[
13+
(0, 0, 10, 10, 10, 10, 100, (5.0, 5.0)),
14+
(2, 3, 5, 7, 3, 4, 12, (3.5, 5.0)),
15+
(10, 10, 15, 13, 5, 3, 15, (12.5, 11.5)),
16+
],
17+
)
18+
def test_bounding_box_properties(x1, y1, x2, y2, expected_width, expected_height, expected_area, expected_center):
19+
bbox = BoundingBox(x1, y1, x2, y2)
20+
assert bbox.width == expected_width
21+
assert bbox.height == expected_height
22+
assert bbox.area == expected_area
23+
24+
actual_center = bbox.center
25+
assert isclose(actual_center[0], expected_center[0], rel_tol=1e-6)
26+
assert isclose(actual_center[1], expected_center[1], rel_tol=1e-6)
27+
28+
29+
@pytest.mark.parametrize(
30+
"bbox, expected_xywh",
31+
[
32+
(BoundingBox(0, 0, 10, 10), (0, 0, 10, 10)),
33+
(BoundingBox(2, 3, 5, 7), (2, 3, 3, 4)),
34+
(BoundingBox(10, 10, 15, 13), (10, 10, 5, 3)),
35+
],
36+
)
37+
def test_to_xywh(bbox, expected_xywh):
38+
xywh = bbox.to_xywh()
39+
assert xywh == expected_xywh
40+
41+
42+
@pytest.mark.parametrize(
43+
"bbox1, bbox2, expected_intersection",
44+
[
45+
# Overlapping case
46+
(
47+
BoundingBox(0, 0, 10, 10),
48+
BoundingBox(5, 5, 15, 15),
49+
BoundingBox(5, 5, 10, 10),
50+
),
51+
# One box completely inside another
52+
(
53+
BoundingBox(0, 0, 10, 10),
54+
BoundingBox(2, 2, 5, 5),
55+
BoundingBox(2, 2, 5, 5),
56+
),
57+
# Touching edges (should return None if there's no positive overlap)
58+
(
59+
BoundingBox(0, 0, 10, 10),
60+
BoundingBox(10, 10, 12, 12),
61+
None,
62+
),
63+
# No overlap at all
64+
(
65+
BoundingBox(0, 0, 5, 5),
66+
BoundingBox(6, 6, 10, 10),
67+
None,
68+
),
69+
],
70+
)
71+
def test_intersection(bbox1, bbox2, expected_intersection):
72+
inter = bbox1.intersection(bbox2)
73+
assert inter == expected_intersection
74+
75+
76+
@pytest.mark.parametrize(
77+
"bbox1, bbox2, expected_iou",
78+
[
79+
# Same box, IoU should equal 1.0
80+
(
81+
BoundingBox(0, 0, 10, 10),
82+
BoundingBox(0, 0, 10, 10),
83+
1.0,
84+
),
85+
# Partial overlap
86+
(
87+
BoundingBox(0, 0, 10, 10),
88+
BoundingBox(5, 5, 15, 15),
89+
25 / 175, # intersection=25, union=175, so IoU should equal 25/175
90+
),
91+
# No overlap, IoU should equal 0.0
92+
(
93+
BoundingBox(0, 0, 5, 5),
94+
BoundingBox(6, 6, 10, 10),
95+
0.0,
96+
),
97+
# One box inside another, ratio of areas
98+
(
99+
BoundingBox(0, 0, 10, 10),
100+
BoundingBox(2, 2, 8, 8),
101+
36 / 100,
102+
),
103+
],
104+
)
105+
def test_iou(bbox1, bbox2, expected_iou):
106+
iou_value = bbox1.iou(bbox2)
107+
assert isclose(iou_value, expected_iou, rel_tol=1e-5)

0 commit comments

Comments
 (0)