Skip to content

Commit b16f4fc

Browse files
authored
Merge pull request #1674 from roboflow/Update_Perspective_Correction_For_Dimensionality
Update perspective correction for dimensionality
2 parents 36b1cd2 + 2ace3e0 commit b16f4fc

File tree

2 files changed

+96
-18
lines changed

2 files changed

+96
-18
lines changed

inference/core/workflows/core_steps/transformations/perspective_correction/v1.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,11 @@ def get_parameters_accepting_batches(cls) -> List[str]:
113113

114114
@classmethod
115115
def get_parameters_accepting_batches_and_scalars(cls) -> List[str]:
116-
return ["perspective_polygons"]
116+
return [
117+
"perspective_polygons",
118+
"transformed_rect_width",
119+
"transformed_rect_height",
120+
]
117121

118122
@classmethod
119123
def describe_outputs(cls) -> List[OutputDefinition]:
@@ -688,8 +692,8 @@ def run(
688692
List[List[List[int]]],
689693
List[List[List[List[int]]]],
690694
],
691-
transformed_rect_width: int,
692-
transformed_rect_height: int,
695+
transformed_rect_width: Union[int, List[int], np.ndarray],
696+
transformed_rect_height: Union[int, List[int], np.ndarray],
693697
extend_perspective_polygon_by_detections_anchor: Union[
694698
sv.Position, Literal[ALL_POSITIONS]
695699
],
@@ -723,23 +727,36 @@ def run(
723727
raise ValueError(
724728
f"Predictions batch size ({batch_size}) does not match number of perspective polygons ({largest_perspective_polygons})"
725729
)
726-
for polygon, detections in zip(largest_perspective_polygons, predictions):
730+
if isinstance(transformed_rect_height, int):
731+
transformed_rect_height = [transformed_rect_height] * batch_size
732+
if isinstance(transformed_rect_width, int):
733+
transformed_rect_width = [transformed_rect_width] * batch_size
734+
for polygon, detections, width, height in zip(
735+
largest_perspective_polygons,
736+
predictions,
737+
list(transformed_rect_width),
738+
list(transformed_rect_height),
739+
):
727740
if polygon is None:
728741
self.perspective_transformers.append(None)
729742
continue
730743
self.perspective_transformers.append(
731744
generate_transformation_matrix(
732745
src_polygon=polygon,
733746
detections=detections,
734-
transformed_rect_width=transformed_rect_width,
735-
transformed_rect_height=transformed_rect_height,
747+
transformed_rect_width=width,
748+
transformed_rect_height=height,
736749
detections_anchor=extend_perspective_polygon_by_detections_anchor,
737750
)
738751
)
739752

740753
result = []
741-
for detections, perspective_transformer_w_h, image in zip(
742-
predictions, self.perspective_transformers, images
754+
for detections, perspective_transformer_w_h, image, width, height in zip(
755+
predictions,
756+
self.perspective_transformers,
757+
images,
758+
transformed_rect_width,
759+
transformed_rect_height,
743760
):
744761
perspective_transformer, extended_width, extended_height = (
745762
perspective_transformer_w_h
@@ -751,8 +768,8 @@ def run(
751768
src=image.numpy_image,
752769
M=perspective_transformer,
753770
dsize=(
754-
transformed_rect_width + int(round(extended_width)),
755-
transformed_rect_height + int(round(extended_height)),
771+
int(round(width)) + int(round(extended_width)),
772+
int(round(height)) + int(round(extended_height)),
756773
),
757774
)
758775
result_image = WorkflowImageData.copy_and_replace(
@@ -765,9 +782,9 @@ def run(
765782
{
766783
OUTPUT_DETECTIONS_KEY: None,
767784
OUTPUT_IMAGE_KEY: result_image,
768-
OUTPUT_EXTENDED_TRANSFORMED_RECT_WIDTH_KEY: transformed_rect_width
785+
OUTPUT_EXTENDED_TRANSFORMED_RECT_WIDTH_KEY: width
769786
+ int(round(extended_width)),
770-
OUTPUT_EXTENDED_TRANSFORMED_RECT_HEIGHT_KEY: transformed_rect_height
787+
OUTPUT_EXTENDED_TRANSFORMED_RECT_HEIGHT_KEY: height
771788
+ int(round(extended_height)),
772789
}
773790
)
@@ -776,19 +793,17 @@ def run(
776793
corrected_detections = correct_detections(
777794
detections=detections,
778795
perspective_transformer=perspective_transformer,
779-
transformed_rect_width=transformed_rect_width
780-
+ int(round(extended_width)),
781-
transformed_rect_height=transformed_rect_height
782-
+ int(round(extended_height)),
796+
transformed_rect_width=width + int(round(extended_width)),
797+
transformed_rect_height=height + int(round(extended_height)),
783798
)
784799

785800
result.append(
786801
{
787802
OUTPUT_DETECTIONS_KEY: corrected_detections,
788803
OUTPUT_IMAGE_KEY: result_image,
789-
OUTPUT_EXTENDED_TRANSFORMED_RECT_WIDTH_KEY: transformed_rect_width
804+
OUTPUT_EXTENDED_TRANSFORMED_RECT_WIDTH_KEY: width
790805
+ int(round(extended_width)),
791-
OUTPUT_EXTENDED_TRANSFORMED_RECT_HEIGHT_KEY: transformed_rect_height
806+
OUTPUT_EXTENDED_TRANSFORMED_RECT_HEIGHT_KEY: height
792807
+ int(round(extended_height)),
793808
}
794809
)

tests/workflows/unit_tests/core_steps/transformations/test_perspective_correction.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,66 @@ def test_warp_image():
413413
assert isinstance(
414414
result[0]["warped_image"], WorkflowImageData
415415
), f"warped_image must be of type WorkflowImageData"
416+
417+
418+
def test_warp_image_batch_dims():
419+
# given
420+
dummy_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
421+
dummy_predictions = sv.Detections(xyxy=np.array([[10, 10, 20, 20]]))
422+
perspective_correction_block = PerspectiveCorrectionBlockV1()
423+
424+
workflow_image_data = WorkflowImageData(
425+
parent_metadata=ImageParentMetadata(parent_id="test"), numpy_image=dummy_image
426+
)
427+
428+
# when
429+
result = perspective_correction_block.run(
430+
images=[workflow_image_data],
431+
predictions=[dummy_predictions],
432+
perspective_polygons=[[[1, 1], [99, 1], [99, 99], [1, 99]]],
433+
transformed_rect_width=[200],
434+
transformed_rect_height=[200],
435+
extend_perspective_polygon_by_detections_anchor=None,
436+
warp_image=True,
437+
)
438+
439+
# then
440+
assert "warped_image" in result[0], "warped_image key must be present in the result"
441+
assert isinstance(
442+
result[0]["warped_image"], WorkflowImageData
443+
), f"warped_image must be of type WorkflowImageData"
444+
445+
446+
def test_batch_input():
447+
# given
448+
batch_size = 3
449+
dummy_images = [
450+
np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
451+
] * batch_size
452+
dummy_predictions = [sv.Detections(xyxy=np.array([[10, 10, 20, 20]]))] * batch_size
453+
perspective_correction_block = PerspectiveCorrectionBlockV1()
454+
455+
workflow_image_data = [
456+
WorkflowImageData(
457+
parent_metadata=ImageParentMetadata(parent_id="test"),
458+
numpy_image=dummy_image,
459+
)
460+
for dummy_image in dummy_images
461+
]
462+
463+
# when
464+
result = perspective_correction_block.run(
465+
images=workflow_image_data,
466+
predictions=dummy_predictions,
467+
perspective_polygons=[[[1, 1], [99, 1], [99, 99], [1, 99]]],
468+
transformed_rect_width=[200] * batch_size,
469+
transformed_rect_height=[200] * batch_size,
470+
extend_perspective_polygon_by_detections_anchor=None,
471+
warp_image=True,
472+
)
473+
474+
# then
475+
assert "warped_image" in result[0], "warped_image key must be present in the result"
476+
assert isinstance(
477+
result[0]["warped_image"], WorkflowImageData
478+
), f"warped_image must be of type WorkflowImageData"

0 commit comments

Comments
 (0)