44using Microsoft . VisualStudio . TestTools . UnitTesting ;
55using NeuralNetworkNET . APIs ;
66using NeuralNetworkNET . APIs . Interfaces . Data ;
7+ using NeuralNetworkNET . APIs . Structs ;
78using NeuralNetworkNET . Extensions ;
89using NeuralNetworkNET . Helpers ;
910using NeuralNetworkNET . SupervisedLearning . Data ;
@@ -44,14 +45,12 @@ public void BatchDivisionTest1()
4445 x = Enumerable . Range ( 0 , 20000 * 784 ) . Select ( _ => ThreadSafeRandom . NextUniform ( 100 ) ) . ToArray ( ) . AsSpan ( ) . AsMatrix ( 20000 , 784 ) ,
4546 y = Enumerable . Range ( 0 , 20000 * 10 ) . Select ( _ => ThreadSafeRandom . NextUniform ( 100 ) ) . ToArray ( ) . AsSpan ( ) . AsMatrix ( 20000 , 10 ) ;
4647 BatchesCollection batches = BatchesCollection . From ( ( x , y ) , 1000 ) ;
47- HashSet < int >
48- set1 = new HashSet < int > ( ) ;
48+ HashSet < int > set1 = new HashSet < int > ( ) ;
4949 for ( int i = 0 ; i < 20000 ; i ++ )
5050 {
5151 set1 . Add ( GetUid ( x , i ) ^ GetUid ( y , i ) ) ;
5252 }
53- HashSet < int >
54- set2 = new HashSet < int > ( ) ;
53+ HashSet < int > set2 = new HashSet < int > ( ) ;
5554 for ( int i = 0 ; i < batches . BatchesCount ; i ++ )
5655 {
5756 int h = batches . Batches [ i ] . X . GetLength ( 0 ) ;
@@ -62,8 +61,7 @@ public void BatchDivisionTest1()
6261 }
6362 Assert . IsTrue ( set1 . OrderBy ( h => h ) . SequenceEqual ( set2 . OrderBy ( h => h ) ) ) ;
6463 batches . CrossShuffle ( ) ;
65- HashSet < int >
66- set3 = new HashSet < int > ( ) ;
64+ HashSet < int > set3 = new HashSet < int > ( ) ;
6765 for ( int i = 0 ; i < batches . BatchesCount ; i ++ )
6866 {
6967 int h = batches . Batches [ i ] . X . GetLength ( 0 ) ;
@@ -83,14 +81,12 @@ public void BatchDivisionTest2()
8381 x = Enumerable . Range ( 0 , 20000 * 784 ) . Select ( _ => ThreadSafeRandom . NextUniform ( 100 ) ) . ToArray ( ) . AsSpan ( ) . AsMatrix ( 20000 , 784 ) ,
8482 y = Enumerable . Range ( 0 , 20000 * 10 ) . Select ( _ => ThreadSafeRandom . NextUniform ( 100 ) ) . ToArray ( ) . AsSpan ( ) . AsMatrix ( 20000 , 10 ) ;
8583 BatchesCollection batches = BatchesCollection . From ( ( x , y ) , 1547 ) ;
86- HashSet < int >
87- set1 = new HashSet < int > ( ) ;
84+ HashSet < int > set1 = new HashSet < int > ( ) ;
8885 for ( int i = 0 ; i < 20000 ; i ++ )
8986 {
9087 set1 . Add ( GetUid ( x , i ) ^ GetUid ( y , i ) ) ;
9188 }
92- HashSet < int >
93- set2 = new HashSet < int > ( ) ;
89+ HashSet < int > set2 = new HashSet < int > ( ) ;
9490 for ( int i = 0 ; i < batches . BatchesCount ; i ++ )
9591 {
9692 int h = batches . Batches [ i ] . X . GetLength ( 0 ) ;
@@ -101,8 +97,7 @@ public void BatchDivisionTest2()
10197 }
10298 Assert . IsTrue ( set1 . OrderBy ( h => h ) . SequenceEqual ( set2 . OrderBy ( h => h ) ) ) ;
10399 batches . CrossShuffle ( ) ;
104- HashSet < int >
105- set3 = new HashSet < int > ( ) ;
100+ HashSet < int > set3 = new HashSet < int > ( ) ;
106101 for ( int i = 0 ; i < batches . BatchesCount ; i ++ )
107102 {
108103 int h = batches . Batches [ i ] . X . GetLength ( 0 ) ;
@@ -145,14 +140,12 @@ public void ReshapeTest()
145140 BatchesCollection
146141 batches1 = BatchesCollection . From ( ( x , y ) , 1000 ) ,
147142 batches2 = BatchesCollection . From ( ( x , y ) , 1000 ) ;
148- HashSet < int >
149- set = new HashSet < int > ( ) ;
143+ HashSet < int > set = new HashSet < int > ( ) ;
150144 for ( int i = 0 ; i < 20000 ; i ++ )
151145 {
152146 set . Add ( GetUid ( x , i ) ^ GetUid ( y , i ) ) ;
153147 }
154- HashSet < int >
155- set1 = new HashSet < int > ( ) ;
148+ HashSet < int > set1 = new HashSet < int > ( ) ;
156149 for ( int i = 0 ; i < batches1 . BatchesCount ; i ++ )
157150 {
158151 int h = batches1 . Batches [ i ] . X . GetLength ( 0 ) ;
@@ -163,8 +156,7 @@ public void ReshapeTest()
163156 }
164157 Assert . IsTrue ( set . OrderBy ( h => h ) . SequenceEqual ( set1 . OrderBy ( h => h ) ) ) ;
165158 batches2 . BatchSize = 1437 ;
166- HashSet < int >
167- set2 = new HashSet < int > ( ) ;
159+ HashSet < int > set2 = new HashSet < int > ( ) ;
168160 for ( int i = 0 ; i < batches2 . BatchesCount ; i ++ )
169161 {
170162 int h = batches2 . Batches [ i ] . X . GetLength ( 0 ) ;
@@ -204,5 +196,57 @@ public void IdTest2()
204196 set1 . To < IDataset , BatchesCollection > ( ) . CrossShuffle ( ) ;
205197 Assert . IsTrue ( set1 . Id == set2 . Id ) ;
206198 }
199+
200+ // Calculates a unique hash code for the target vector
201+ private static unsafe int GetUid ( float [ ] v )
202+ {
203+ fixed ( float * pv = v )
204+ {
205+ int hash = 17 ;
206+ unchecked
207+ {
208+ for ( int i = 0 ; i < v . Length ; i ++ )
209+ hash = hash * 23 + pv [ i ] . GetHashCode ( ) ;
210+ return hash ;
211+ }
212+ }
213+ }
214+
215+ [ TestMethod ]
216+ public void DatasetPartition ( )
217+ {
218+ float [ , ]
219+ x = Enumerable . Range ( 0 , 20000 * 784 ) . Select ( _ => ThreadSafeRandom . NextUniform ( 100 ) ) . ToArray ( ) . AsSpan ( ) . AsMatrix ( 20000 , 784 ) ,
220+ y = Enumerable . Range ( 0 , 20000 * 10 ) . Select ( _ => ThreadSafeRandom . NextUniform ( 100 ) ) . ToArray ( ) . AsSpan ( ) . AsMatrix ( 20000 , 10 ) ;
221+ ITrainingDataset sourceDataset = DatasetLoader . Training ( ( x , y ) , 1000 ) ;
222+ ( ITrainingDataset training , ITestDataset test ) = sourceDataset . PartitionWithTest ( 0.33f ) ;
223+ HashSet < int > set = new HashSet < int > ( ) ;
224+ for ( int i = 0 ; i < 20000 ; i ++ )
225+ {
226+ set . Add ( GetUid ( x , i ) ^ GetUid ( y , i ) ) ;
227+ }
228+ HashSet < int > set1 = new HashSet < int > ( ) ;
229+ for ( int i = 0 ; i < training . Count ; i ++ )
230+ {
231+ DatasetSample sample = training [ i ] ;
232+ set1 . Add ( GetUid ( sample . X . ToArray ( ) ) ^ GetUid ( sample . Y . ToArray ( ) ) ) ;
233+ }
234+ for ( int i = 0 ; i < test . Count ; i ++ )
235+ {
236+ DatasetSample sample = test [ i ] ;
237+ set1 . Add ( GetUid ( sample . X . ToArray ( ) ) ^ GetUid ( sample . Y . ToArray ( ) ) ) ;
238+ }
239+ Assert . IsTrue ( set . OrderBy ( h => h ) . SequenceEqual ( set1 . OrderBy ( h => h ) ) ) ;
240+ }
241+
242+ [ TestMethod ]
243+ public void DatasetPartitionException ( )
244+ {
245+ float [ , ]
246+ x = Enumerable . Range ( 0 , 15 * 784 ) . Select ( _ => ThreadSafeRandom . NextUniform ( 100 ) ) . ToArray ( ) . AsSpan ( ) . AsMatrix ( 15 , 784 ) ,
247+ y = Enumerable . Range ( 0 , 15 * 10 ) . Select ( _ => ThreadSafeRandom . NextUniform ( 100 ) ) . ToArray ( ) . AsSpan ( ) . AsMatrix ( 15 , 10 ) ;
248+ ITrainingDataset sourceDataset = DatasetLoader . Training ( ( x , y ) , 1000 ) ;
249+ Assert . ThrowsException < ArgumentOutOfRangeException > ( ( ) => sourceDataset . PartitionWithTest ( 0.33f ) ) ;
250+ }
207251 }
208252}
0 commit comments