11import json
22import os
33import sys
4+ import time
45import zipfile
56
67import requests
@@ -69,6 +70,10 @@ def __init__(
6970 self .model_format = model_format
7071 self .workspace = workspace
7172 self .project = project
73+ if "exports" in version_dict .keys ():
74+ self .exports = version_dict ["exports" ]
75+ else :
76+ self .exports = []
7277
7378 version_without_workspace = os .path .basename (str (version ))
7479
@@ -101,6 +106,42 @@ def __init__(
101106 else :
102107 self .model = None
103108
109+ def __check_if_generating (self ):
110+ # check Roboflow API to see if this version is still generating
111+
112+ url = f"{ API_URL } /{ self .workspace } /{ self .project } /{ self .version } ?nocache=true"
113+ response = requests .get (url , params = {"api_key" : self .__api_key })
114+ response .raise_for_status ()
115+
116+ if response .json ()["version" ]["progress" ] == None :
117+ progress = 0.0
118+ else :
119+ progress = float (response .json ()["version" ]["progress" ])
120+
121+ return response .json ()["version" ]["generating" ], progress
122+
123+ def __wait_if_generating (self , recurse = False ):
124+ # checks if a given version is still in the progress of generating
125+
126+ still_generating , progress = self .__check_if_generating ()
127+
128+ if still_generating :
129+ progress_message = (
130+ "Generating version still in progress. Progress: "
131+ + str (round (progress * 100 , 2 ))
132+ + "%"
133+ )
134+ sys .stdout .write ("\r " + progress_message )
135+ sys .stdout .flush ()
136+ time .sleep (5 )
137+ return self .__wait_if_generating (recurse = True )
138+
139+ else :
140+ if recurse :
141+ sys .stdout .write ("\n " )
142+ sys .stdout .flush ()
143+ return
144+
104145 def download (self , model_format = None , location = None ):
105146 """
106147 Download and extract a ZIP of a version's dataset in a given format
@@ -110,11 +151,19 @@ def download(self, model_format=None, location=None):
110151
111152 :return: Dataset
112153 """
113- if location is None :
114- location = self .__get_download_location ()
154+
155+ self .__wait_if_generating ()
115156
116157 model_format = self .__get_format_identifier (model_format )
117158
159+ if model_format not in self .exports :
160+ self .export (model_format )
161+
162+ # if model_format is not in
163+
164+ if location is None :
165+ location = self .__get_download_location ()
166+
118167 if self .__api_key == "coco-128-sample" :
119168 link = "https://app.roboflow.com/ds/n9QwXwUK42?key=NnVCe2yMxP"
120169 else :
@@ -144,14 +193,92 @@ def export(self, model_format=None):
144193 :return: True
145194 :raises RuntimeError / HTTPError:
146195 """
196+
197+ model_format = self .__get_format_identifier (model_format )
198+
199+ self .__wait_if_generating ()
200+
147201 url = self .__get_download_url (model_format )
148- response = requests .post (url , params = {"api_key" : self .__api_key })
202+ response = requests .get (url , params = {"api_key" : self .__api_key })
149203 if not response .ok :
150204 try :
151205 raise RuntimeError (response .json ())
152206 except requests .exceptions .JSONDecodeError :
153207 response .raise_for_status ()
154208
209+ # the rest api returns 202 if the export is still in progress
210+ if response .status_code == 202 :
211+ status_code_check = 202
212+ while status_code_check == 202 :
213+ time .sleep (1 )
214+ response = requests .get (url , params = {"api_key" : self .__api_key })
215+ status_code_check = response .status_code
216+ if status_code_check == 202 :
217+ progress = response .json ()["progress" ]
218+ progress_message = (
219+ "Exporting format "
220+ + model_format
221+ + " in progress : "
222+ + str (round (progress * 100 , 2 ))
223+ + "%"
224+ )
225+ sys .stdout .write ("\r " + progress_message )
226+ sys .stdout .flush ()
227+
228+ if response .status_code == 200 :
229+ sys .stdout .write ("\n " )
230+ print ("\r " + "Version export complete for " + model_format + " format" )
231+ sys .stdout .flush ()
232+ return True
233+ else :
234+ try :
235+ raise RuntimeError (response .json ())
236+ except requests .exceptions .JSONDecodeError :
237+ response .raise_for_status ()
238+
239+ def train (self , speed = None , checkpoint = None ) -> bool :
240+ """
241+ Ask the Roboflow API to train a previously exported version's dataset.
242+ Args:
243+ speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`.
244+ checkpoint: A string representing the checkpoint to use while training
245+ Returns:
246+ True
247+ RuntimeError: If the Roboflow API returns an error with a helpful JSON body
248+ HTTPError: If the Network/Roboflow API fails and does not return JSON
249+ """
250+
251+ self .__wait_if_generating ()
252+
253+ train_model_format = "yolov5pytorch"
254+
255+ if train_model_format not in self .exports :
256+ self .export (train_model_format )
257+
258+ workspace , project , * _ = self .id .rsplit ("/" )
259+ url = f"{ API_URL } /{ workspace } /{ project } /{ self .version } /train"
260+
261+ data = {}
262+ if speed :
263+ data ["speed" ] = speed
264+
265+ if checkpoint :
266+ data ["checkpoint" ] = checkpoint
267+
268+ sys .stdout .write ("\r " + "Reaching out to Roboflow to start training..." )
269+ sys .stdout .write ("\n " )
270+ sys .stdout .flush ()
271+
272+ response = requests .post (url , json = data , params = {"api_key" : self .__api_key })
273+ if not response .ok :
274+ try :
275+ raise RuntimeError (response .json ())
276+ except requests .exceptions .JSONDecodeError :
277+ response .raise_for_status ()
278+
279+ sys .stdout .write ("\r " + "Training model in progress..." )
280+ sys .stdout .flush ()
281+
155282 return True
156283
157284 def upload_model (self , model_path : str ) -> None :
0 commit comments