Skip to content

Commit c0be6aa

Browse files
authored
Merge pull request #67 from Sergio0694/dev
Bug fixes and adjustments
2 parents 84328d6 + 92f09cb commit c0be6aa

File tree

8 files changed

+128
-53
lines changed

8 files changed

+128
-53
lines changed

NeuralNetwork.NET/APIs/NetworkManager.cs

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ public static INeuralNetwork NewSequential(TensorInfo input, [NotNull, ItemNotNu
4141
}).ToArray());
4242
}
4343

44-
#region Training
44+
#region Training APIs
4545

4646
/// <summary>
4747
/// Trains a neural network with the given parameters
@@ -70,17 +70,7 @@ public static TrainingSessionResult TrainNetwork(
7070
[CanBeNull] ITestDataset testDataset = null,
7171
CancellationToken token = default)
7272
{
73-
// Preliminary checks
74-
if (dropout < 0 || dropout >= 1) throw new ArgumentOutOfRangeException(nameof(dropout), "The dropout probability is invalid");
75-
76-
// Start the training
77-
return NetworkTrainer.TrainNetwork(
78-
network as SequentialNetwork ?? throw new ArgumentException("The input network instance isn't valid", nameof(network)),
79-
dataset as BatchesCollection ?? throw new ArgumentException("The input dataset instance isn't valid", nameof(dataset)),
80-
epochs, dropout, algorithm, batchCallback.AsIProgress(), trainingCallback.AsIProgress(),
81-
validationDataset as ValidationDataset,
82-
testDataset as TestDataset,
83-
token);
73+
return TrainNetworkCore(network, dataset, algorithm, epochs, dropout, batchCallback.AsIProgress(), trainingCallback.AsIProgress(), validationDataset, testDataset, token);
8474
}
8575

8676
/// <summary>
@@ -110,9 +100,38 @@ public static Task<TrainingSessionResult> TrainNetworkAsync(
110100
[CanBeNull] ITestDataset testDataset = null,
111101
CancellationToken token = default)
112102
{
113-
return Task.Run(() => TrainNetwork(network, dataset, algorithm, epochs, dropout, batchCallback, trainingCallback, validationDataset, testDataset, token), token);
103+
IProgress<BatchProgress> batchProgress = batchCallback.AsIProgress();
104+
IProgress<TrainingProgressEventArgs> trainingProgress = trainingCallback.AsIProgress(); // Capture the synchronization contexts
105+
return Task.Run(() => TrainNetworkCore(network, dataset, algorithm, epochs, dropout, batchProgress, trainingProgress, validationDataset, testDataset, token), token);
114106
}
115107

116108
#endregion
109+
110+
// Core trainer method with additional checks
111+
[NotNull]
112+
private static TrainingSessionResult TrainNetworkCore(
113+
[NotNull] INeuralNetwork network,
114+
[NotNull] ITrainingDataset dataset,
115+
[NotNull] ITrainingAlgorithmInfo algorithm,
116+
int epochs, float dropout,
117+
[CanBeNull] IProgress<BatchProgress> batchProgress,
118+
[CanBeNull] IProgress<TrainingProgressEventArgs> trainingProgress,
119+
[CanBeNull] IValidationDataset validationDataset,
120+
[CanBeNull] ITestDataset testDataset,
121+
CancellationToken token)
122+
{
123+
// Preliminary checks
124+
if (epochs < 1) throw new ArgumentOutOfRangeException(nameof(epochs), "The number of epochs must at be at least equal to 1");
125+
if (dropout < 0 || dropout >= 1) throw new ArgumentOutOfRangeException(nameof(dropout), "The dropout probability is invalid");
126+
127+
// Start the training
128+
return NetworkTrainer.TrainNetwork(
129+
network as SequentialNetwork ?? throw new ArgumentException("The input network instance isn't valid", nameof(network)),
130+
dataset as BatchesCollection ?? throw new ArgumentException("The input dataset instance isn't valid", nameof(dataset)),
131+
epochs, dropout, algorithm, batchProgress, trainingProgress,
132+
validationDataset as ValidationDataset,
133+
testDataset as TestDataset,
134+
token);
135+
}
117136
}
118137
}

NeuralNetwork.NET/Extensions/MiscExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ internal static void AssertCompleted(in this ParallelLoopResult result)
155155
[Pure, NotNull]
156156
internal static String TrimVerbatim([NotNull] this String text)
157157
{
158-
String[] lines = text.Split(new[] { Environment.NewLine }, StringSplitOptions.None);
158+
String[] lines = text.Split(new[] { "\r\n", "\n" }, StringSplitOptions.RemoveEmptyEntries);
159159
return lines.Aggregate(new StringBuilder(), (b, s) =>
160160
{
161161
b.AppendLine(s.Trim());

NeuralNetwork.NET/SupervisedLearning/Data/BatchesCollection.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,14 @@ public int BatchesCount
4040
/// <inheritdoc/>
4141
public (ITrainingDataset, ITestDataset) PartitionWithTest(float ratio, Action<TrainingProgressEventArgs> progress = null)
4242
{
43-
if (ratio <= 0 || ratio >= 1) throw new ArgumentOutOfRangeException(nameof(ratio), "The ratio must be in the (0,1) range");
44-
int left = ((int)(Count * (1 - ratio))).Max(1); // Ensure there's at least one element
43+
int left = CalculatePartitionSize(ratio);
4544
return (DatasetLoader.Training(Take(0, left), BatchSize), DatasetLoader.Test(Take(left, Count), progress));
4645
}
4746

4847
/// <inheritdoc/>
4948
public (ITrainingDataset, IValidationDataset) PartitionWithValidation(float ratio, float tolerance = 1e-2f, int epochs = 5)
5049
{
51-
if (ratio <= 0 || ratio >= 1) throw new ArgumentOutOfRangeException(nameof(ratio), "The ratio must be in the (0,1) range");
52-
int left = ((int)(Count * (1 - ratio))).Max(1); // Ensure there's at least one element
50+
int left = CalculatePartitionSize(ratio);
5351
return (DatasetLoader.Training(Take(0, left), BatchSize), DatasetLoader.Validation(Take(left, Count), tolerance, epochs));
5452
}
5553

@@ -231,6 +229,16 @@ from i in Enumerable.Range(0, batch.X.GetLength(0))
231229
}
232230
}
233231

232+
// Computes the size of the first dataset partition given a partition ratio
233+
[Pure]
234+
private int CalculatePartitionSize(float ratio)
235+
{
236+
if (ratio <= 0 || ratio >= 1) throw new ArgumentOutOfRangeException(nameof(ratio), "The ratio must be in the (0,1) range");
237+
int left = ((int)(Count * (1 - ratio))).Max(10); // Ensure there are at least 10 elements
238+
if (Count - left < 10) throw new ArgumentOutOfRangeException(nameof(ratio), "Each partition must have at least 10 samples");
239+
return left;
240+
}
241+
234242
#endregion
235243

236244
#region Shuffling

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
![](http://i.pi.gy/8ZDDE.png)
2-
[![NuGet](https://img.shields.io/nuget/v/NeuralNetwork.NET.svg)](https://www.nuget.org/packages/NeuralNetwork.NET/) [![NuGet](https://img.shields.io/nuget/dt/NeuralNetwork.NET.svg)](https://www.nuget.org/stats/packages/NeuralNetwork.NET?groupby=Version) [![Twitter Follow](https://img.shields.io/twitter/follow/Sergio0694.svg?style=social&label=Follow)](https://twitter.com/SergioPedri)
2+
[![NuGet](https://img.shields.io/nuget/v/NeuralNetwork.NET.svg)](https://www.nuget.org/packages/NeuralNetwork.NET/) [![NuGet](https://img.shields.io/nuget/dt/NeuralNetwork.NET.svg)](https://www.nuget.org/stats/packages/NeuralNetwork.NET?groupby=Version) [![AppVeyor](https://img.shields.io/appveyor/ci/Sergio0694/neuralnetwork-net.svg)](https://ci.appveyor.com/project/Sergio0694/neuralnetwork-net) [![AppVeyor tests](https://img.shields.io/appveyor/tests/Sergio0694/neuralnetwork-net.svg)](https://ci.appveyor.com/project/Sergio0694/neuralnetwork-net) [![Twitter Follow](https://img.shields.io/twitter/follow/Sergio0694.svg?style=social&label=Follow)](https://twitter.com/SergioPedri)
33

44
# What is it?
55

Unit/NeuralNetwork.NET.Unit/DatasetLoadingTest.cs

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using Microsoft.VisualStudio.TestTools.UnitTesting;
55
using NeuralNetworkNET.APIs;
66
using NeuralNetworkNET.APIs.Interfaces.Data;
7+
using NeuralNetworkNET.APIs.Structs;
78
using NeuralNetworkNET.Extensions;
89
using NeuralNetworkNET.Helpers;
910
using NeuralNetworkNET.SupervisedLearning.Data;
@@ -44,14 +45,12 @@ public void BatchDivisionTest1()
4445
x = Enumerable.Range(0, 20000 * 784).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(20000, 784),
4546
y = Enumerable.Range(0, 20000 * 10).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(20000, 10);
4647
BatchesCollection batches = BatchesCollection.From((x, y), 1000);
47-
HashSet<int>
48-
set1 = new HashSet<int>();
48+
HashSet<int> set1 = new HashSet<int>();
4949
for (int i = 0; i < 20000; i++)
5050
{
5151
set1.Add(GetUid(x, i) ^ GetUid(y, i));
5252
}
53-
HashSet<int>
54-
set2 = new HashSet<int>();
53+
HashSet<int> set2 = new HashSet<int>();
5554
for (int i = 0; i < batches.BatchesCount; i++)
5655
{
5756
int h = batches.Batches[i].X.GetLength(0);
@@ -62,8 +61,7 @@ public void BatchDivisionTest1()
6261
}
6362
Assert.IsTrue(set1.OrderBy(h => h).SequenceEqual(set2.OrderBy(h => h)));
6463
batches.CrossShuffle();
65-
HashSet<int>
66-
set3 = new HashSet<int>();
64+
HashSet<int> set3 = new HashSet<int>();
6765
for (int i = 0; i < batches.BatchesCount; i++)
6866
{
6967
int h = batches.Batches[i].X.GetLength(0);
@@ -83,14 +81,12 @@ public void BatchDivisionTest2()
8381
x = Enumerable.Range(0, 20000 * 784).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(20000, 784),
8482
y = Enumerable.Range(0, 20000 * 10).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(20000, 10);
8583
BatchesCollection batches = BatchesCollection.From((x, y), 1547);
86-
HashSet<int>
87-
set1 = new HashSet<int>();
84+
HashSet<int> set1 = new HashSet<int>();
8885
for (int i = 0; i < 20000; i++)
8986
{
9087
set1.Add(GetUid(x, i) ^ GetUid(y, i));
9188
}
92-
HashSet<int>
93-
set2 = new HashSet<int>();
89+
HashSet<int> set2 = new HashSet<int>();
9490
for (int i = 0; i < batches.BatchesCount; i++)
9591
{
9692
int h = batches.Batches[i].X.GetLength(0);
@@ -101,8 +97,7 @@ public void BatchDivisionTest2()
10197
}
10298
Assert.IsTrue(set1.OrderBy(h => h).SequenceEqual(set2.OrderBy(h => h)));
10399
batches.CrossShuffle();
104-
HashSet<int>
105-
set3 = new HashSet<int>();
100+
HashSet<int> set3 = new HashSet<int>();
106101
for (int i = 0; i < batches.BatchesCount; i++)
107102
{
108103
int h = batches.Batches[i].X.GetLength(0);
@@ -145,14 +140,12 @@ public void ReshapeTest()
145140
BatchesCollection
146141
batches1 = BatchesCollection.From((x, y), 1000),
147142
batches2 = BatchesCollection.From((x, y), 1000);
148-
HashSet<int>
149-
set = new HashSet<int>();
143+
HashSet<int> set = new HashSet<int>();
150144
for (int i = 0; i < 20000; i++)
151145
{
152146
set.Add(GetUid(x, i) ^ GetUid(y, i));
153147
}
154-
HashSet<int>
155-
set1 = new HashSet<int>();
148+
HashSet<int> set1 = new HashSet<int>();
156149
for (int i = 0; i < batches1.BatchesCount; i++)
157150
{
158151
int h = batches1.Batches[i].X.GetLength(0);
@@ -163,8 +156,7 @@ public void ReshapeTest()
163156
}
164157
Assert.IsTrue(set.OrderBy(h => h).SequenceEqual(set1.OrderBy(h => h)));
165158
batches2.BatchSize = 1437;
166-
HashSet<int>
167-
set2 = new HashSet<int>();
159+
HashSet<int> set2 = new HashSet<int>();
168160
for (int i = 0; i < batches2.BatchesCount; i++)
169161
{
170162
int h = batches2.Batches[i].X.GetLength(0);
@@ -204,5 +196,57 @@ public void IdTest2()
204196
set1.To<IDataset, BatchesCollection>().CrossShuffle();
205197
Assert.IsTrue(set1.Id == set2.Id);
206198
}
199+
200+
// Calculates a unique hash code for the target vector
201+
private static unsafe int GetUid(float[] v)
202+
{
203+
fixed (float* pv = v)
204+
{
205+
int hash = 17;
206+
unchecked
207+
{
208+
for (int i = 0; i < v.Length; i++)
209+
hash = hash * 23 + pv[i].GetHashCode();
210+
return hash;
211+
}
212+
}
213+
}
214+
215+
[TestMethod]
216+
public void DatasetPartition()
217+
{
218+
float[,]
219+
x = Enumerable.Range(0, 20000 * 784).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(20000, 784),
220+
y = Enumerable.Range(0, 20000 * 10).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(20000, 10);
221+
ITrainingDataset sourceDataset = DatasetLoader.Training((x, y), 1000);
222+
(ITrainingDataset training, ITestDataset test) = sourceDataset.PartitionWithTest(0.33f);
223+
HashSet<int> set = new HashSet<int>();
224+
for (int i = 0; i < 20000; i++)
225+
{
226+
set.Add(GetUid(x, i) ^ GetUid(y, i));
227+
}
228+
HashSet<int> set1 = new HashSet<int>();
229+
for (int i = 0; i < training.Count; i++)
230+
{
231+
DatasetSample sample = training[i];
232+
set1.Add(GetUid(sample.X.ToArray()) ^ GetUid(sample.Y.ToArray()));
233+
}
234+
for (int i = 0; i < test.Count; i++)
235+
{
236+
DatasetSample sample = test[i];
237+
set1.Add(GetUid(sample.X.ToArray()) ^ GetUid(sample.Y.ToArray()));
238+
}
239+
Assert.IsTrue(set.OrderBy(h => h).SequenceEqual(set1.OrderBy(h => h)));
240+
}
241+
242+
[TestMethod]
243+
public void DatasetPartitionException()
244+
{
245+
float[,]
246+
x = Enumerable.Range(0, 15 * 784).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(15, 784),
247+
y = Enumerable.Range(0, 15 * 10).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(15, 10);
248+
ITrainingDataset sourceDataset = DatasetLoader.Training((x, y), 1000);
249+
Assert.ThrowsException<ArgumentOutOfRangeException>(() => sourceDataset.PartitionWithTest(0.33f));
250+
}
207251
}
208252
}

Unit/NeuralNetwork.NET.Unit/MiscTest.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,17 @@ public void TrimVerbatim()
7575
plt.xlabel(""Epoch"")
7676
plt.plot(x)
7777
plt.show()";
78-
const String expected = @"import matplotlib.pyplot as plt
79-
x = [$VALUES$]
80-
plt.grid(linestyle=""dashed"")
81-
plt.ylabel(""$YLABEL$"")
82-
plt.xlabel(""Epoch"")
83-
plt.plot(x)
84-
plt.show()
85-
";
78+
String[] lines =
79+
{
80+
"import matplotlib.pyplot as plt",
81+
"x = [$VALUES$]",
82+
"plt.grid(linestyle=\"dashed\")",
83+
"plt.ylabel(\"$YLABEL$\")",
84+
"plt.xlabel(\"Epoch\")",
85+
"plt.plot(x)",
86+
"plt.show()"
87+
};
88+
String expected = lines.Skip(1).Aggregate(lines[0], (s, l) => $"{s}{Environment.NewLine}{l}") + Environment.NewLine;
8689
Assert.IsTrue(text.TrimVerbatim().Equals(expected));
8790
}
8891
}

Unit/NeuralNetwork.NET.Unit/NetworkTest.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ private static ((float[,] X, float[,] Y) TrainingData, (float[,] X, float[,] Y)
8585
ParseSamples(Path.Combine(path, TestSetValuesFilename), Path.Combine(path, TestSetLabelsFilename), test));
8686
}
8787

88-
private static bool TestTrainingMethod(ITrainingAlgorithmInfo info)
88+
private static bool TestTrainingMethod(ITrainingAlgorithmInfo info, int epochs)
8989
{
9090
(var trainingSet, var testSet) = ParseMnistDataset();
9191
BatchesCollection batches = BatchesCollection.From(trainingSet, 100);
9292
SequentialNetwork network = NetworkManager.NewSequential(TensorInfo.Image<Alpha8>(28, 28),
9393
NetworkLayers.FullyConnected(100, ActivationFunctionType.Sigmoid),
9494
NetworkLayers.Softmax(10)).To<INeuralNetwork, SequentialNetwork>();
95-
TrainingSessionResult result = NetworkTrainer.TrainNetwork(network, batches, 2, 0, info, null, null, null, null, default);
95+
TrainingSessionResult result = NetworkTrainer.TrainNetwork(network, batches, epochs, 0, info, null, null, null, null, default);
9696
Assert.IsTrue(result.StopReason == TrainingStopReason.EpochsCompleted);
9797
(_, _, float accuracy) = network.Evaluate(testSet);
9898
if (accuracy < 80)
@@ -106,19 +106,19 @@ private static bool TestTrainingMethod(ITrainingAlgorithmInfo info)
106106
}
107107

108108
[TestMethod]
109-
public void GradientDescentTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.StochasticGradientDescent()));
109+
public void GradientDescentTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.StochasticGradientDescent(0.1f), 1));
110110

111111
[TestMethod]
112-
public void MomentumTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.Momentum()));
112+
public void MomentumTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.Momentum(0.1f), 1));
113113

114114
[TestMethod]
115-
public void AdaGradTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.AdaGrad(0.1f)));
115+
public void AdaGradTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.AdaGrad(0.1f), 2));
116116

117117
[TestMethod]
118-
public void AdaDeltaTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.AdaDelta()));
118+
public void AdaDeltaTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.AdaDelta(), 1));
119119

120120
[TestMethod]
121-
public void AdamTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.Adam()));
121+
public void AdamTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.Adam(), 1));
122122

123123
[TestMethod]
124124
public void AdaMaxTest()
@@ -140,6 +140,6 @@ public void AdaMaxTest()
140140
}
141141

142142
[TestMethod]
143-
public void RMSPropTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.RMSProp()));
143+
public void RMSPropTest() => Assert.IsTrue(TestTrainingMethod(TrainingAlgorithms.RMSProp(), 1));
144144
}
145145
}

Unit/NeuralNetwork.NET.Unit/NeuralNetwork.NET.Unit.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
</ItemGroup>
3535

3636
<ItemGroup>
37+
<PackageReference Include="Appveyor.TestLogger" Version="2.0.0" />
3738
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.5.0" />
3839
<PackageReference Include="MSTest.TestAdapter" Version="1.2.0" />
3940
<PackageReference Include="MSTest.TestFramework" Version="1.2.0" />

0 commit comments

Comments
 (0)