Skip to content

Commit 067c1ff

Browse files
committed
Add metrics of F1Score and FBetaScore.
1 parent 77eb6f5 commit 067c1ff

File tree

8 files changed

+214
-6
lines changed

8 files changed

+214
-6
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ public Tensor log(Tensor x, string name = null)
3636
public Tensor erf(Tensor x, string name = null)
3737
=> math_ops.erf(x, name);
3838

39+
public Tensor multiply(Tensor x, Tensor y, string name = null)
40+
=> math_ops.multiply(x, y, name: name);
41+
42+
public Tensor divide_no_nan(Tensor a, Tensor b, string name = null)
43+
=> math_ops.div_no_nan(a, b);
44+
45+
public Tensor square(Tensor x, string name = null)
46+
=> math_ops.square(x, name: name);
47+
3948
public Tensor sum(Tensor x, Axis? axis = null, string name = null)
4049
=> math_ops.reduce_sum(x, axis: axis, name: name);
4150

src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,27 @@ IMetricFunc CosineSimilarity(string name = "cosine_similarity",
7171
TF_DataType dtype = TF_DataType.TF_FLOAT,
7272
Axis? axis = null);
7373

74+
/// <summary>
75+
/// Computes F-1 Score.
76+
/// </summary>
77+
/// <returns></returns>
78+
IMetricFunc F1Score(int num_classes,
79+
string? average = null,
80+
float threshold = -1f,
81+
string name = "fbeta_score",
82+
TF_DataType dtype = TF_DataType.TF_FLOAT);
83+
84+
/// <summary>
85+
/// Computes F-Beta score.
86+
/// </summary>
87+
/// <returns></returns>
88+
IMetricFunc FBetaScore(int num_classes,
89+
string? average = null,
90+
float beta = 0.1f,
91+
float threshold = -1f,
92+
string name = "fbeta_score",
93+
TF_DataType dtype = TF_DataType.TF_FLOAT);
94+
7495
/// <summary>
7596
/// Computes how often targets are in the top K predictions.
7697
/// </summary>

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
<AssemblyName>Tensorflow.Binding</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<TargetTensorFlow>2.10.0</TargetTensorFlow>
8-
<Version>0.100.3</Version>
8+
<Version>0.100.4</Version>
99
<LangVersion>10.0</LangVersion>
1010
<Nullable>enable</Nullable>
1111
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
@@ -20,7 +20,7 @@
2020
<Description>Google's TensorFlow full binding in .NET Standard.
2121
Building, training and infering deep learning models.
2222
https://tensorflownet.readthedocs.io</Description>
23-
<AssemblyVersion>0.100.3.0</AssemblyVersion>
23+
<AssemblyVersion>0.100.4.0</AssemblyVersion>
2424
<PackageReleaseNotes>
2525
tf.net 0.100.x and above are based on tensorflow native 2.10.0
2626

@@ -38,7 +38,7 @@ https://tensorflownet.readthedocs.io</Description>
3838
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.
3939
tf.net 0.10x.x aligns with TensorFlow v2.10.x native library.
4040
</PackageReleaseNotes>
41-
<FileVersion>0.100.3.0</FileVersion>
41+
<FileVersion>0.100.4.0</FileVersion>
4242
<PackageLicenseFile>LICENSE</PackageLicenseFile>
4343
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
4444
<SignAssembly>true</SignAssembly>
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
namespace Tensorflow.Keras.Metrics;
2+
3+
public class F1Score : FBetaScore
4+
{
5+
public F1Score(int num_classes,
6+
string? average = null,
7+
float? threshold = -1f,
8+
string name = "f1_score",
9+
TF_DataType dtype = TF_DataType.TF_FLOAT)
10+
: base(num_classes, average: average, threshold: threshold, beta: 1f, name: name, dtype: dtype)
11+
{
12+
}
13+
}
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
namespace Tensorflow.Keras.Metrics;
2+
3+
public class FBetaScore : Metric
4+
{
5+
int _num_classes;
6+
string? _average;
7+
Tensor _beta;
8+
Tensor _threshold;
9+
Axis _axis;
10+
int[] _init_shape;
11+
12+
IVariableV1 true_positives;
13+
IVariableV1 false_positives;
14+
IVariableV1 false_negatives;
15+
IVariableV1 weights_intermediate;
16+
17+
public FBetaScore(int num_classes,
18+
string? average = null,
19+
float beta = 0.1f,
20+
float? threshold = -1f,
21+
string name = "fbeta_score",
22+
TF_DataType dtype = TF_DataType.TF_FLOAT)
23+
: base(name: name, dtype: dtype)
24+
{
25+
_num_classes = num_classes;
26+
_average = average;
27+
_beta = constant_op.constant(beta);
28+
_dtype = dtype;
29+
30+
if (threshold.HasValue)
31+
{
32+
_threshold = constant_op.constant(threshold);
33+
}
34+
35+
_init_shape = new int[0];
36+
37+
if (average != "micro")
38+
{
39+
_axis = 0;
40+
_init_shape = new int[] { num_classes };
41+
}
42+
43+
true_positives = add_weight("true_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
44+
false_positives = add_weight("false_positives", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
45+
false_negatives = add_weight("false_negatives", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
46+
weights_intermediate = add_weight("weights_intermediate", shape: _init_shape, initializer: tf.initializers.zeros_initializer());
47+
}
48+
49+
public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
50+
{
51+
if (_threshold == null)
52+
{
53+
_threshold = tf.reduce_max(y_pred, axis: -1, keepdims: true);
54+
// make sure [0, 0, 0] doesn't become [1, 1, 1]
55+
// Use abs(x) > eps, instead of x != 0 to check for zero
56+
y_pred = tf.logical_and(y_pred >= _threshold, tf.abs(y_pred) > 1e-12);
57+
}
58+
else
59+
{
60+
y_pred = y_pred > _threshold;
61+
}
62+
63+
y_true = tf.cast(y_true, _dtype);
64+
y_pred = tf.cast(y_pred, _dtype);
65+
66+
true_positives.assign_add(_weighted_sum(y_pred * y_true, sample_weight));
67+
false_positives.assign_add(
68+
_weighted_sum(y_pred * (1 - y_true), sample_weight)
69+
);
70+
false_negatives.assign_add(
71+
_weighted_sum((1 - y_pred) * y_true, sample_weight)
72+
);
73+
weights_intermediate.assign_add(_weighted_sum(y_true, sample_weight));
74+
75+
return weights_intermediate.AsTensor();
76+
}
77+
78+
Tensor _weighted_sum(Tensor val, Tensor? sample_weight = null)
79+
{
80+
if (sample_weight != null)
81+
{
82+
val = tf.math.multiply(val, tf.expand_dims(sample_weight, 1));
83+
}
84+
85+
return tf.reduce_sum(val, axis: _axis);
86+
}
87+
88+
public override Tensor result()
89+
{
90+
var precision = tf.math.divide_no_nan(
91+
true_positives.AsTensor(), true_positives.AsTensor() + false_positives.AsTensor()
92+
);
93+
var recall = tf.math.divide_no_nan(
94+
true_positives.AsTensor(), true_positives.AsTensor() + false_negatives.AsTensor()
95+
);
96+
97+
var mul_value = precision * recall;
98+
var add_value = (tf.math.square(_beta) * precision) + recall;
99+
var mean = tf.math.divide_no_nan(mul_value, add_value);
100+
var f1_score = mean * (1 + tf.math.square(_beta));
101+
102+
Tensor weights;
103+
if (_average == "weighted")
104+
{
105+
weights = tf.math.divide_no_nan(
106+
weights_intermediate.AsTensor(), tf.reduce_sum(weights_intermediate.AsTensor())
107+
);
108+
f1_score = tf.reduce_sum(f1_score * weights);
109+
}
110+
// micro, macro
111+
else if (_average != null)
112+
{
113+
f1_score = tf.reduce_mean(f1_score);
114+
}
115+
116+
return f1_score;
117+
}
118+
119+
public override void reset_states()
120+
{
121+
var reset_value = np.zeros(_init_shape, dtype: _dtype);
122+
keras.backend.batch_set_value(
123+
new List<(IVariableV1, NDArray)>
124+
{
125+
(true_positives, reset_value),
126+
(false_positives, reset_value),
127+
(false_negatives, reset_value),
128+
(weights_intermediate, reset_value)
129+
});
130+
}
131+
}

src/TensorFlowNET.Keras/Metrics/MetricsApi.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ public IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentro
8686
public IMetricFunc CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null)
8787
=> new CosineSimilarity(name: name, dtype: dtype, axis: axis ?? -1);
8888

89+
public IMetricFunc F1Score(int num_classes, string? average = null, float threshold = -1, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT)
90+
=> new F1Score(num_classes, average: average, threshold: threshold, name: name, dtype: dtype);
91+
92+
public IMetricFunc FBetaScore(int num_classes, string? average = null, float beta = 0.1F, float threshold = -1, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT)
93+
=> new FBetaScore(num_classes, average: average,beta: beta, threshold: threshold, name: name, dtype: dtype);
94+
8995
public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
9096
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);
9197

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<Nullable>enable</Nullable>
88
<RootNamespace>Tensorflow.Keras</RootNamespace>
99
<Platforms>AnyCPU;x64</Platforms>
10-
<Version>0.10.3</Version>
10+
<Version>0.10.4</Version>
1111
<Authors>Haiping Chen</Authors>
1212
<Product>Keras for .NET</Product>
1313
<Copyright>Apache 2.0, Haiping Chen 2023</Copyright>
@@ -37,8 +37,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
3737
<RepositoryType>Git</RepositoryType>
3838
<SignAssembly>true</SignAssembly>
3939
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
40-
<AssemblyVersion>0.10.3.0</AssemblyVersion>
41-
<FileVersion>0.10.3.0</FileVersion>
40+
<AssemblyVersion>0.10.4.0</AssemblyVersion>
41+
<FileVersion>0.10.4.0</FileVersion>
4242
<PackageLicenseFile>LICENSE</PackageLicenseFile>
4343
<Configurations>Debug;Release;GPU</Configurations>
4444
</PropertyGroup>

test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,34 @@ public void CosineSimilarity()
114114
Assert.AreEqual(r, 0.6999999f);
115115
}
116116

117+
/// <summary>
118+
/// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score
119+
/// </summary>
120+
[TestMethod]
121+
public void F1Score()
122+
{
123+
var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } });
124+
var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } });
125+
var m = tf.keras.metrics.F1Score(num_classes: 3, threshold: 0.5f);
126+
m.update_state(y_true, y_pred);
127+
var r = m.result().numpy();
128+
Assert.AreEqual(r, new[] { 0.5f, 0.8f, 0.6666667f });
129+
}
130+
131+
/// <summary>
132+
/// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/FBetaScore
133+
/// </summary>
134+
[TestMethod]
135+
public void FBetaScore()
136+
{
137+
var y_true = np.array(new[,] { { 1, 1, 1 }, { 1, 0, 0 }, { 1, 1, 0 } });
138+
var y_pred = np.array(new[,] { { 0.2f, 0.6f, 0.7f }, { 0.2f, 0.6f, 0.6f }, { 0.6f, 0.8f, 0f } });
139+
var m = tf.keras.metrics.FBetaScore(num_classes: 3, beta: 2.0f, threshold: 0.5f);
140+
m.update_state(y_true, y_pred);
141+
var r = m.result().numpy();
142+
Assert.AreEqual(r, new[] { 0.3846154f, 0.90909094f, 0.8333334f });
143+
}
144+
117145
/// <summary>
118146
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
119147
/// </summary>

0 commit comments

Comments
 (0)