99using NeuralNetworkNET . Extensions ;
1010using NeuralNetworkNET . Helpers ;
1111using NeuralNetworkNET . SupervisedLearning . Progress ;
12+ using SixLabors . ImageSharp ;
13+ using SixLabors . ImageSharp . Advanced ;
14+ using SixLabors . ImageSharp . PixelFormats ;
1215
1316namespace NeuralNetworkNET . APIs . Datasets
1417{
@@ -25,11 +28,14 @@ public static class Cifar10
2528 // 32*32 RGB images
2629 private const int SampleSize = 3072 ;
2730
31+ // A single 32*32 image
32+ private const int ImageSize = 1024 ;
33+
2834 private const String DatasetURL = "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz" ;
2935
3036 [ NotNull , ItemNotNull ]
3137 private static readonly IReadOnlyList < String > TrainingBinFilenames = Enumerable . Range ( 1 , 5 ) . Select ( i => $ "data_batch_{ i } .bin") . ToArray ( ) ;
32-
38+
3339 private const String TestBinFilename = "test_batch.bin" ;
3440
3541 #endregion
@@ -38,12 +44,13 @@ public static class Cifar10
3844 /// Downloads the CIFAR-10 training datasets and returns a new <see cref="ITestDataset"/> instance
3945 /// </summary>
4046 /// <param name="size">The desired dataset batch size</param>
47+ /// <param name="callback">The optional progress calback</param>
4148 /// <param name="token">An optional cancellation token for the operation</param>
4249 [ PublicAPI ]
4350 [ Pure , ItemCanBeNull ]
44- public static async Task < ITrainingDataset > GetTrainingDatasetAsync ( int size , CancellationToken token = default )
51+ public static async Task < ITrainingDataset > GetTrainingDatasetAsync ( int size , [ CanBeNull ] IProgress < HttpProgress > callback = null , CancellationToken token = default )
4552 {
46- IReadOnlyDictionary < String , Func < Stream > > map = await DatasetsDownloader . GetArchiveAsync ( DatasetURL , token ) ;
53+ IReadOnlyDictionary < String , Func < Stream > > map = await DatasetsDownloader . GetArchiveAsync ( DatasetURL , callback , token ) ;
4754 if ( map == null ) return null ;
4855 IReadOnlyList < ( float [ ] , float [ ] ) > [ ] data = new IReadOnlyList < ( float [ ] , float [ ] ) > [ TrainingBinFilenames . Count ] ;
4956 Parallel . For ( 0 , TrainingBinFilenames . Count , i => data [ i ] = ParseSamples ( map [ TrainingBinFilenames [ i ] ] , TrainingSamplesInBinFiles ) ) . AssertCompleted ( ) ;
@@ -54,25 +61,45 @@ public static async Task<ITrainingDataset> GetTrainingDatasetAsync(int size, Can
5461 /// Downloads the CIFAR-10 test datasets and returns a new <see cref="ITestDataset"/> instance
5562 /// </summary>
5663 /// <param name="progress">The optional progress callback to use</param>
64+ /// <param name="callback">The optional progress calback</param>
5765 /// <param name="token">An optional cancellation token for the operation</param>
5866 [ PublicAPI ]
5967 [ Pure , ItemCanBeNull ]
60- public static async Task < ITestDataset > GetTestDatasetAsync ( [ CanBeNull ] Action < TrainingProgressEventArgs > progress = null , CancellationToken token = default )
68+ public static async Task < ITestDataset > GetTestDatasetAsync ( [ CanBeNull ] Action < TrainingProgressEventArgs > progress = null , [ CanBeNull ] IProgress < HttpProgress > callback = null , CancellationToken token = default )
6169 {
62- IReadOnlyDictionary < String , Func < Stream > > map = await DatasetsDownloader . GetArchiveAsync ( DatasetURL , token ) ;
70+ IReadOnlyDictionary < String , Func < Stream > > map = await DatasetsDownloader . GetArchiveAsync ( DatasetURL , callback , token ) ;
6371 if ( map == null ) return null ;
6472 IReadOnlyList < ( float [ ] , float [ ] ) > data = ParseSamples ( map [ TestBinFilename ] , TrainingSamplesInBinFiles ) ;
6573 return DatasetLoader . Test ( data , progress ) ;
6674 }
6775
76+ /// <summary>
77+ /// Downloads and exports the full CIFAR-10 dataset (both training and test samples) to the target directory
78+ /// </summary>
79+ /// <param name="directory">The target directory</param>
80+ /// <param name="token">The cancellation token for the operation</param>
81+ [ PublicAPI ]
82+ public static async Task < bool > ExportDatasetAsync ( [ NotNull ] DirectoryInfo directory , CancellationToken token = default )
83+ {
84+ IReadOnlyDictionary < String , Func < Stream > > map = await DatasetsDownloader . GetArchiveAsync ( DatasetURL , null , token ) ;
85+ if ( map == null ) return false ;
86+ if ( ! directory . Exists ) directory . Create ( ) ;
87+ ParallelLoopResult result = Parallel . ForEach ( TrainingBinFilenames . Concat ( new [ ] { TestBinFilename } ) , ( name , state ) =>
88+ {
89+ ExportSamples ( directory , ( name , map [ name ] ) , TrainingSamplesInBinFiles , token ) ;
90+ if ( token . IsCancellationRequested ) state . Stop ( ) ;
91+ } ) ;
92+ return result . IsCompleted && ! token . IsCancellationRequested ;
93+ }
94+
6895 #region Tools
6996
7097 /// <summary>
7198 /// Parses a CIFAR-10 .bin file
7299 /// </summary>
73100 /// <param name="factory">A <see cref="Func{TResult}"/> that returns the <see cref="Stream"/> to read</param>
74101 /// <param name="count">The number of samples to parse</param>
75- private static unsafe IReadOnlyList < ( float [ ] , float [ ] ) > ParseSamples ( Func < Stream > factory , int count )
102+ private static unsafe IReadOnlyList < ( float [ ] , float [ ] ) > ParseSamples ( [ NotNull ] Func < Stream > factory , int count )
76103 {
77104 using ( Stream stream = factory ( ) )
78105 {
@@ -89,8 +116,12 @@ public static async Task<ITestDataset> GetTestDatasetAsync([CanBeNull] Action<Tr
89116 fixed ( float * px = x )
90117 {
91118 stream . Read ( temp , 0 , SampleSize ) ;
92- for ( int j = 0 ; j < SampleSize ; j ++ )
119+ for ( int j = 0 ; j < ImageSize ; j ++ )
120+ {
93121 px [ j ] = ptemp [ j ] / 255f ; // Normalized samples
122+ px [ j ] = ptemp [ j + ImageSize ] / 255f ;
123+ px [ j ] = ptemp [ j + 2 * ImageSize ] / 255f ;
124+ }
94125 }
95126 data [ i ] = ( x , y ) ;
96127 }
@@ -99,6 +130,38 @@ public static async Task<ITestDataset> GetTestDatasetAsync([CanBeNull] Action<Tr
99130 }
100131 }
101132
133+ /// <summary>
134+ /// Exports a CIFAR-10 .bin file
135+ /// </summary>
136+ /// <param name="folder">The target folder to use to save the images</param>
137+ /// <param name="source">The source filename and a <see cref="Func{TResult}"/> that returns the <see cref="Stream"/> to read</param>
138+ /// <param name="count">The number of samples to parse</param>
139+ /// <param name="token">A token for the operation</param>
140+ private static unsafe void ExportSamples ( [ NotNull ] DirectoryInfo folder , ( String Name , Func < Stream > Factory ) source , int count , CancellationToken token )
141+ {
142+ using ( Stream stream = source . Factory ( ) )
143+ {
144+ byte [ ] temp = new byte [ SampleSize ] ;
145+ fixed ( byte * ptemp = temp )
146+ {
147+ for ( int i = 0 ; i < count ; i ++ )
148+ {
149+ if ( token . IsCancellationRequested ) return ;
150+ int label = stream . ReadByte ( ) ;
151+ stream . Read ( temp , 0 , SampleSize ) ;
152+ using ( Image < Rgb24 > image = new Image < Rgb24 > ( 32 , 32 ) )
153+ fixed ( Rgb24 * p0 = & image . DangerousGetPinnableReferenceToPixelBuffer ( ) )
154+ {
155+ for ( int j = 0 ; j < ImageSize ; j ++ )
156+ p0 [ j ] = new Rgb24 ( ptemp [ j ] , ptemp [ j + ImageSize ] , ptemp [ j + 2 * ImageSize ] ) ;
157+ using ( FileStream file = File . OpenWrite ( Path . Combine ( folder . FullName , $ "[{ source . Name } ][{ i } ][{ label } ].bmp") ) )
158+ image . SaveAsBmp ( file ) ;
159+ }
160+ }
161+ }
162+ }
163+ }
164+
102165 #endregion
103166 }
104167}
0 commit comments