Skip to content

Commit 9e877d1

Browse files
committed
Add metrics of BinaryAccuracy, CategoricalAccuracy, CategoricalCrossentropy.
1 parent 9891998 commit 9e877d1

File tree

12 files changed

+244
-20
lines changed

12 files changed

+244
-20
lines changed

src/TensorFlowNET.Core/Keras/Metrics/IMetricFunc.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
public interface IMetricFunc
44
{
5+
string Name { get; }
56
/// <summary>
67
/// Accumulates metric statistics.
78
/// </summary>

src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ public interface IMetricsApi
55
Tensor binary_accuracy(Tensor y_true, Tensor y_pred);
66

77
Tensor categorical_accuracy(Tensor y_true, Tensor y_pred);
8+
Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred,
9+
bool from_logits = false,
10+
float label_smoothing = 0f,
11+
Axis? axis = null);
812

913
Tensor mean_absolute_error(Tensor y_true, Tensor y_pred);
1014

@@ -27,14 +31,39 @@ public interface IMetricsApi
2731
/// <returns></returns>
2832
Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5);
2933

34+
/// <summary>
35+
/// Calculates how often predictions match binary labels.
36+
/// </summary>
37+
/// <returns></returns>
38+
IMetricFunc BinaryAccuracy(string name = "binary_accuracy",
39+
TF_DataType dtype = TF_DataType.TF_FLOAT,
40+
float threshold = 05f);
41+
42+
/// <summary>
43+
/// Calculates how often predictions match one-hot labels.
44+
/// </summary>
45+
/// <returns></returns>
46+
IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy",
47+
TF_DataType dtype = TF_DataType.TF_FLOAT,
48+
bool from_logits = false,
49+
float label_smoothing = 0f,
50+
Axis? axis = null);
51+
52+
/// <summary>
53+
/// Computes the crossentropy metric between the labels and predictions.
54+
/// </summary>
55+
/// <returns></returns>
56+
IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy",
57+
TF_DataType dtype = TF_DataType.TF_FLOAT);
58+
3059
/// <summary>
3160
/// Computes how often targets are in the top K predictions.
3261
/// </summary>
33-
/// <param name="y_true"></param>
34-
/// <param name="y_pred"></param>
3562
/// <param name="k"></param>
3663
/// <returns></returns>
37-
IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT);
64+
IMetricFunc TopKCategoricalAccuracy(int k = 5,
65+
string name = "top_k_categorical_accuracy",
66+
TF_DataType dtype = TF_DataType.TF_FLOAT);
3867

3968
/// <summary>
4069
/// Computes the precision of the predictions with respect to the labels.
@@ -45,7 +74,11 @@ public interface IMetricsApi
4574
/// <param name="name"></param>
4675
/// <param name="dtype"></param>
4776
/// <returns></returns>
48-
IMetricFunc Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT);
77+
IMetricFunc Precision(float thresholds = 0.5f,
78+
int top_k = 0,
79+
int class_id = 0,
80+
string name = "recall",
81+
TF_DataType dtype = TF_DataType.TF_FLOAT);
4982

5083
/// <summary>
5184
/// Computes the recall of the predictions with respect to the labels.
@@ -56,5 +89,9 @@ public interface IMetricsApi
5689
/// <param name="name"></param>
5790
/// <param name="dtype"></param>
5891
/// <returns></returns>
59-
IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT);
92+
IMetricFunc Recall(float thresholds = 0.5f,
93+
int top_k = 0,
94+
int class_id = 0,
95+
string name = "recall",
96+
TF_DataType dtype = TF_DataType.TF_FLOAT);
6097
}

src/TensorFlowNET.Keras/Engine/MetricsContainer.cs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,21 @@ namespace Tensorflow.Keras.Engine
99
{
1010
public class MetricsContainer : Container
1111
{
12-
string[] _user_metrics;
13-
string[] _metric_names;
14-
Metric[] _metrics;
15-
List<Metric> _metrics_in_order;
12+
IMetricFunc[] _user_metrics = new IMetricFunc[0];
13+
string[] _metric_names = new string[0];
14+
Metric[] _metrics = new Metric[0];
15+
List<IMetricFunc> _metrics_in_order = new List<IMetricFunc>();
1616

17-
public MetricsContainer(string[] metrics, string[] output_names = null)
17+
public MetricsContainer(IMetricFunc[] metrics, string[] output_names = null)
1818
: base(output_names)
1919
{
2020
_user_metrics = metrics;
21+
_built = false;
22+
}
23+
24+
public MetricsContainer(string[] metrics, string[] output_names = null)
25+
: base(output_names)
26+
{
2127
_metric_names = metrics;
2228
_built = false;
2329
}
@@ -46,17 +52,19 @@ void _set_metric_names()
4652

4753
void _create_ordered_metrics()
4854
{
49-
_metrics_in_order = new List<Metric>();
5055
foreach (var m in _metrics)
5156
_metrics_in_order.append(m);
57+
58+
foreach(var m in _user_metrics)
59+
_metrics_in_order.append(m);
5260
}
5361

5462
Metric[] _get_metric_objects(string[] metrics, Tensor y_t, Tensor y_p)
5563
{
5664
return metrics.Select(x => _get_metric_object(x, y_t, y_p)).ToArray();
5765
}
5866

59-
Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p)
67+
public Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p)
6068
{
6169
Func<Tensor, Tensor, Tensor> metric_obj = null;
6270
if (metric == "accuracy" || metric == "acc")
@@ -94,7 +102,7 @@ Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p)
94102
return new MeanMetricWrapper(metric_obj, metric);
95103
}
96104

97-
public IEnumerable<Metric> metrics
105+
public IEnumerable<IMetricFunc> metrics
98106
{
99107
get
100108
{

src/TensorFlowNET.Keras/Engine/Model.Compile.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using Tensorflow.Keras.ArgsDefinition;
33
using Tensorflow.Keras.Losses;
4+
using Tensorflow.Keras.Metrics;
45
using Tensorflow.Keras.Optimizers;
56

67
namespace Tensorflow.Keras.Engine
@@ -31,6 +32,27 @@ public void compile(OptimizerV2 optimizer = null,
3132
_is_compiled = true;
3233
}
3334

35+
public void compile(OptimizerV2 optimizer = null,
36+
ILossFunc loss = null,
37+
IMetricFunc[] metrics = null)
38+
{
39+
this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs
40+
{
41+
});
42+
43+
this.loss = loss ?? new MeanSquaredError();
44+
45+
compiled_loss = new LossesContainer(loss, output_names: output_names);
46+
compiled_metrics = new MetricsContainer(metrics, output_names: output_names);
47+
48+
int experimental_steps_per_execution = 1;
49+
_configure_steps_per_execution(experimental_steps_per_execution);
50+
51+
// Initialize cache attrs.
52+
_reset_compile_cache();
53+
_is_compiled = true;
54+
}
55+
3456
public void compile(string optimizer, string loss, string[] metrics)
3557
{
3658
var _optimizer = optimizer switch

src/TensorFlowNET.Keras/Engine/Model.Metrics.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ namespace Tensorflow.Keras.Engine
55
{
66
public partial class Model
77
{
8-
public IEnumerable<Metric> metrics
8+
public IEnumerable<IMetricFunc> metrics
99
{
1010
get
1111
{
12-
var _metrics = new List<Metric>();
12+
var _metrics = new List<IMetricFunc>();
1313

1414
if (_is_compiled)
1515
{
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
namespace Tensorflow.Keras.Metrics;
2+
3+
public class BinaryAccuracy : MeanMetricWrapper
4+
{
5+
public BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 0.5f)
6+
: base((yt, yp) => metrics_utils.binary_matches(yt, yp),
7+
name: name,
8+
dtype: dtype)
9+
{
10+
}
11+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
namespace Tensorflow.Keras.Metrics;
2+
3+
public class CategoricalAccuracy : MeanMetricWrapper
4+
{
5+
public CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
6+
: base((yt, yp) => metrics_utils.sparse_categorical_matches(
7+
tf.math.argmax(yt, axis: -1), yp),
8+
name: name,
9+
dtype: dtype)
10+
{
11+
}
12+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
namespace Tensorflow.Keras.Metrics;
2+
3+
public class CategoricalCrossentropy : MeanMetricWrapper
4+
{
5+
public CategoricalCrossentropy(string name = "categorical_crossentropy",
6+
TF_DataType dtype = TF_DataType.TF_FLOAT,
7+
bool from_logits = false,
8+
float label_smoothing = 0f,
9+
Axis? axis = null)
10+
: base((yt, yp) => keras.metrics.categorical_crossentropy(
11+
yt, yp, from_logits: from_logits, label_smoothing: label_smoothing, axis: axis ?? -1),
12+
name: name,
13+
dtype: dtype)
14+
{
15+
}
16+
}

src/TensorFlowNET.Keras/Metrics/MetricsApi.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ public Tensor categorical_accuracy(Tensor y_true, Tensor y_pred)
1515
return math_ops.cast(eql, TF_DataType.TF_FLOAT);
1616
}
1717

18+
public Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0, Axis? axis = null)
19+
{
20+
y_true = tf.cast(y_true, y_pred.dtype);
21+
// var label_smoothing_tensor = tf.convert_to_tensor(label_smoothing, dtype: y_pred.dtype);
22+
if (label_smoothing > 0)
23+
{
24+
var num_classes = tf.cast(tf.shape(y_true)[-1], y_pred.dtype);
25+
y_true = y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes);
26+
}
27+
return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits, axis: axis);
28+
}
29+
1830
/// <summary>
1931
/// Calculates how often predictions matches integer labels.
2032
/// </summary>
@@ -59,6 +71,15 @@ public Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5
5971
);
6072
}
6173

74+
public IMetricFunc BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 5)
75+
=> new BinaryAccuracy();
76+
77+
public IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
78+
=> new CategoricalAccuracy(name: name, dtype: dtype);
79+
80+
public IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, float label_smoothing = 0, Axis? axis = null)
81+
=> new CategoricalCrossentropy();
82+
6283
public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
6384
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);
6485

src/TensorFlowNET.Keras/Metrics/metrics_utils.cs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,48 @@
11
using Tensorflow.Keras.Utils;
2-
using Tensorflow.NumPy;
32

43
namespace Tensorflow.Keras.Metrics;
54

65
public class metrics_utils
76
{
7+
public static Tensor binary_matches(Tensor y_true, Tensor y_pred, float threshold = 0.5f)
8+
{
9+
y_pred = tf.cast(y_pred > threshold, y_pred.dtype);
10+
return tf.cast(tf.equal(y_true, y_pred), keras.backend.floatx());
11+
}
12+
13+
/// <summary>
14+
/// Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch.
15+
/// </summary>
16+
/// <param name="y_true"></param>
17+
/// <param name="y_pred"></param>
18+
/// <returns></returns>
19+
public static Tensor sparse_categorical_matches(Tensor y_true, Tensor y_pred)
20+
{
21+
var reshape_matches = false;
22+
var y_true_rank = y_true.shape.ndim;
23+
var y_pred_rank = y_pred.shape.ndim;
24+
var y_true_org_shape = tf.shape(y_true);
25+
26+
if (y_true_rank > -1 && y_pred_rank > -1 && y_true.ndim == y_pred.ndim )
27+
{
28+
reshape_matches = true;
29+
y_true = tf.squeeze(y_true, new Shape(-1));
30+
}
31+
y_pred = tf.math.argmax(y_pred, axis: -1);
32+
33+
var matches = tf.cast(
34+
tf.equal(y_true, y_pred),
35+
dtype: keras.backend.floatx()
36+
);
37+
38+
if (reshape_matches)
39+
{
40+
return tf.reshape(matches, shape: y_true_org_shape);
41+
}
42+
43+
return matches;
44+
}
45+
846
public static Tensor sparse_top_k_categorical_matches(Tensor y_true, Tensor y_pred, int k = 5)
947
{
1048
var reshape_matches = false;

0 commit comments

Comments
 (0)