diff --git a/trapdata/ml/models/base.py b/trapdata/ml/models/base.py index cee4033..2398f6c 100644 --- a/trapdata/ml/models/base.py +++ b/trapdata/ml/models/base.py @@ -74,6 +74,7 @@ class InferenceBaseClass: weights = None labels_path = None category_map = {} + class_masking_list = None num_classes: Union[int, None] = None # Will use len(category_map) if None lookup_gbif_names: bool = False default_taxon_rank: str = "SPECIES" @@ -116,6 +117,7 @@ def __init__( f"Loading {self.type} model (stage: {self.stage}) for {self.name} with {len(self.category_map or [])} categories" ) self.model = self.get_model() + self.class_masking_list = self.get_class_masking_list() @classmethod def get_key(cls): @@ -184,6 +186,30 @@ def fetch_gbif_ids(labels): else: return {} + def get_class_masking_list(self) -> list[str]: + """ + This must be implemented by a subclass + + """ + raise NotImplementedError + + def _mask_classes(self, predictions: torch.Tensor): + """Class mask function to include specific output classes and exclude the rest""" + + # Create a mask for the classes to prune + mask = torch.zeros( + predictions.size(1), dtype=torch.bool, device=predictions.device + ) + + # Get species keys that needs to be removed + for taxon_to_keep in self.class_masking_list: + id_to_keep = self.name_to_id_map[taxon_to_keep] + mask[id_to_keep] = True + + # Apply the mask to zero out unwanted nodes + predictions[:, ~mask] = float("-inf") # Set to -inf to ignore during softmax + return predictions + def get_model(self) -> torch.nn.Module: """ This method must be implemented by a subclass. diff --git a/trapdata/ml/models/classification.py b/trapdata/ml/models/classification.py index b7bf833..9d060fc 100644 --- a/trapdata/ml/models/classification.py +++ b/trapdata/ml/models/classification.py @@ -1,3 +1,6 @@ +import os +import pickle + import timm import torch import torch.utils.data @@ -241,6 +244,8 @@ def get_transforms(self): ) def post_process_batch(self, output): + # Mask classes if requested + predictions = torch.nn.functional.softmax(output, dim=1) predictions = predictions.cpu().numpy() @@ -372,6 +377,11 @@ def save_results(self, object_ids, batch_output, *args, **kwargs): ] save_classified_objects(self.db_path, object_ids, classified_objects_data) + def get_class_masking_list(self, masking_list: str) -> list[str]: + with open(os.getenv(masking_list, "masking_list.pkl"), "rb") as f: + class_masking_list = pickle.load(f) + return class_masking_list + class QuebecVermontMothSpeciesClassifierMixedResolution( SpeciesClassifier, Resnet50ClassifierLowRes