Skip to content

Commit 9891998

Browse files
committed
Add metrics of Precision.
1 parent 217cfd2 commit 9891998

File tree

7 files changed

+141
-3
lines changed

7 files changed

+141
-3
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ public Tensor erf(Tensor x, string name = null)
3939
public Tensor sum(Tensor x, Axis? axis = null, string name = null)
4040
=> math_ops.reduce_sum(x, axis: axis, name: name);
4141

42+
/// <summary>
43+
/// Finds values and indices of the `k` largest entries for the last dimension.
44+
/// </summary>
45+
/// <param name="input"></param>
46+
/// <param name="k"></param>
47+
/// <param name="sorted"></param>
48+
/// <param name="name"></param>
49+
/// <returns></returns>
50+
public Tensors top_k(Tensor input, int k, bool sorted = true, string name = null)
51+
=> nn_ops.top_kv2(input, k, sorted: sorted, name: name);
52+
4253
public Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = "InTopK")
4354
=> nn_ops.in_top_k(predictions, targets, k, name);
4455

src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,17 @@ public interface IMetricsApi
3636
/// <returns></returns>
3737
IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT);
3838

39+
/// <summary>
40+
/// Computes the precision of the predictions with respect to the labels.
41+
/// </summary>
42+
/// <param name="thresholds"></param>
43+
/// <param name="top_k"></param>
44+
/// <param name="class_id"></param>
45+
/// <param name="name"></param>
46+
/// <param name="dtype"></param>
47+
/// <returns></returns>
48+
IMetricFunc Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT);
49+
3950
/// <summary>
4051
/// Computes the recall of the predictions with respect to the labels.
4152
/// </summary>
@@ -45,5 +56,5 @@ public interface IMetricsApi
4556
/// <param name="name"></param>
4657
/// <param name="dtype"></param>
4758
/// <returns></returns>
48-
IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT);
59+
IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT);
4960
}

src/TensorFlowNET.Core/Operations/nn_ops.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ private static Tensor _get_noise_shape(Tensor x, Tensor noise_shape)
109109
return noise_shape;
110110
}
111111

112+
public static Tensors top_kv2(Tensor input, int k, bool sorted = true, string name = null)
113+
=> tf.Context.ExecuteOp("TopKV2", name, new ExecuteOpArgs(input, k)
114+
.SetAttributes(new { sorted }));
115+
112116
public static Tensor in_top_k(Tensor predictions, Tensor targets, int k, string name = null)
113117
{
114118
return tf_with(ops.name_scope(name, "in_top_k"), delegate

src/TensorFlowNET.Keras/Metrics/MetricsApi.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ public Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5
6262
public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
6363
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);
6464

65-
public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 1, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT)
65+
public IMetricFunc Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "precision", TF_DataType dtype = TF_DataType.TF_FLOAT)
66+
=> new Precision(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype);
67+
68+
public IMetricFunc Recall(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT)
6669
=> new Recall(thresholds: thresholds, top_k: top_k, class_id: class_id, name: name, dtype: dtype);
6770
}
6871
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
namespace Tensorflow.Keras.Metrics;
2+
3+
public class Precision : Metric
4+
{
5+
Tensor _thresholds;
6+
int _top_k;
7+
int _class_id;
8+
IVariableV1 true_positives;
9+
IVariableV1 false_positives;
10+
bool _thresholds_distributed_evenly;
11+
12+
public Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT)
13+
: base(name: name, dtype: dtype)
14+
{
15+
_thresholds = constant_op.constant(new float[] { thresholds });
16+
_top_k = top_k;
17+
_class_id = class_id;
18+
true_positives = add_weight("true_positives", shape: 1, initializer: tf.initializers.zeros_initializer());
19+
false_positives = add_weight("false_positives", shape: 1, initializer: tf.initializers.zeros_initializer());
20+
}
21+
22+
public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
23+
{
24+
return metrics_utils.update_confusion_matrix_variables(
25+
new Dictionary<string, IVariableV1>
26+
{
27+
{ "tp", true_positives },
28+
{ "fp", false_positives },
29+
},
30+
y_true,
31+
y_pred,
32+
thresholds: _thresholds,
33+
thresholds_distributed_evenly: _thresholds_distributed_evenly,
34+
top_k: _top_k,
35+
class_id: _class_id,
36+
sample_weight: sample_weight);
37+
}
38+
39+
public override Tensor result()
40+
{
41+
var result = tf.divide(true_positives.AsTensor(), tf.add(true_positives, false_positives));
42+
return _thresholds.size == 1 ? result[0] : result;
43+
}
44+
45+
public override void reset_states()
46+
{
47+
var num_thresholds = (int)_thresholds.size;
48+
keras.backend.batch_set_value(
49+
new List<(IVariableV1, NDArray)>
50+
{
51+
(true_positives, np.zeros(num_thresholds)),
52+
(false_positives, np.zeros(num_thresholds))
53+
});
54+
}
55+
}

src/TensorFlowNET.Keras/Metrics/metrics_utils.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,17 @@ public static Tensor update_confusion_matrix_variables(Dictionary<string, IVaria
7878
sample_weight: sample_weight);
7979
}
8080

81+
if (top_k > 0)
82+
{
83+
y_pred = _filter_top_k(y_pred, top_k);
84+
}
85+
86+
if (class_id > 0)
87+
{
88+
y_true = y_true[Slice.All, class_id];
89+
y_pred = y_pred[Slice.All, class_id];
90+
}
91+
8192
if (thresholds_distributed_evenly)
8293
{
8394
throw new NotImplementedException();
@@ -204,5 +215,14 @@ Func<Tensor, Tensor, Tensor, IVariableV1, ITensorOrOperation> weighted_assign_ad
204215

205216
tf.group(update_ops.ToArray());
206217
return null;
207-
}
218+
}
219+
220+
private static Tensor _filter_top_k(Tensor x, int k)
221+
{
222+
var NEG_INF = -1e10;
223+
var (_, top_k_idx) = tf.math.top_k(x, k, sorted: false);
224+
var top_k_mask = tf.reduce_sum(
225+
tf.one_hot(top_k_idx, (int)x.shape[-1], axis: -1), axis: -2);
226+
return x * top_k_mask + NEG_INF * (1 - top_k_mask);
227+
}
208228
}

test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,40 @@ public void top_k_categorical_accuracy()
4646
Assert.AreEqual(m.numpy(), new[] { 1f, 1f });
4747
}
4848

49+
/// <summary>
50+
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision
51+
/// </summary>
52+
[TestMethod]
53+
public void Precision()
54+
{
55+
var y_true = np.array(new[] { 0, 1, 1, 1 });
56+
var y_pred = np.array(new[] { 1, 0, 1, 1 });
57+
var m = tf.keras.metrics.Precision();
58+
m.update_state(y_true, y_pred);
59+
var r = m.result().numpy();
60+
Assert.AreEqual(r, 0.6666667f);
61+
62+
m.reset_states();
63+
var weights = np.array(new[] { 0f, 0f, 1f, 0f });
64+
m.update_state(y_true, y_pred, sample_weight: weights);
65+
r = m.result().numpy();
66+
Assert.AreEqual(r, 1f);
67+
68+
// With top_k=2, it will calculate precision over y_true[:2]
69+
// and y_pred[:2]
70+
m = tf.keras.metrics.Precision(top_k: 2);
71+
m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 }));
72+
r = m.result().numpy();
73+
Assert.AreEqual(r, 0f);
74+
75+
// With top_k=4, it will calculate precision over y_true[:4]
76+
// and y_pred[:4]
77+
m = tf.keras.metrics.Precision(top_k: 4);
78+
m.update_state(np.array(new[] { 0, 0, 1, 1 }), np.array(new[] { 1, 1, 1, 1 }));
79+
r = m.result().numpy();
80+
Assert.AreEqual(r, 0.5f);
81+
}
82+
4983
/// <summary>
5084
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Recall
5185
/// </summary>

0 commit comments

Comments
 (0)