Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions trapdata/ml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class InferenceBaseClass:
weights = None
labels_path = None
category_map = {}
class_masking_list = None
num_classes: Union[int, None] = None # Will use len(category_map) if None
lookup_gbif_names: bool = False
default_taxon_rank: str = "SPECIES"
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
f"Loading {self.type} model (stage: {self.stage}) for {self.name} with {len(self.category_map or [])} categories"
)
self.model = self.get_model()
self.class_masking_list = self.get_class_masking_list()

@classmethod
def get_key(cls):
Expand Down Expand Up @@ -184,6 +186,30 @@ def fetch_gbif_ids(labels):
else:
return {}

def get_class_masking_list(self) -> list[str]:
"""
This must be implemented by a subclass

"""
raise NotImplementedError

def _mask_classes(self, predictions: torch.Tensor):
"""Class mask function to include specific output classes and exclude the rest"""

# Create a mask for the classes to prune
mask = torch.zeros(
predictions.size(1), dtype=torch.bool, device=predictions.device
)

# Get species keys that needs to be removed
for taxon_to_keep in self.class_masking_list:
id_to_keep = self.name_to_id_map[taxon_to_keep]
mask[id_to_keep] = True

# Apply the mask to zero out unwanted nodes
predictions[:, ~mask] = float("-inf") # Set to -inf to ignore during softmax
return predictions

def get_model(self) -> torch.nn.Module:
"""
This method must be implemented by a subclass.
Expand Down
10 changes: 10 additions & 0 deletions trapdata/ml/models/classification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import pickle

import timm
import torch
import torch.utils.data
Expand Down Expand Up @@ -241,6 +244,8 @@ def get_transforms(self):
)

def post_process_batch(self, output):
# Mask classes if requested

predictions = torch.nn.functional.softmax(output, dim=1)
predictions = predictions.cpu().numpy()

Expand Down Expand Up @@ -372,6 +377,11 @@ def save_results(self, object_ids, batch_output, *args, **kwargs):
]
save_classified_objects(self.db_path, object_ids, classified_objects_data)

def get_class_masking_list(self, masking_list: str) -> list[str]:
with open(os.getenv(masking_list, "masking_list.pkl"), "rb") as f:
class_masking_list = pickle.load(f)
return class_masking_list


class QuebecVermontMothSpeciesClassifierMixedResolution(
SpeciesClassifier, Resnet50ClassifierLowRes
Expand Down
Loading