66
77from .. import detection_in_world
88from .. import object_in_world
9- from ..cluster_estimation import cluster_estimation
109from ..common .modules .logger import logger
10+ from . import cluster_estimation
1111
1212
1313class ClusterEstimationByLabel :
@@ -19,11 +19,14 @@ class ClusterEstimationByLabel:
1919 ATTRIBUTES
2020 ----------
2121 min_activation_threshold: int
22- Minimum total data points before model runs.
22+ Minimum total data points before model runs. Must be at least max_num_components.
2323
2424 min_new_points_to_run: int
2525 Minimum number of new data points that must be collected before running model.
2626
27+ max_num_components: int
28+ Max number of real landing pads. Must be at least 1.
29+
2730 random_state: int
2831 Seed for randomizer, to get consistent results.
2932
@@ -47,6 +50,7 @@ def create(
4750 cls ,
4851 min_activation_threshold : int ,
4952 min_new_points_to_run : int ,
53+ max_num_components : int ,
5054 random_state : int ,
5155 local_logger : logger .Logger ,
5256 ) -> "tuple[bool, ClusterEstimationByLabel | None]" :
@@ -55,13 +59,23 @@ def create(
5559 """
5660
5761 # At least 1 point for model to fit
58- if min_activation_threshold < 1 :
62+ if min_activation_threshold < max_num_components :
63+ return False , None
64+
65+ if min_new_points_to_run < 0 :
66+ return False , None
67+
68+ if max_num_components < 1 :
69+ return False , None
70+
71+ if random_state < 0 :
5972 return False , None
6073
6174 return True , ClusterEstimationByLabel (
6275 cls .__create_key ,
6376 min_activation_threshold ,
6477 min_new_points_to_run ,
78+ max_num_components ,
6579 random_state ,
6680 local_logger ,
6781 )
@@ -71,6 +85,7 @@ def __init__(
7185 class_private_create_key : object ,
7286 min_activation_threshold : int ,
7387 min_new_points_to_run : int ,
88+ max_num_components : int ,
7489 random_state : int ,
7590 local_logger : logger .Logger ,
7691 ) -> None :
@@ -84,10 +99,12 @@ def __init__(
8499 # Requirements to decide to run
85100 self .__min_activation_threshold = min_activation_threshold
86101 self .__min_new_points_to_run = min_new_points_to_run
102+ self .__max_num_components = max_num_components
87103 self .__random_state = random_state
88104 self .__local_logger = local_logger
89105
90- # cluster model corresponding to each label
106+ # Cluster model corresponding to each label
107+ # Each cluster estimation object stores the detections given to in its __all_points bucket across runs
91108 self .__label_to_cluster_estimation_model : dict [
92109 int , cluster_estimation .ClusterEstimation
93110 ] = {}
@@ -120,17 +137,20 @@ def run(
120137 Dictionary where the key is a label and the value is a list of all cluster detections with that label
121138 """
122139 label_to_detections : dict [int , list [detection_in_world .DetectionInWorld ]] = {}
140+ # Sorting detections by label
123141 for detection in input_detections :
124142 if not detection .label in label_to_detections :
125143 label_to_detections [detection .label ] = []
126144 label_to_detections [detection .label ].append (detection )
127145
128146 labels_to_object_clusters : dict [int , list [object_in_world .ObjectInWorld ]] = {}
129147 for label , detections in label_to_detections .items ():
148+ # create cluster estimation for label if it doesn't exist
130149 if not label in self .__label_to_cluster_estimation_model :
131150 result , cluster_model = cluster_estimation .ClusterEstimation .create (
132151 self .__min_activation_threshold ,
133152 self .__min_new_points_to_run ,
153+ self .__max_num_components ,
134154 self .__random_state ,
135155 self .__local_logger ,
136156 label ,
@@ -141,6 +161,7 @@ def run(
141161 )
142162 return False , None
143163 self .__label_to_cluster_estimation_model [label ] = cluster_model
164+ # runs cluster estimation for specific label
144165 result , clusters = self .__label_to_cluster_estimation_model [label ].run (
145166 detections ,
146167 run_override ,
0 commit comments