Skip to content

Commit 4aae8f9

Browse files
committed
Add lazy loading option for models.
1 parent 364767d commit 4aae8f9

3 files changed

Lines changed: 38 additions & 2 deletions

File tree

src/ChartExtractor/extraction/extraction.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,46 +64,53 @@
6464
),
6565
MODEL_CONFIG["intraoperative_document_landmarks"]["imgsz"],
6666
MODEL_CONFIG["intraoperative_document_landmarks"]["imgsz"],
67+
lazy_loading=True,
6768
)
6869
PREOP_POSTOP_DOC_MODEL = OnnxYolov11Detection(
6970
PATH_TO_MODELS / MODEL_CONFIG["preop_postop_document_landmarks"]["name"],
7071
PATH_TO_MODEL_METADATA
7172
/ MODEL_CONFIG["preop_postop_document_landmarks"]["name"].replace(".onnx", ".json"),
7273
MODEL_CONFIG["preop_postop_document_landmarks"]["imgsz"],
7374
MODEL_CONFIG["preop_postop_document_landmarks"]["imgsz"],
75+
lazy_loading=True,
7476
)
7577
NUMBERS_MODEL = OnnxYolov11Detection(
7678
PATH_TO_MODELS / MODEL_CONFIG["numbers"]["name"],
7779
PATH_TO_MODEL_METADATA / MODEL_CONFIG["numbers"]["name"].replace(".onnx", ".json"),
7880
MODEL_CONFIG["numbers"]["imgsz"],
7981
MODEL_CONFIG["numbers"]["imgsz"],
82+
lazy_loading=True,
8083
)
8184
SYSTOLIC_MODEL = OnnxYolov11PoseSingle(
8285
PATH_TO_MODELS / MODEL_CONFIG["systolic"]["name"],
8386
PATH_TO_MODEL_METADATA / MODEL_CONFIG["systolic"]["name"].replace(".onnx", ".json"),
8487
MODEL_CONFIG["systolic"]["imgsz"],
8588
MODEL_CONFIG["systolic"]["imgsz"],
89+
lazy_loading=True,
8690
)
8791
DIASTOLIC_MODEL = OnnxYolov11PoseSingle(
8892
PATH_TO_MODELS / MODEL_CONFIG["diastolic"]["name"],
8993
PATH_TO_MODEL_METADATA
9094
/ MODEL_CONFIG["diastolic"]["name"].replace(".onnx", ".json"),
9195
MODEL_CONFIG["diastolic"]["imgsz"],
9296
MODEL_CONFIG["diastolic"]["imgsz"],
97+
lazy_loading=True,
9398
)
9499
HEART_RATE_MODEL = OnnxYolov11PoseSingle(
95100
PATH_TO_MODELS / MODEL_CONFIG["heart_rate"]["name"],
96101
PATH_TO_MODEL_METADATA
97102
/ MODEL_CONFIG["heart_rate"]["name"].replace(".onnx", ".json"),
98103
MODEL_CONFIG["heart_rate"]["imgsz"],
99104
MODEL_CONFIG["heart_rate"]["imgsz"],
105+
lazy_loading=True,
100106
)
101107
CHECKBOXES_MODEL = OnnxYolov11Detection(
102108
PATH_TO_MODELS / MODEL_CONFIG["checkboxes"]["name"],
103109
PATH_TO_MODEL_METADATA
104110
/ MODEL_CONFIG["checkboxes"]["name"].replace(".onnx", ".json"),
105111
MODEL_CONFIG["checkboxes"]["imgsz"],
106112
MODEL_CONFIG["checkboxes"]["imgsz"],
113+
lazy_loading=True
107114
)
108115

109116

src/ChartExtractor/object_detection_models/onnx_yolov11_detection.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
model_classes_filepath: Path,
4949
input_im_width: int = 640,
5050
input_im_height: int = 640,
51+
lazy_loading: bool = False
5152
):
5253
"""Initializes the onnx model.
5354
@@ -62,11 +63,17 @@ def __init__(
6263
input_im_height (int):
6364
The image height that the model accepts.
6465
Defaults to 640.
66+
lazy_loading (bool):
67+
Whether or not to load the model only when it is called for detection.
68+
Defaults to False.
6569
"""
66-
self.model = ort.InferenceSession(model_weights_filepath)
70+
self.model_weights_filepath = model_weights_filepath
6771
self.input_im_width = input_im_width
6872
self.input_im_height = input_im_height
6973
self.classes = self.load_classes(model_classes_filepath)
74+
self.model_is_loaded = False
75+
if not lazy_loading:
76+
self.load_model()
7077

7178
@staticmethod
7279
def load_classes(model_metadata_filepath: Path) -> Dict:
@@ -94,6 +101,11 @@ def load_classes(model_metadata_filepath: Path) -> Dict:
94101
print(potential_err_msg)
95102
print(e)
96103
return classes
104+
105+
def load_model(self):
106+
"""Loads the model."""
107+
self.model = ort.InferenceSession(self.model_weights_filepath)
108+
self.model_is_loaded = True
97109

98110
def __call__(
99111
self,
@@ -116,6 +128,8 @@ def __call__(
116128
Returns:
117129
A list of detections for each image.
118130
"""
131+
if not model_is_loaded:
132+
self.load_model()
119133
if not isinstance(images, list):
120134
images = [images]
121135
detections: List[List[Detection]] = [

src/ChartExtractor/object_detection_models/onnx_yolov11_pose_single.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(
5252
model_classes_filepath: Path,
5353
input_im_width: int = 640,
5454
input_im_height: int = 640,
55+
lazy_loading: bool = False
5556
):
5657
"""Initializes the onnx model.
5758
@@ -66,11 +67,18 @@ def __init__(
6667
input_im_height (int):
6768
The image height that the model accepts.
6869
Defaults to 640.
70+
lazy_loading (bool):
71+
Whether or not to load the model only when it is called for detection.
72+
Defaults to False.
6973
"""
70-
self.model = ort.InferenceSession(model_weights_filepath)
74+
self.model_weights_filepath = model_weights_filepath
7175
self.input_im_width = input_im_width
7276
self.input_im_height = input_im_height
7377
self.classes = self.load_classes(model_classes_filepath)
78+
self.model_is_loaded = False
79+
if not lazy_loading:
80+
self.load_model()
81+
7482

7583
@staticmethod
7684
def load_classes(model_metadata_filepath: Path) -> Dict:
@@ -98,6 +106,11 @@ def load_classes(model_metadata_filepath: Path) -> Dict:
98106
print(potential_err_msg)
99107
print(e)
100108
return classes
109+
110+
def load_model(self):
111+
"""Loads the model."""
112+
self.model = ort.InferenceSession(self.model_weights_filepath)
113+
self.model_is_loaded = True
101114

102115
def __call__(
103116
self,
@@ -120,6 +133,8 @@ def __call__(
120133
Returns:
121134
A list of detections for each image.
122135
"""
136+
if not model_is_loaded:
137+
self.load_model()
123138
if not isinstance(images, list):
124139
images = [images]
125140
detections: List[List[Detection]] = [

0 commit comments

Comments
 (0)