Skip to content

Commit 922139f

Browse files
committed
Set mean of all classes in F1 Score.
1 parent 067c1ff commit 922139f

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/TensorFlowNET.Keras/Engine/Model.Train.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,15 @@ Dictionary<string, float> train_step(DataHandler data_handler, Tensor x, Tensor
3939
compiled_metrics.update_state(y, y_pred);
4040

4141
var dict = new Dictionary<string, float>();
42-
metrics.ToList().ForEach(x => dict[x.Name] = (float)x.result());
42+
metrics.ToList().ForEach(x =>
43+
{
44+
var r = x.result();
45+
if (r.ndim > 0)
46+
{
47+
r = tf.reduce_mean(r);
48+
}
49+
dict[x.Name] = (float)r;
50+
});
4351
return dict;
4452
}
4553

0 commit comments

Comments
 (0)