@@ -113,7 +113,11 @@ def get_parameters_accepting_batches(cls) -> List[str]:
113113
114114 @classmethod
115115 def get_parameters_accepting_batches_and_scalars (cls ) -> List [str ]:
116- return ["perspective_polygons" ]
116+ return [
117+ "perspective_polygons" ,
118+ "transformed_rect_width" ,
119+ "transformed_rect_height" ,
120+ ]
117121
118122 @classmethod
119123 def describe_outputs (cls ) -> List [OutputDefinition ]:
@@ -688,8 +692,8 @@ def run(
688692 List [List [List [int ]]],
689693 List [List [List [List [int ]]]],
690694 ],
691- transformed_rect_width : int ,
692- transformed_rect_height : int ,
695+ transformed_rect_width : Union [ int , List [ int ], np . ndarray ] ,
696+ transformed_rect_height : Union [ int , List [ int ], np . ndarray ] ,
693697 extend_perspective_polygon_by_detections_anchor : Union [
694698 sv .Position , Literal [ALL_POSITIONS ]
695699 ],
@@ -723,23 +727,36 @@ def run(
723727 raise ValueError (
724728 f"Predictions batch size ({ batch_size } ) does not match number of perspective polygons ({ largest_perspective_polygons } )"
725729 )
726- for polygon , detections in zip (largest_perspective_polygons , predictions ):
730+ if isinstance (transformed_rect_height , int ):
731+ transformed_rect_height = [transformed_rect_height ] * batch_size
732+ if isinstance (transformed_rect_width , int ):
733+ transformed_rect_width = [transformed_rect_width ] * batch_size
734+ for polygon , detections , width , height in zip (
735+ largest_perspective_polygons ,
736+ predictions ,
737+ list (transformed_rect_width ),
738+ list (transformed_rect_height ),
739+ ):
727740 if polygon is None :
728741 self .perspective_transformers .append (None )
729742 continue
730743 self .perspective_transformers .append (
731744 generate_transformation_matrix (
732745 src_polygon = polygon ,
733746 detections = detections ,
734- transformed_rect_width = transformed_rect_width ,
735- transformed_rect_height = transformed_rect_height ,
747+ transformed_rect_width = width ,
748+ transformed_rect_height = height ,
736749 detections_anchor = extend_perspective_polygon_by_detections_anchor ,
737750 )
738751 )
739752
740753 result = []
741- for detections , perspective_transformer_w_h , image in zip (
742- predictions , self .perspective_transformers , images
754+ for detections , perspective_transformer_w_h , image , width , height in zip (
755+ predictions ,
756+ self .perspective_transformers ,
757+ images ,
758+ transformed_rect_width ,
759+ transformed_rect_height ,
743760 ):
744761 perspective_transformer , extended_width , extended_height = (
745762 perspective_transformer_w_h
@@ -751,8 +768,8 @@ def run(
751768 src = image .numpy_image ,
752769 M = perspective_transformer ,
753770 dsize = (
754- transformed_rect_width + int (round (extended_width )),
755- transformed_rect_height + int (round (extended_height )),
771+ int ( round ( width )) + int (round (extended_width )),
772+ int ( round ( height )) + int (round (extended_height )),
756773 ),
757774 )
758775 result_image = WorkflowImageData .copy_and_replace (
@@ -765,9 +782,9 @@ def run(
765782 {
766783 OUTPUT_DETECTIONS_KEY : None ,
767784 OUTPUT_IMAGE_KEY : result_image ,
768- OUTPUT_EXTENDED_TRANSFORMED_RECT_WIDTH_KEY : transformed_rect_width
785+ OUTPUT_EXTENDED_TRANSFORMED_RECT_WIDTH_KEY : width
769786 + int (round (extended_width )),
770- OUTPUT_EXTENDED_TRANSFORMED_RECT_HEIGHT_KEY : transformed_rect_height
787+ OUTPUT_EXTENDED_TRANSFORMED_RECT_HEIGHT_KEY : height
771788 + int (round (extended_height )),
772789 }
773790 )
@@ -776,19 +793,17 @@ def run(
776793 corrected_detections = correct_detections (
777794 detections = detections ,
778795 perspective_transformer = perspective_transformer ,
779- transformed_rect_width = transformed_rect_width
780- + int (round (extended_width )),
781- transformed_rect_height = transformed_rect_height
782- + int (round (extended_height )),
796+ transformed_rect_width = width + int (round (extended_width )),
797+ transformed_rect_height = height + int (round (extended_height )),
783798 )
784799
785800 result .append (
786801 {
787802 OUTPUT_DETECTIONS_KEY : corrected_detections ,
788803 OUTPUT_IMAGE_KEY : result_image ,
789- OUTPUT_EXTENDED_TRANSFORMED_RECT_WIDTH_KEY : transformed_rect_width
804+ OUTPUT_EXTENDED_TRANSFORMED_RECT_WIDTH_KEY : width
790805 + int (round (extended_width )),
791- OUTPUT_EXTENDED_TRANSFORMED_RECT_HEIGHT_KEY : transformed_rect_height
806+ OUTPUT_EXTENDED_TRANSFORMED_RECT_HEIGHT_KEY : height
792807 + int (round (extended_height )),
793808 }
794809 )
0 commit comments