@@ -5,6 +5,10 @@ public interface IMetricsApi
55 Tensor binary_accuracy ( Tensor y_true , Tensor y_pred ) ;
66
77 Tensor categorical_accuracy ( Tensor y_true , Tensor y_pred ) ;
8+ Tensor categorical_crossentropy ( Tensor y_true , Tensor y_pred ,
9+ bool from_logits = false ,
10+ float label_smoothing = 0f ,
11+ Axis ? axis = null ) ;
812
913 Tensor mean_absolute_error ( Tensor y_true , Tensor y_pred ) ;
1014
@@ -27,14 +31,39 @@ public interface IMetricsApi
2731 /// <returns></returns>
2832 Tensor top_k_categorical_accuracy ( Tensor y_true , Tensor y_pred , int k = 5 ) ;
2933
34+ /// <summary>
35+ /// Calculates how often predictions match binary labels.
36+ /// </summary>
37+ /// <returns></returns>
38+ IMetricFunc BinaryAccuracy ( string name = "binary_accuracy" ,
39+ TF_DataType dtype = TF_DataType . TF_FLOAT ,
40+ float threshold = 05f ) ;
41+
42+ /// <summary>
43+ /// Calculates how often predictions match one-hot labels.
44+ /// </summary>
45+ /// <returns></returns>
46+ IMetricFunc CategoricalCrossentropy ( string name = "categorical_crossentropy" ,
47+ TF_DataType dtype = TF_DataType . TF_FLOAT ,
48+ bool from_logits = false ,
49+ float label_smoothing = 0f ,
50+ Axis ? axis = null ) ;
51+
52+ /// <summary>
53+ /// Computes the crossentropy metric between the labels and predictions.
54+ /// </summary>
55+ /// <returns></returns>
56+ IMetricFunc CategoricalAccuracy ( string name = "categorical_accuracy" ,
57+ TF_DataType dtype = TF_DataType . TF_FLOAT ) ;
58+
3059 /// <summary>
3160 /// Computes how often targets are in the top K predictions.
3261 /// </summary>
33- /// <param name="y_true"></param>
34- /// <param name="y_pred"></param>
3562 /// <param name="k"></param>
3663 /// <returns></returns>
37- IMetricFunc TopKCategoricalAccuracy ( int k = 5 , string name = "top_k_categorical_accuracy" , TF_DataType dtype = TF_DataType . TF_FLOAT ) ;
64+ IMetricFunc TopKCategoricalAccuracy ( int k = 5 ,
65+ string name = "top_k_categorical_accuracy" ,
66+ TF_DataType dtype = TF_DataType . TF_FLOAT ) ;
3867
3968 /// <summary>
4069 /// Computes the precision of the predictions with respect to the labels.
@@ -45,7 +74,11 @@ public interface IMetricsApi
4574 /// <param name="name"></param>
4675 /// <param name="dtype"></param>
4776 /// <returns></returns>
48- IMetricFunc Precision ( float thresholds = 0.5f , int top_k = 0 , int class_id = 0 , string name = "recall" , TF_DataType dtype = TF_DataType . TF_FLOAT ) ;
77+ IMetricFunc Precision ( float thresholds = 0.5f ,
78+ int top_k = 0 ,
79+ int class_id = 0 ,
80+ string name = "recall" ,
81+ TF_DataType dtype = TF_DataType . TF_FLOAT ) ;
4982
5083 /// <summary>
5184 /// Computes the recall of the predictions with respect to the labels.
@@ -56,5 +89,9 @@ public interface IMetricsApi
5689 /// <param name="name"></param>
5790 /// <param name="dtype"></param>
5891 /// <returns></returns>
59- IMetricFunc Recall ( float thresholds = 0.5f , int top_k = 0 , int class_id = 0 , string name = "recall" , TF_DataType dtype = TF_DataType . TF_FLOAT ) ;
92+ IMetricFunc Recall ( float thresholds = 0.5f ,
93+ int top_k = 0 ,
94+ int class_id = 0 ,
95+ string name = "recall" ,
96+ TF_DataType dtype = TF_DataType . TF_FLOAT ) ;
6097}
0 commit comments