|
4 | 4 | import sys |
5 | 5 |
|
6 | 6 | import requests |
| 7 | +from numpy import ndarray |
7 | 8 | from PIL import Image |
8 | 9 |
|
9 | 10 | from roboflow.config import API_URL, CLIP_FEATURIZE_URL, DEMO_KEYS |
@@ -240,9 +241,9 @@ def active_learning( |
240 | 241 | conditionals: (dict) = dictionary of upload conditions |
241 | 242 | use_localhost: (bool) = determines if local http format used or remote endpoint |
242 | 243 | """ |
| 244 | + prediction_results = [] |
243 | 245 |
|
244 | 246 | # ensure that all fields of conditionals have a key:value pair |
245 | | - |
246 | 247 | conditionals["target_classes"] = ( |
247 | 248 | [] |
248 | 249 | if "target_classes" not in conditionals |
@@ -292,7 +293,11 @@ def active_learning( |
292 | 293 | print("inference reference point: ", inference_model) |
293 | 294 | print("upload destination: ", upload_project) |
294 | 295 |
|
295 | | - globbed_files = glob.glob(raw_data_location + "/*" + raw_data_extension) |
| 296 | + # check if raw data type is cv2 frame |
| 297 | + if type(raw_data_location is type(ndarray)): |
| 298 | + globbed_files = [raw_data_location] |
| 299 | + else: |
| 300 | + globbed_files = glob.glob(raw_data_location + "/*" + raw_data_extension) |
296 | 301 |
|
297 | 302 | image1 = globbed_files[0] |
298 | 303 | similarity_timeout_counter = 0 |
@@ -326,6 +331,8 @@ def active_learning( |
326 | 331 | continue # skip this image if too similar or counter hits limit |
327 | 332 |
|
328 | 333 | predictions = inference_model.predict(image).json()["predictions"] |
| 334 | + # collect all predictions to return to user at end |
| 335 | + prediction_results.append({"image": image, "predictions": predictions}) |
329 | 336 |
|
330 | 337 | # compare object and class count of predictions if enabled, continue if not enough occurances |
331 | 338 | if not count_comparisons( |
@@ -372,7 +379,12 @@ def active_learning( |
372 | 379 | upload_project.upload(image, num_retry_uploads=3) |
373 | 380 | break |
374 | 381 |
|
375 | | - return "complete" |
| 382 | + # return predictions with filenames if globbed images from dir, otherwise return latest prediction result |
| 383 | + return ( |
| 384 | + prediction_results |
| 385 | + if type(raw_data_location) is not ndarray |
| 386 | + else prediction_results[-1]["predictions"] |
| 387 | + ) |
376 | 388 |
|
377 | 389 | def __str__(self): |
378 | 390 | projects = self.projects() |
|
0 commit comments