1+ use rayon:: iter:: ParallelIterator ;
12use std:: ops:: RangeTo ;
23use std:: path:: { Path , PathBuf } ;
4+ use std:: sync:: { Arc , Mutex } ;
35use std:: time:: Instant ;
46use image:: { DynamicImage , GenericImage , ImageBuffer , Rgba , RgbaImage } ;
57use image:: imageops:: FilterType ;
@@ -16,6 +18,7 @@ use crate::generator::coco::{BoundingBox, CocoCategoryInfo, CocoGenerator};
1618use crate :: generator:: config:: TargetGeneratorConfig ;
1719use crate :: objects:: { ObjectManager } ;
1820use moka:: sync:: { Cache , CacheBuilder } ;
21+ use rayon:: iter:: { IntoParallelIterator , ParallelBridge } ;
1922
2023pub mod coco;
2124pub mod error;
@@ -29,9 +32,9 @@ pub struct TargetGenerator {
2932 backgrounds_path : PathBuf ,
3033 pub object_manager : ObjectManager ,
3134 background_loader : BackgroundLoader ,
32- coco_generator : CocoGenerator ,
35+ coco_generator : Arc < Mutex < CocoGenerator > > ,
3336 config : TargetGeneratorConfig ,
34- resized_cache : Cache < u32 , DynamicImage > ,
37+ resized_cache : Cache < String , DynamicImage > ,
3538}
3639
3740impl TargetGenerator {
@@ -43,7 +46,7 @@ impl TargetGenerator {
4346 let categories = object_manager. categories ( ) ;
4447 let config = TargetGeneratorConfig :: default ( ) ;
4548
46- let resized_cache: Cache < u32 , DynamicImage > = CacheBuilder :: new ( config. cache_size as u64 * 1024 * 1024 )
49+ let resized_cache: Cache < String , DynamicImage > = CacheBuilder :: new ( config. cache_size as u64 * 1024 * 1024 )
4750 . weigher ( |_key, value : & DynamicImage | -> u32 { // evict based on size in MBs
4851 value. as_bytes ( ) . len ( ) as u32
4952 } )
@@ -54,13 +57,13 @@ impl TargetGenerator {
5457 backgrounds_path : background_path. as_ref ( ) . to_path_buf ( ) ,
5558 object_manager,
5659 background_loader : BackgroundLoader :: new ( background_path) ?,
57- coco_generator : CocoGenerator :: new ( annotations_path, categories) ,
60+ coco_generator : Arc :: new ( Mutex :: new ( CocoGenerator :: new ( annotations_path, categories) ) ) ,
5861 config,
5962 resized_cache,
6063 } )
6164 }
6265
63- pub fn generate_target ( & mut self , pixels_per_meter : f32 , number_of_objects : u16 ) -> Result < RgbaImage , GenerationError > {
66+ pub fn generate_target ( & self , pixels_per_meter : f32 , number_of_objects : u16 ) -> Result < RgbaImage , GenerationError > {
6467 trace ! ( "Beginning to generate a target..." ) ;
6568
6669 if number_of_objects == 0 {
@@ -74,7 +77,7 @@ impl TargetGenerator {
7477 let mut placed_objects = vec ! [ ] ;
7578
7679 // add background image to coco here
77- let background_id = self . coco_generator . add_image ( w, h, background. filename . clone ( ) , background. date_captured . clone ( ) ) ;
80+ let background_id = self . coco_generator . lock ( ) . unwrap ( ) . add_image ( w, h, background. filename . clone ( ) , background. date_captured . clone ( ) ) ;
7881
7982 for obj in set {
8083 let clone = & obj. dynamic_image . clone ( ) ;
@@ -86,11 +89,11 @@ impl TargetGenerator {
8689 debug ! ( "Resizing object to {}x{}" , obj_w, obj_h) ;
8790
8891 // overlay respects transparent pixels unlike copy_from
89- let resized = if let Some ( resized) = self . resized_cache . get ( & obj_w) {
92+ let resized = if let Some ( resized) = self . resized_cache . get ( & format ! ( "{}x{}_{}" , obj_w, obj_h , obj . object_class ) ) {
9093 resized. clone ( )
9194 } else {
9295 let resized = clone. resize ( obj_w, obj_h, FilterType :: Gaussian ) ;
93- self . resized_cache . insert ( obj_w, resized. clone ( ) ) ;
96+ self . resized_cache . insert ( format ! ( "{}x{}_{}" , obj_w, obj_h , obj . object_class ) , resized. clone ( ) ) ;
9497 resized
9598 } ;
9699
@@ -112,17 +115,28 @@ impl TargetGenerator {
112115 } ;
113116
114117 // add annotation to coco here
115- self . coco_generator . add_annotation ( background_id, obj. object_class as u32 , 0 , vec ! [ ] , ( obj_w * obj_h) as f64 , bbox) ;
118+ self . coco_generator . lock ( ) . unwrap ( ) . add_annotation ( background_id, obj. object_class , 0 , vec ! [ ] , ( obj_w * obj_h) as f64 , bbox) ;
116119
117120 placed_objects. push ( bbox) ;
118121 }
119122
120123 Ok ( image)
121124 }
122125
123- pub fn generate_targets < A : AsRef < Path > > ( & self , amount : u32 , range_to : RangeTo < u32 > , path : A ) -> Result < ( ) , GenerationError > {
126+ pub fn generate_targets < A : AsRef < Path > + Sync > ( & mut self , amount : u32 , range_to : RangeTo < u32 > , path : A ) -> Result < ( ) , GenerationError > {
124127 let start = Instant :: now ( ) ; // start timer
128+ debug ! ( "Generating {} targets..." , amount) ;
125129
130+ let threadpool = rayon:: ThreadPoolBuilder :: new ( ) . num_threads ( self . config . worker_threads as usize ) . build ( ) . unwrap ( ) ;
131+
132+ threadpool. install ( || {
133+ ( 0 ..amount) . into_par_iter ( ) . for_each ( |i| {
134+ let b = self . generate_target ( STANDARD_PPM , thread_rng ( ) . gen_range ( 1 ..range_to. end ) as u16 ) . unwrap ( ) ;
135+ let path = path. as_ref ( ) . join ( format ! ( "{}.png" , i) ) ;
136+ b. save ( path. clone ( ) ) . unwrap ( ) ;
137+ debug ! ( "Saved generated target to {}" , path. display( ) ) ;
138+ } ) ;
139+ } ) ;
126140
127141
128142 debug ! ( "Generation completed, generated {} in average {}ms" , amount, start. elapsed( ) . as_millis( ) / amount as u128 ) ;
@@ -161,7 +175,7 @@ impl TargetGenerator {
161175 }
162176
163177 pub fn close ( & self ) {
164- self . coco_generator . save ( ) ;
178+ self . coco_generator . lock ( ) . unwrap ( ) . save ( ) ;
165179 }
166180}
167181
@@ -178,5 +192,19 @@ pub fn test_generate_target() {
178192 b. save ( "output_1.png" . to_string ( ) ) . unwrap ( ) ;
179193 debug ! ( "Saved generated target to output_1.png" ) ;
180194
195+ tg. close ( ) ;
196+ }
197+
198+ #[ test]
199+ #[ ignore]
200+ pub fn test_generate_targets ( ) {
201+ SimpleLogger :: new ( ) . init ( ) . unwrap ( ) ;
202+
203+ let mut tg = TargetGenerator :: new ( "output" , "backgrounds" , "objects" , "output/annotations.json" ) . unwrap ( ) ;
204+ tg. config . permit_duplicates = true ;
205+ tg. config . permit_collisions = false ;
206+ tg. config . visualize_bboxes = true ;
207+ tg. generate_targets ( 500 , ..6u32 , "output" ) . unwrap ( ) ;
208+
181209 tg. close ( ) ;
182210}
0 commit comments