Skip to content

Commit f5578f7

Browse files
Merge pull request #90 from roboflow/rstp_stream_support
Rstp stream support
2 parents a408242 + ac27d3c commit f5578f7

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

roboflow/core/workspace.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55

66
import requests
7+
from numpy import ndarray
78
from PIL import Image
89

910
from roboflow.config import API_URL, CLIP_FEATURIZE_URL, DEMO_KEYS
@@ -240,9 +241,9 @@ def active_learning(
240241
conditionals: (dict) = dictionary of upload conditions
241242
use_localhost: (bool) = determines if local http format used or remote endpoint
242243
"""
244+
prediction_results = []
243245

244246
# ensure that all fields of conditionals have a key:value pair
245-
246247
conditionals["target_classes"] = (
247248
[]
248249
if "target_classes" not in conditionals
@@ -292,7 +293,11 @@ def active_learning(
292293
print("inference reference point: ", inference_model)
293294
print("upload destination: ", upload_project)
294295

295-
globbed_files = glob.glob(raw_data_location + "/*" + raw_data_extension)
296+
# check if raw data type is cv2 frame
297+
if type(raw_data_location is type(ndarray)):
298+
globbed_files = [raw_data_location]
299+
else:
300+
globbed_files = glob.glob(raw_data_location + "/*" + raw_data_extension)
296301

297302
image1 = globbed_files[0]
298303
similarity_timeout_counter = 0
@@ -326,6 +331,8 @@ def active_learning(
326331
continue # skip this image if too similar or counter hits limit
327332

328333
predictions = inference_model.predict(image).json()["predictions"]
334+
# collect all predictions to return to user at end
335+
prediction_results.append({"image": image, "predictions": predictions})
329336

330337
# compare object and class count of predictions if enabled, continue if not enough occurances
331338
if not count_comparisons(
@@ -372,7 +379,12 @@ def active_learning(
372379
upload_project.upload(image, num_retry_uploads=3)
373380
break
374381

375-
return "complete"
382+
# return predictions with filenames if globbed images from dir, otherwise return latest prediction result
383+
return (
384+
prediction_results
385+
if type(raw_data_location) is not ndarray
386+
else prediction_results[-1]["predictions"]
387+
)
376388

377389
def __str__(self):
378390
projects = self.projects()

0 commit comments

Comments
 (0)