From b58e770728d1f3f4c27ce161465284650a73e0c7 Mon Sep 17 00:00:00 2001 From: Christian Kleineidam Date: Fri, 20 Jun 2025 16:42:58 +0200 Subject: [PATCH 1/2] Add LogLoss metric --- .gitignore | 1 + CHANGELOG.md | 4 +++ .../classification/log_loss_metric.dart | 28 +++++++++++++++++++ lib/src/metric/metric_factory_impl.dart | 4 +++ lib/src/metric/metric_type.dart | 3 ++ pubspec.yaml | 2 +- .../classification/log_loss_metric_test.dart | 28 +++++++++++++++++++ test/metric/metric_factory_impl_test.dart | 5 ++++ 8 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 lib/src/metric/classification/log_loss_metric.dart create mode 100644 test/metric/classification/log_loss_metric_test.dart diff --git a/.gitignore b/.gitignore index 31623535..17e25045 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ test/.test_coverage.dart .DS_Store */**/.DS_Store +flutter-sdk/ diff --git a/CHANGELOG.md b/CHANGELOG.md index d0179b9a..83e60a5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 16.17.14 +- Added `MetricType.logLoss` and `LogLossMetric` for evaluating probabilistic + binary classifiers + ## 16.17.13 - Added Decision Tree web demo using Web Assembly diff --git a/lib/src/metric/classification/log_loss_metric.dart b/lib/src/metric/classification/log_loss_metric.dart new file mode 100644 index 00000000..718e9b8a --- /dev/null +++ b/lib/src/metric/classification/log_loss_metric.dart @@ -0,0 +1,28 @@ +import 'package:ml_algo/src/helpers/validate_matrix_columns.dart'; +import 'package:ml_algo/src/metric/metric.dart'; +import 'package:ml_linalg/matrix.dart'; +import 'dart:math' as math; + +class LogLossMetric implements Metric { + const LogLossMetric({this.eps = 1e-15}); + + final double eps; + + double _clip(double p) => p < eps ? eps : (p > 1.0 - eps ? 1.0 - eps : p); + + @override + double getScore(Matrix predictedLabels, Matrix origLabels) { + validateMatrixColumns([predictedLabels, origLabels]); + + final preds = predictedLabels.toVector(); + final orig = origLabels.toVector(); + + var sum = 0.0; + for (var i = 0; i < preds.length; i++) { + final p = _clip(preds[i]); + final y = orig[i]; + sum += y == 1 ? -math.log(p) : -math.log(1.0 - p); + } + return sum / preds.length; + } +} diff --git a/lib/src/metric/metric_factory_impl.dart b/lib/src/metric/metric_factory_impl.dart index 2944975f..95c88353 100644 --- a/lib/src/metric/metric_factory_impl.dart +++ b/lib/src/metric/metric_factory_impl.dart @@ -1,6 +1,7 @@ import 'package:ml_algo/src/metric/classification/accuracy.dart'; import 'package:ml_algo/src/metric/classification/precision.dart'; import 'package:ml_algo/src/metric/classification/recall.dart'; +import 'package:ml_algo/src/metric/classification/log_loss_metric.dart'; import 'package:ml_algo/src/metric/metric.dart'; import 'package:ml_algo/src/metric/metric_factory.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; @@ -28,6 +29,9 @@ class MetricFactoryImpl implements MetricFactory { case MetricType.recall: return const RecallMetric(); + case MetricType.logLoss: + return const LogLossMetric(); + default: throw UnsupportedError('Unsupported metric type $type'); } diff --git a/lib/src/metric/metric_type.dart b/lib/src/metric/metric_type.dart index 7a8eb971..babf6917 100644 --- a/lib/src/metric/metric_type.dart +++ b/lib/src/metric/metric_type.dart @@ -100,4 +100,7 @@ enum MetricType { /// better the prediction's quality is. The metric produces scores within the /// range [0, 1] recall, + + /// Binary cross-entropy (a.k.a. log-loss) + logLoss, } diff --git a/pubspec.yaml b/pubspec.yaml index 9bd76b99..b1674a6a 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,6 +1,6 @@ name: ml_algo description: Machine learning algorithms, Machine learning models performance evaluation functionality -version: 16.17.13 +version: 16.17.14 homepage: https://github.com/gyrdym/ml_algo environment: diff --git a/test/metric/classification/log_loss_metric_test.dart b/test/metric/classification/log_loss_metric_test.dart new file mode 100644 index 00000000..23e28a92 --- /dev/null +++ b/test/metric/classification/log_loss_metric_test.dart @@ -0,0 +1,28 @@ +import 'package:ml_algo/src/metric/classification/log_loss_metric.dart'; +import 'package:ml_linalg/matrix.dart'; +import 'package:test/test.dart'; + +void main() { + group('LogLossMetric', () { + const metric = LogLossMetric(); + + test('perfect predictions → loss ≈ 0', () { + final yTrue = Matrix.column([1, 0, 1, 0]); + final yPred = Matrix.column([1.0, 0.0, 1.0, 0.0]); + expect(metric.getScore(yPred, yTrue), closeTo(0.0, 1e-12)); + }); + + test('typical predictions', () { + final yTrue = Matrix.column([1, 0]); + final yPred = Matrix.column([0.9, 0.1]); + expect(metric.getScore(yPred, yTrue), + closeTo(0.10536051565782628, 1e-6)); // -ln(0.9) + }); + + test('probabilities are clipped', () { + final yTrue = Matrix.column([1, 0]); + final yPred = Matrix.column([0.0, 1.0]); + expect(metric.getScore(yPred, yTrue).isFinite, isTrue); + }); + }); +} diff --git a/test/metric/metric_factory_impl_test.dart b/test/metric/metric_factory_impl_test.dart index 418d5333..2dcf4c34 100644 --- a/test/metric/metric_factory_impl_test.dart +++ b/test/metric/metric_factory_impl_test.dart @@ -1,6 +1,7 @@ import 'package:ml_algo/src/metric/classification/accuracy.dart'; import 'package:ml_algo/src/metric/classification/precision.dart'; import 'package:ml_algo/src/metric/classification/recall.dart'; +import 'package:ml_algo/src/metric/classification/log_loss_metric.dart'; import 'package:ml_algo/src/metric/metric_factory_impl.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/metric/regression/mape.dart'; @@ -31,5 +32,9 @@ void main() { test('should create RecallMetric instance', () { expect(factory.createByType(MetricType.recall), isA()); }); + + test('should create LogLossMetric instance', () { + expect(factory.createByType(MetricType.logLoss), isA()); + }); }); } From 06be91324c15ab8e1adb1380d0fd851e190666b7 Mon Sep 17 00:00:00 2001 From: Christian Kleineidam Date: Wed, 25 Jun 2025 12:38:41 +0200 Subject: [PATCH 2/2] Bump version to 16.17.15 --- CHANGELOG.md | 4 ++-- pubspec.yaml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 83e60a5c..0aa90c7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,7 @@ # Changelog -## 16.17.14 -- Added `MetricType.logLoss` and `LogLossMetric` for evaluating probabilistic +## 16.17.15 + - Added `MetricType.logLoss` and `LogLossMetric` for evaluating probabilistic binary classifiers ## 16.17.13 diff --git a/pubspec.yaml b/pubspec.yaml index b1674a6a..7ece0879 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -1,6 +1,6 @@ name: ml_algo description: Machine learning algorithms, Machine learning models performance evaluation functionality -version: 16.17.14 +version: 16.17.15 homepage: https://github.com/gyrdym/ml_algo environment: