Skip to content

Commit 77eb6f5

Browse files
committed
Add metrics of Accuracy and CosineSimilarity.
1 parent 9e877d1 commit 77eb6f5

File tree

7 files changed

+106
-3
lines changed

7 files changed

+106
-3
lines changed

src/TensorFlowNET.Core/APIs/tf.linalg.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ public Tensor inv(Tensor input, bool adjoint = false, string name = null)
5454
public Tensor global_norm(Tensor[] t_list, string name = null)
5555
=> clip_ops.global_norm(t_list, name: name);
5656

57+
public Tensor l2_normalize(Tensor x,
58+
int axis = 0,
59+
float epsilon = 1e-12f,
60+
string name = null)
61+
=> nn_impl.l2_normalize(x, axis: axis, epsilon: constant_op.constant(epsilon), name: name);
62+
5763
public Tensor lstsq(Tensor matrix, Tensor rhs,
5864
NDArray l2_regularizer = null, bool fast = true, string name = null)
5965
=> ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name);

src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred,
3131
/// <returns></returns>
3232
Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5);
3333

34+
/// <summary>
35+
/// Calculates how often predictions equal labels.
36+
/// </summary>
37+
/// <returns></returns>
38+
IMetricFunc Accuracy(string name = "accuracy",
39+
TF_DataType dtype = TF_DataType.TF_FLOAT);
40+
3441
/// <summary>
3542
/// Calculates how often predictions match binary labels.
3643
/// </summary>
@@ -56,6 +63,14 @@ IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy",
5663
IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy",
5764
TF_DataType dtype = TF_DataType.TF_FLOAT);
5865

66+
/// <summary>
67+
/// Computes the cosine similarity between the labels and predictions.
68+
/// </summary>
69+
/// <returns></returns>
70+
IMetricFunc CosineSimilarity(string name = "cosine_similarity",
71+
TF_DataType dtype = TF_DataType.TF_FLOAT,
72+
Axis? axis = null);
73+
5974
/// <summary>
6075
/// Computes how often targets are in the top K predictions.
6176
/// </summary>
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
namespace Tensorflow.Keras.Metrics;
2+
3+
public class Accuracy : MeanMetricWrapper
4+
{
5+
public Accuracy(string name = "accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
6+
: base((yt, yp) => metrics_utils.accuracy(yt, yp),
7+
name: name,
8+
dtype: dtype)
9+
{
10+
}
11+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
namespace Tensorflow.Keras.Metrics;
2+
3+
public class CosineSimilarity : MeanMetricWrapper
4+
{
5+
public CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null)
6+
: base((yt, yp) => metrics_utils.cosine_similarity(yt, yp, axis: axis ?? -1),
7+
name: name,
8+
dtype: dtype)
9+
{
10+
}
11+
}

src/TensorFlowNET.Keras/Metrics/MetricsApi.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ public Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5
7171
);
7272
}
7373

74+
public IMetricFunc Accuracy(string name = "accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
75+
=> new Accuracy(name: name, dtype: dtype);
76+
7477
public IMetricFunc BinaryAccuracy(string name = "binary_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT, float threshold = 5)
7578
=> new BinaryAccuracy();
7679

@@ -80,6 +83,9 @@ public IMetricFunc CategoricalAccuracy(string name = "categorical_accuracy", TF_
8083
public IMetricFunc CategoricalCrossentropy(string name = "categorical_crossentropy", TF_DataType dtype = TF_DataType.TF_FLOAT, bool from_logits = false, float label_smoothing = 0, Axis? axis = null)
8184
=> new CategoricalCrossentropy();
8285

86+
public IMetricFunc CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null)
87+
=> new CosineSimilarity(name: name, dtype: dtype, axis: axis ?? -1);
88+
8389
public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
8490
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);
8591

src/TensorFlowNET.Keras/Metrics/metrics_utils.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,26 @@ namespace Tensorflow.Keras.Metrics;
44

55
public class metrics_utils
66
{
7+
public static Tensor accuracy(Tensor y_true, Tensor y_pred)
8+
{
9+
if (y_true.dtype != y_pred.dtype)
10+
y_pred = tf.cast(y_pred, y_true.dtype);
11+
return tf.cast(tf.equal(y_true, y_pred), keras.backend.floatx());
12+
}
13+
714
public static Tensor binary_matches(Tensor y_true, Tensor y_pred, float threshold = 0.5f)
815
{
916
y_pred = tf.cast(y_pred > threshold, y_pred.dtype);
1017
return tf.cast(tf.equal(y_true, y_pred), keras.backend.floatx());
1118
}
1219

20+
public static Tensor cosine_similarity(Tensor y_true, Tensor y_pred, Axis? axis = null)
21+
{
22+
y_true = tf.linalg.l2_normalize(y_true, axis: axis ?? -1);
23+
y_pred = tf.linalg.l2_normalize(y_pred, axis: axis ?? -1);
24+
return tf.reduce_sum(y_true * y_pred, axis: axis ?? -1);
25+
}
26+
1327
/// <summary>
1428
/// Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch.
1529
/// </summary>

test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,26 @@ namespace TensorFlowNET.Keras.UnitTest;
1414
[TestClass]
1515
public class MetricsTest : EagerModeTestBase
1616
{
17+
/// <summary>
18+
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Accuracy
19+
/// </summary>
20+
[TestMethod]
21+
public void Accuracy()
22+
{
23+
var y_true = np.array(new[,] { { 1 }, { 2 }, { 3 }, { 4 } });
24+
var y_pred = np.array(new[,] { { 0f }, { 2f }, { 3f }, { 4f } });
25+
var m = tf.keras.metrics.Accuracy();
26+
m.update_state(y_true, y_pred);
27+
var r = m.result().numpy();
28+
Assert.AreEqual(r, 0.75f);
29+
30+
m.reset_states();
31+
var weights = np.array(new[] { 1f, 1f, 0f, 0f });
32+
m.update_state(y_true, y_pred, sample_weight: weights);
33+
r = m.result().numpy();
34+
Assert.AreEqual(r, 0.5f);
35+
}
36+
1737
/// <summary>
1838
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/BinaryAccuracy
1939
/// </summary>
@@ -23,14 +43,14 @@ public void BinaryAccuracy()
2343
var y_true = np.array(new[,] { { 1 }, { 1 },{ 0 }, { 0 } });
2444
var y_pred = np.array(new[,] { { 0.98f }, { 1f }, { 0f }, { 0.6f } });
2545
var m = tf.keras.metrics.BinaryAccuracy();
26-
/*m.update_state(y_true, y_pred);
46+
m.update_state(y_true, y_pred);
2747
var r = m.result().numpy();
2848
Assert.AreEqual(r, 0.75f);
2949

30-
m.reset_states();*/
50+
m.reset_states();
3151
var weights = np.array(new[] { 1f, 0f, 0f, 1f });
3252
m.update_state(y_true, y_pred, sample_weight: weights);
33-
var r = m.result().numpy();
53+
r = m.result().numpy();
3454
Assert.AreEqual(r, 0.5f);
3555
}
3656

@@ -74,6 +94,26 @@ public void CategoricalCrossentropy()
7494
Assert.AreEqual(r, 1.6271976f);
7595
}
7696

97+
/// <summary>
98+
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/CosineSimilarity
99+
/// </summary>
100+
[TestMethod]
101+
public void CosineSimilarity()
102+
{
103+
var y_true = np.array(new[,] { { 0, 1 }, { 1, 1 } });
104+
var y_pred = np.array(new[,] { { 1f, 0f }, { 1f, 1f } });
105+
var m = tf.keras.metrics.CosineSimilarity(axis: 1);
106+
m.update_state(y_true, y_pred);
107+
var r = m.result().numpy();
108+
Assert.AreEqual(r, 0.49999997f);
109+
110+
m.reset_states();
111+
var weights = np.array(new[] { 0.3f, 0.7f });
112+
m.update_state(y_true, y_pred, sample_weight: weights);
113+
r = m.result().numpy();
114+
Assert.AreEqual(r, 0.6999999f);
115+
}
116+
77117
/// <summary>
78118
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
79119
/// </summary>

0 commit comments

Comments
 (0)