11using System ;
22using System . IO ;
3+ using System . Runtime . CompilerServices ;
34using JetBrains . Annotations ;
45using NeuralNetworkNET . APIs . Enums ;
6+ using NeuralNetworkNET . APIs . Interfaces ;
57using NeuralNetworkNET . APIs . Structs ;
68using NeuralNetworkNET . Extensions ;
79using NeuralNetworkNET . Helpers ;
@@ -30,14 +32,20 @@ internal abstract class BatchNormalizationLayerBase : WeightedLayerBase
3032 [ NotNull ]
3133 public float [ ] Sigma2 { get ; }
3234
33- // The current iteration number (for the Cumulative Moving Average)
34- private int _Iteration ;
35+ /// <summary>
36+ /// Gets the current iteration number (for the Cumulative Moving Average)
37+ /// </summary>
38+ public int Iteration { get ; private set ; }
3539
3640 /// <summary>
3741 /// Gets the current CMA factor used to update the <see cref="Mu"/> and <see cref="Sigma2"/> tensors
3842 /// </summary>
3943 [ JsonProperty ( nameof ( CumulativeMovingAverageFactor ) , Order = 6 ) ]
40- public float CumulativeMovingAverageFactor => 1f / ( 1 + _Iteration ) ;
44+ public float CumulativeMovingAverageFactor
45+ {
46+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
47+ get => 1f / ( 1 + Iteration ) ;
48+ }
4149
4250 /// <inheritdoc/>
4351 public override String Hash => Convert . ToBase64String ( Sha256 . Hash ( Weights , Biases , Mu , Sigma2 ) ) ;
@@ -74,24 +82,26 @@ protected BatchNormalizationLayerBase(in TensorInfo shape, NormalizationMode mod
7482 NormalizationMode = mode ;
7583 }
7684
77- protected BatchNormalizationLayerBase ( in TensorInfo shape , NormalizationMode mode , [ NotNull ] float [ ] w , [ NotNull ] float [ ] b , [ NotNull ] float [ ] mu , [ NotNull ] float [ ] sigma2 , ActivationType activation )
85+ protected BatchNormalizationLayerBase ( in TensorInfo shape , NormalizationMode mode , [ NotNull ] float [ ] w , [ NotNull ] float [ ] b , int iteration , [ NotNull ] float [ ] mu , [ NotNull ] float [ ] sigma2 , ActivationType activation )
7886 : base ( shape , shape , w , b , activation )
7987 {
8088 if ( w . Length != b . Length ) throw new ArgumentException ( "The size for both gamme and beta paarameters must be the same" ) ;
8189 if ( mode == NormalizationMode . Spatial && w . Length != shape . Channels ||
8290 mode == NormalizationMode . PerActivation && w . Length != shape . Size )
8391 throw new ArgumentException ( "Invalid parameters size for the selected normalization mode" ) ;
92+ if ( iteration < 0 ) throw new ArgumentOutOfRangeException ( nameof ( iteration ) , "The iteration value must be aat least equal to 0" ) ;
8493 if ( mu . Length != w . Length || sigma2 . Length != w . Length )
8594 throw new ArgumentException ( "The mu and sigma2 parameters must match the shape of the gamma and beta parameters" ) ;
8695 NormalizationMode = mode ;
96+ Iteration = iteration ;
8797 Mu = mu ;
8898 Sigma2 = sigma2 ;
8999 }
90100
91101 /// <inheritdoc/>
92102 public override void Forward ( in Tensor x , out Tensor z , out Tensor a )
93103 {
94- if ( NetworkTrainer . BackpropagationInProgress ) ForwardTraining ( 1f / ( 1 + _Iteration ++ ) , x , out z , out a ) ;
104+ if ( NetworkTrainer . BackpropagationInProgress ) ForwardTraining ( 1f / ( 1 + Iteration ++ ) , x , out z , out a ) ;
95105 else ForwardInference ( x , out z , out a ) ;
96106 }
97107
@@ -112,11 +122,22 @@ public override void Forward(in Tensor x, out Tensor z, out Tensor a)
112122 /// <param name="a">The output activation on the current layer</param>
113123 public abstract void ForwardTraining ( float factor , in Tensor x , out Tensor z , out Tensor a ) ;
114124
125+ /// <inheritdoc/>
126+ public override bool Equals ( INetworkLayer other )
127+ {
128+ if ( ! base . Equals ( other ) ) return false ;
129+ return other is BatchNormalizationLayerBase layer &&
130+ Iteration == layer . Iteration &&
131+ Mu . ContentEquals ( layer . Mu ) &&
132+ Sigma2 . ContentEquals ( layer . Sigma2 ) ;
133+ }
134+
115135 /// <inheritdoc/>
116136 public override void Serialize ( Stream stream )
117137 {
118138 base . Serialize ( stream ) ;
119139 stream . Write ( NormalizationMode ) ;
140+ stream . Write ( Iteration ) ;
120141 stream . Write ( Mu . Length ) ;
121142 stream . WriteShuffled ( Mu ) ;
122143 stream . Write ( Sigma2 . Length ) ;
0 commit comments