diff --git a/mean_average_precision/detection_map.py b/mean_average_precision/detection_map.py index 0928452..7f21359 100644 --- a/mean_average_precision/detection_map.py +++ b/mean_average_precision/detection_map.py @@ -158,3 +158,21 @@ def plot(self, interpolated=True, class_names=None): plt.suptitle("Mean average precision : {:0.2f}".format(sum(mean_average_precision)/len(mean_average_precision))) fig.tight_layout() + + + def get_result(self,interpolated=True, class_names=None): + """ + return result. As result is calculated for a batch, it needs to be utilized for entire dataset. + """ + + mean_average_precision = [] + # TODO: data structure not optimal for this operation... + for i in range (self.n_class): + if i > self.n_class - 1: + break + precisions, recalls = self.compute_precision_recall_(i, interpolated) + average_precision = self.compute_ap(precisions, recalls) + class_name = class_names[i] if class_names else "Class {}".format(i) + mean_average_precision.append(average_precision) + + return sum(mean_average_precision)/len(mean_average_precision)