Skip to content

Commit c72f479

Browse files
committed
Add metric of HammingLoss.
1 parent 32a3e48 commit c72f479

File tree

8 files changed

+122
-2
lines changed

8 files changed

+122
-2
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ public class MathApi
2424
public Tensor argmax(Tensor input, Axis axis = null, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64)
2525
=> gen_math_ops.arg_max(input, axis, name: name, output_type: output_type);
2626

27+
public Tensor count_nonzero(Tensor input, Axis? axis = null, bool? keepdims = null, TF_DataType dtype = TF_DataType.TF_INT64, string name = null)
28+
=> math_ops.count_nonzero_v2(input, axis: axis, keepdims: keepdims ?? false, dtype: dtype);
2729
public Tensor log(Tensor x, string name = null)
2830
=> gen_math_ops.log(x, name);
2931

src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,20 @@ IMetricFunc FBetaScore(int num_classes,
9191
float? threshold = null,
9292
string name = "fbeta_score",
9393
TF_DataType dtype = TF_DataType.TF_FLOAT);
94-
94+
95+
/// <summary>
96+
/// Computes hamming loss.
97+
/// </summary>
98+
/// <param name="mode">multiclass or multilabel</param>
99+
/// <param name="threshold"></param>
100+
/// <param name="name"></param>
101+
/// <param name="dtype"></param>
102+
/// <returns></returns>
103+
IMetricFunc HammingLoss(string mode,
104+
float? threshold = null,
105+
string name = "hamming_loss",
106+
TF_DataType dtype = TF_DataType.TF_FLOAT);
107+
95108
/// <summary>
96109
/// Computes how often targets are in the top K predictions.
97110
/// </summary>

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,18 @@ public static Tensor batch_matmul(Tensor x, Tensor y,
821821
.SetAttributes(new { adj_x, adj_y }));
822822
});
823823

824+
public static Tensor count_nonzero_v2(Tensor input,
825+
Axis? axis,
826+
bool keepdims = false,
827+
string name = null,
828+
TF_DataType dtype = TF_DataType.TF_INT64)
829+
=> tf_with(ops.name_scope(name, "count_nonzero", input), scope =>
830+
{
831+
name = scope;
832+
var zero = array_ops.zeros(Shape.Scalar, dtype: input.dtype);
833+
return reduce_sum(cast(gen_math_ops.not_equal(input, zero), dtype), axis: axis, keepdims: keepdims);
834+
});
835+
824836
public static Tensor bincount(Tensor arr, Tensor weights = null,
825837
Tensor minlength = null,
826838
Tensor maxlength = null,

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ https://tensorflownet.readthedocs.io</Description>
109109
<ItemGroup>
110110
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
111111
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
112-
<PackageReference Include="Protobuf.Text" Version="0.6.0" />
112+
<PackageReference Include="Protobuf.Text" Version="0.6.1" />
113113
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
114114
</ItemGroup>
115115
</Project>
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
namespace Tensorflow.Keras.Metrics;
2+
3+
public class HammingLoss : MeanMetricWrapper
4+
{
5+
public HammingLoss(string mode,
6+
NDArray threshold = null,
7+
string name = "hamming_loss",
8+
TF_DataType dtype = TF_DataType.TF_FLOAT)
9+
: base((yt, yp) => metrics_utils.hamming_loss_fn(yt, yp, threshold, mode),
10+
name: name,
11+
dtype: dtype)
12+
{
13+
_dtype = dtype;
14+
}
15+
}

src/TensorFlowNET.Keras/Metrics/MetricsApi.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ public IMetricFunc F1Score(int num_classes, string? average = null, float? thres
9292
public IMetricFunc FBetaScore(int num_classes, string? average = null, float beta = 0.1F, float? threshold = null, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT)
9393
=> new FBetaScore(num_classes, average: average,beta: beta, threshold: threshold, name: name, dtype: dtype);
9494

95+
public IMetricFunc HammingLoss(string mode, float? threshold = null, string name = "hamming_loss", TF_DataType dtype = TF_DataType.TF_FLOAT)
96+
=> new HammingLoss(mode, threshold: threshold, name: name, dtype: dtype);
97+
9598
public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)
9699
=> new TopKCategoricalAccuracy(k: k, name: name, dtype: dtype);
97100

src/TensorFlowNET.Keras/Metrics/metrics_utils.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,36 @@ public static Tensor cosine_similarity(Tensor y_true, Tensor y_pred, Axis? axis
2424
return tf.reduce_sum(y_true * y_pred, axis: axis ?? -1);
2525
}
2626

27+
public static Tensor hamming_loss_fn(Tensor y_true, Tensor y_pred, Tensor threshold, string mode)
28+
{
29+
if (threshold == null)
30+
{
31+
threshold = tf.reduce_max(y_pred, axis: -1, keepdims: true);
32+
// make sure [0, 0, 0] doesn't become [1, 1, 1]
33+
// Use abs(x) > eps, instead of x != 0 to check for zero
34+
y_pred = tf.logical_and(y_pred >= threshold, tf.abs(y_pred) > 1e-12);
35+
}
36+
else
37+
{
38+
y_pred = y_pred > threshold;
39+
}
40+
41+
42+
y_true = tf.cast(y_true, tf.int32);
43+
y_pred = tf.cast(y_pred, tf.int32);
44+
45+
if (mode == "multiclass")
46+
{
47+
var nonzero = tf.cast(tf.math.count_nonzero(y_true * y_pred, axis: -1), tf.float32);
48+
return 1.0 - nonzero;
49+
}
50+
else
51+
{
52+
var nonzero = tf.cast(tf.math.count_nonzero(y_true - y_pred, axis: -1), tf.float32);
53+
return nonzero / y_true.shape[-1];
54+
}
55+
}
56+
2757
/// <summary>
2858
/// Creates float Tensor, 1.0 for label-prediction match, 0.0 for mismatch.
2959
/// </summary>

test/TensorFlowNET.Keras.UnitTest/Metrics/MetricsTest.cs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,51 @@ public void FBetaScore()
142142
Assert.AreEqual(r, new[] { 0.3846154f, 0.90909094f, 0.8333334f });
143143
}
144144

145+
/// <summary>
146+
/// https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/HammingLoss
147+
/// </summary>
148+
[TestMethod]
149+
public void HammingLoss()
150+
{
151+
// multi-class hamming loss
152+
var y_true = np.array(new[,]
153+
{
154+
{ 1, 0, 0, 0 },
155+
{ 0, 0, 1, 0 },
156+
{ 0, 0, 0, 1 },
157+
{ 0, 1, 0, 0 }
158+
});
159+
var y_pred = np.array(new[,]
160+
{
161+
{ 0.8f, 0.1f, 0.1f, 0.0f },
162+
{ 0.2f, 0.0f, 0.8f, 0.0f },
163+
{ 0.05f, 0.05f, 0.1f, 0.8f },
164+
{ 1.0f, 0.0f, 0.0f, 0.0f }
165+
});
166+
var m = tf.keras.metrics.HammingLoss(mode: "multiclass", threshold: 0.6f);
167+
m.update_state(y_true, y_pred);
168+
var r = m.result().numpy();
169+
Assert.AreEqual(r, 0.25f);
170+
171+
// multi-label hamming loss
172+
y_true = np.array(new[,]
173+
{
174+
{ 1, 0, 1, 0 },
175+
{ 0, 1, 0, 1 },
176+
{ 0, 0, 0, 1 }
177+
});
178+
y_pred = np.array(new[,]
179+
{
180+
{ 0.82f, 0.5f, 0.9f, 0.0f },
181+
{ 0f, 1f, 0.4f, 0.98f },
182+
{ 0.89f, 0.79f, 0f, 0.3f }
183+
});
184+
m = tf.keras.metrics.HammingLoss(mode: "multilabel", threshold: 0.8f);
185+
m.update_state(y_true, y_pred);
186+
r = m.result().numpy();
187+
Assert.AreEqual(r, 0.16666667f);
188+
}
189+
145190
/// <summary>
146191
/// https://www.tensorflow.org/api_docs/python/tf/keras/metrics/TopKCategoricalAccuracy
147192
/// </summary>

0 commit comments

Comments
 (0)