Skip to content

Commit f790b66

Browse files
Merge pull request #89 from roboflow/pipGenerate
Code Actions V1
2 parents f5578f7 + 2422168 commit f790b66

File tree

6 files changed

+275
-27
lines changed

6 files changed

+275
-27
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
certifi==2021.5.30
1+
certifi==2022.12.7
22
chardet==4.0.0
33
cycler==0.10.0
44
idna==2.10

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from roboflow.core.project import Project
99
from roboflow.core.workspace import Workspace
1010

11-
__version__ = "0.2.21"
11+
__version__ = "0.2.22"
1212

1313

1414
def check_key(api_key, model, notebook, num_retries=0):

roboflow/core/project.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import io
33
import json
44
import os
5+
import sys
56
import urllib
67
import warnings
78

@@ -96,6 +97,119 @@ def versions(self):
9697
version_array.append(version_object)
9798
return version_array
9899

100+
def generate_version(self, settings):
101+
102+
"""
103+
Settings, a python dict with augmentation and preprocessing keys and specifications for generation.
104+
These settings mirror capabilities available via the Roboflow UI.
105+
For example:
106+
{
107+
"augmentation": {
108+
"bbblur": { "pixels": 1.5 },
109+
"bbbrightness": { "brighten": true, "darken": false, "percent": 91 },
110+
"bbcrop": { "min": 12, "max": 71 },
111+
"bbexposure": { "percent": 30 },
112+
"bbflip": { "horizontal": true, "vertical": false },
113+
"bbnoise": { "percent": 50 },
114+
"bbninety": { "clockwise": true, "counter-clockwise": false, "upside-down": false },
115+
"bbrotate": { "degrees": 45 },
116+
"bbshear": { "horizontal": 45, "vertical": 45 },
117+
"blur": { "pixels": 1.5 },
118+
"brightness": { "brighten": true, "darken": false, "percent": 91 },
119+
"crop": { "min": 12, "max": 71 },
120+
"cutout": { "count": 26, "percent": 71 },
121+
"exposure": { "percent": 30 },
122+
"flip": { "horizontal": true, "vertical": false },
123+
"hue": { "degrees": 180 },
124+
"image": { "versions": 32 },
125+
"mosaic": true,
126+
"ninety": { "clockwise": true, "counter-clockwise": false, "upside-down": false },
127+
"noise": { "percent": 50 },
128+
"rgrayscale": { "percent": 50 },
129+
"rotate": { "degrees": 45 },
130+
"saturation": { "percent": 50 },
131+
"shear": { "horizontal": 45, "vertical": 45 }
132+
},
133+
"preprocessing": {
134+
"auto-orient": true,
135+
"contrast": { "type": "Contrast Stretching" },
136+
"filter-null": { "percent": 50 },
137+
"grayscale": true,
138+
"isolate": true,
139+
"remap": { "original_class_name": "new_class_name" },
140+
"resize": { "width": 200, "height": 200, "format": "Stretch to" },
141+
"static-crop": { "x_min": 10, "x_max": 90, "y_min": 10, "y_max": 90 },
142+
"tile": { "rows": 2, "columns": 2 }
143+
}
144+
}
145+
146+
Returns: The version number that is being generated
147+
"""
148+
149+
if not {"augmentation", "preprocessing"} <= settings.keys():
150+
raise (
151+
RuntimeError(
152+
"augmentation and preprocessing keys are required to generate. If none are desired specify empty dict associated with that key."
153+
)
154+
)
155+
156+
r = requests.post(
157+
f"{API_URL}/{self.__workspace}/{self.__project_name}/generate?api_key={self.__api_key}",
158+
json=settings,
159+
)
160+
161+
try:
162+
r_json = r.json()
163+
except:
164+
raise ("Error when requesting to generate a new version for project.")
165+
166+
# if the generation succeeds, return the version that is being generated
167+
if r.status_code == 200:
168+
sys.stdout.write(
169+
"\r"
170+
+ r_json["message"]
171+
+ " for new version "
172+
+ str(r_json["version"])
173+
+ "."
174+
)
175+
sys.stdout.write("\n")
176+
sys.stdout.flush()
177+
return int(r_json["version"])
178+
else:
179+
if "error" in r_json.keys():
180+
raise RuntimeError(r_json["error"])
181+
else:
182+
raise RuntimeError(json.dumps(r_json))
183+
184+
def train(
185+
self,
186+
new_version_settings={
187+
"preprocessing": {
188+
"auto-orient": True,
189+
"resize": {"width": 640, "height": 640, "format": "Stretch to"},
190+
},
191+
"augmentation": {},
192+
},
193+
speed=None,
194+
checkpoint=None,
195+
) -> bool:
196+
"""
197+
Ask the Roboflow API to train a previously exported version's dataset.
198+
Args:
199+
speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`.
200+
checkpoint: A string representing the checkpoint to use while training
201+
Returns:
202+
True
203+
RuntimeError: If the Roboflow API returns an error with a helpful JSON body
204+
HTTPError: If the Network/Roboflow API fails and does not return JSON
205+
"""
206+
207+
new_version = self.generate_version(settings=new_version_settings)
208+
new_version = self.version(new_version)
209+
new_version.train(speed=speed, checkpoint=checkpoint)
210+
211+
return True
212+
99213
def version(self, version_number, local=None):
100214
"""Retrieves information about a specific version, and throws it into an object.
101215
:param version_number: the version number that you want to retrieve

roboflow/core/version.py

Lines changed: 130 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33
import sys
4+
import time
45
import zipfile
56

67
import requests
@@ -69,6 +70,10 @@ def __init__(
6970
self.model_format = model_format
7071
self.workspace = workspace
7172
self.project = project
73+
if "exports" in version_dict.keys():
74+
self.exports = version_dict["exports"]
75+
else:
76+
self.exports = []
7277

7378
version_without_workspace = os.path.basename(str(version))
7479

@@ -101,6 +106,42 @@ def __init__(
101106
else:
102107
self.model = None
103108

109+
def __check_if_generating(self):
110+
# check Roboflow API to see if this version is still generating
111+
112+
url = f"{API_URL}/{self.workspace}/{self.project}/{self.version}?nocache=true"
113+
response = requests.get(url, params={"api_key": self.__api_key})
114+
response.raise_for_status()
115+
116+
if response.json()["version"]["progress"] == None:
117+
progress = 0.0
118+
else:
119+
progress = float(response.json()["version"]["progress"])
120+
121+
return response.json()["version"]["generating"], progress
122+
123+
def __wait_if_generating(self, recurse=False):
124+
# checks if a given version is still in the progress of generating
125+
126+
still_generating, progress = self.__check_if_generating()
127+
128+
if still_generating:
129+
progress_message = (
130+
"Generating version still in progress. Progress: "
131+
+ str(round(progress * 100, 2))
132+
+ "%"
133+
)
134+
sys.stdout.write("\r" + progress_message)
135+
sys.stdout.flush()
136+
time.sleep(5)
137+
return self.__wait_if_generating(recurse=True)
138+
139+
else:
140+
if recurse:
141+
sys.stdout.write("\n")
142+
sys.stdout.flush()
143+
return
144+
104145
def download(self, model_format=None, location=None):
105146
"""
106147
Download and extract a ZIP of a version's dataset in a given format
@@ -110,11 +151,19 @@ def download(self, model_format=None, location=None):
110151
111152
:return: Dataset
112153
"""
113-
if location is None:
114-
location = self.__get_download_location()
154+
155+
self.__wait_if_generating()
115156

116157
model_format = self.__get_format_identifier(model_format)
117158

159+
if model_format not in self.exports:
160+
self.export(model_format)
161+
162+
# if model_format is not in
163+
164+
if location is None:
165+
location = self.__get_download_location()
166+
118167
if self.__api_key == "coco-128-sample":
119168
link = "https://app.roboflow.com/ds/n9QwXwUK42?key=NnVCe2yMxP"
120169
else:
@@ -144,14 +193,92 @@ def export(self, model_format=None):
144193
:return: True
145194
:raises RuntimeError / HTTPError:
146195
"""
196+
197+
model_format = self.__get_format_identifier(model_format)
198+
199+
self.__wait_if_generating()
200+
147201
url = self.__get_download_url(model_format)
148-
response = requests.post(url, params={"api_key": self.__api_key})
202+
response = requests.get(url, params={"api_key": self.__api_key})
149203
if not response.ok:
150204
try:
151205
raise RuntimeError(response.json())
152206
except requests.exceptions.JSONDecodeError:
153207
response.raise_for_status()
154208

209+
# the rest api returns 202 if the export is still in progress
210+
if response.status_code == 202:
211+
status_code_check = 202
212+
while status_code_check == 202:
213+
time.sleep(1)
214+
response = requests.get(url, params={"api_key": self.__api_key})
215+
status_code_check = response.status_code
216+
if status_code_check == 202:
217+
progress = response.json()["progress"]
218+
progress_message = (
219+
"Exporting format "
220+
+ model_format
221+
+ " in progress : "
222+
+ str(round(progress * 100, 2))
223+
+ "%"
224+
)
225+
sys.stdout.write("\r" + progress_message)
226+
sys.stdout.flush()
227+
228+
if response.status_code == 200:
229+
sys.stdout.write("\n")
230+
print("\r" + "Version export complete for " + model_format + " format")
231+
sys.stdout.flush()
232+
return True
233+
else:
234+
try:
235+
raise RuntimeError(response.json())
236+
except requests.exceptions.JSONDecodeError:
237+
response.raise_for_status()
238+
239+
def train(self, speed=None, checkpoint=None) -> bool:
240+
"""
241+
Ask the Roboflow API to train a previously exported version's dataset.
242+
Args:
243+
speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`.
244+
checkpoint: A string representing the checkpoint to use while training
245+
Returns:
246+
True
247+
RuntimeError: If the Roboflow API returns an error with a helpful JSON body
248+
HTTPError: If the Network/Roboflow API fails and does not return JSON
249+
"""
250+
251+
self.__wait_if_generating()
252+
253+
train_model_format = "yolov5pytorch"
254+
255+
if train_model_format not in self.exports:
256+
self.export(train_model_format)
257+
258+
workspace, project, *_ = self.id.rsplit("/")
259+
url = f"{API_URL}/{workspace}/{project}/{self.version}/train"
260+
261+
data = {}
262+
if speed:
263+
data["speed"] = speed
264+
265+
if checkpoint:
266+
data["checkpoint"] = checkpoint
267+
268+
sys.stdout.write("\r" + "Reaching out to Roboflow to start training...")
269+
sys.stdout.write("\n")
270+
sys.stdout.flush()
271+
272+
response = requests.post(url, json=data, params={"api_key": self.__api_key})
273+
if not response.ok:
274+
try:
275+
raise RuntimeError(response.json())
276+
except requests.exceptions.JSONDecodeError:
277+
response.raise_for_status()
278+
279+
sys.stdout.write("\r" + "Training model in progress...")
280+
sys.stdout.flush()
281+
155282
return True
156283

157284
def upload_model(self, model_path: str) -> None:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
long_description_content_type="text/markdown",
2222
url="https://github.com/roboflow-ai/roboflow-python",
2323
install_requires=[
24-
"certifi==2021.5.30",
24+
"certifi==2022.12.7",
2525
"chardet==4.0.0",
2626
"cycler==0.10.0",
2727
"glob2",

0 commit comments

Comments
 (0)