Skip to content

Commit 850491e

Browse files
authored
Support single color in utils.draw_bounding_boxes (#4075)
1 parent eb1b982 commit 850491e

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

test/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,18 @@ def test_draw_boxes():
113113
assert_equal(img, img_cp)
114114

115115

116+
@pytest.mark.parametrize('colors', [
117+
None,
118+
['red', 'blue', '#FF00FF', (1, 34, 122)],
119+
'red',
120+
'#FF00FF',
121+
(1, 34, 122)
122+
])
123+
def test_draw_boxes_colors(colors):
124+
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
125+
utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors)
126+
127+
116128
def test_draw_boxes_vanilla():
117129
img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
118130
img_cp = img.clone()

torchvision/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def draw_bounding_boxes(
141141
image: torch.Tensor,
142142
boxes: torch.Tensor,
143143
labels: Optional[List[str]] = None,
144-
colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None,
144+
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None,
145145
fill: Optional[bool] = False,
146146
width: int = 1,
147147
font: Optional[str] = None,
@@ -159,8 +159,9 @@ def draw_bounding_boxes(
159159
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and
160160
`0 <= ymin < ymax < H`.
161161
labels (List[str]): List containing the labels of bounding boxes.
162-
colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes. The colors can
163-
be represented as `str` or `Tuple[int, int, int]`.
162+
colors (Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]): List containing the colors
163+
or a single color for all of the bounding boxes. The colors can be represented as `str` or
164+
`Tuple[int, int, int]`.
164165
fill (bool): If `True` fills the bounding box with specified color.
165166
width (int): Width of bounding box.
166167
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may
@@ -200,8 +201,10 @@ def draw_bounding_boxes(
200201
for i, bbox in enumerate(img_boxes):
201202
if colors is None:
202203
color = None
203-
else:
204+
elif isinstance(colors, list):
204205
color = colors[i]
206+
else:
207+
color = colors
205208

206209
if fill:
207210
if color is None:

0 commit comments

Comments
 (0)