Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 65 additions & 14 deletions src/scriptworker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def format_json(data):
# is omitted we don't actually ever return None (because on failure we raise an Exception)
@overload
def load_json_or_yaml(
string: str, is_path: Optional[bool] = ..., file_type: Optional[str] = ..., exception: Type[BaseException] = ..., message: str = ...
string: str, is_path: Optional[bool] = ..., file_type: Optional[str] = ..., exception: Type[BaseException] = ..., message: Optional[str] = ...
) -> Dict[str, Any]: # pragma: no cover
...

Expand All @@ -532,7 +532,7 @@ def load_json_or_yaml(
is_path: Optional[bool] = False,
file_type: Optional[str] = "json",
exception: Optional[Type[BaseException]] = ScriptWorkerTaskException,
message: str = "Failed to load %(file_type)s: %(exc)s",
message: Optional[str] = None,
) -> Optional[Dict[str, Any]]: # pragma: no cover
...

Expand All @@ -542,7 +542,7 @@ def load_json_or_yaml(
is_path: Optional[bool] = False,
file_type: Optional[str] = "json",
exception: Optional[Type[BaseException]] = ScriptWorkerTaskException,
message: str = "Failed to load %(file_type)s: %(exc)s",
message: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""Load json or yaml from a filehandle or string, and raise a custom exception on failure.

Expand All @@ -552,8 +552,10 @@ def load_json_or_yaml(
file_type (str, optional): either "json" or "yaml". Defaults to "json".
exception (exception, optional): the exception to raise on failure.
If None, don't raise an exception. Defaults to ScriptWorkerTaskException.
message (str, optional): the message to use for the exception.
Defaults to "Failed to load %(file_type)s: %(exc)s"
message (str, optional): override the exception message. Supports the
``%(file_type)s``, ``%(exc)s``, and ``%(path)s`` placeholders (the
latter is empty when ``is_path=False``). Defaults to a message that
includes the path when ``is_path=True``, otherwise one that omits it.

Returns:
dict: the data from the string.
Expand All @@ -578,7 +580,9 @@ def load_json_or_yaml(
return contents
except (OSError, ValueError, yaml.scanner.ScannerError) as exc:
if exception is not None:
repl_dict = {"exc": str(exc), "file_type": file_type}
if message is None:
raise exception(f"Failed to load {file_type} from {string}: {exc}" if is_path else f"Failed to load {file_type}: {exc}")
repl_dict = {"exc": str(exc), "file_type": file_type, "path": string if is_path else ""}
raise exception(message % repl_dict)
return None

Expand Down Expand Up @@ -652,7 +656,7 @@ async def _log_download_error(resp, msg):
log.debug("Redirect history %s: %s; body=%s", get_loggable_url(str(h.url)), h.status, (await h.text())[:1000])


async def download_file(context, url, abs_filename, session=None, chunk_size=128, auth=None):
async def download_file(context, url, abs_filename, session=None, chunk_size=128, auth=None, expected_content_type=None):
"""Download a file, async.

Args:
Expand All @@ -663,6 +667,16 @@ async def download_file(context, url, abs_filename, session=None, chunk_size=128
None, use context.session. Defaults to None.
chunk_size (int, optional): the chunk size to read from the response
at a time. Default is 128.
expected_content_type (str, optional): if set, raise ``DownloadError``
when the server returns an ``HTML`` response and ``expected_content_type``
is something other than HTML. Narrow by design — servers vary too
much for a strict match — but catches the common
"error page instead of the JSON/YAML/artifact we asked for" case.

Raises:
DownloadError: on non-200 status, or an HTML response when
``expected_content_type`` was not HTML.
Download404: on 404 status.

"""
session = session or context.session
Expand All @@ -675,10 +689,15 @@ async def download_file(context, url, abs_filename, session=None, chunk_size=128
async with session.get(url, auth=auth) as resp:
if resp.status == 404:
await _log_download_error(resp, "404 downloading %(url)s: %(status)s; body=%(body)s")
raise Download404("{} status {}!".format(loggable_url, resp.status))
raise Download404(f"{loggable_url} status {resp.status}!")
elif resp.status != 200:
await _log_download_error(resp, "Failed to download %(url)s: %(status)s; body=%(body)s")
raise DownloadError("{} status {} is not 200!".format(loggable_url, resp.status))
raise DownloadError(f"{loggable_url} status {resp.status} is not 200!")
if expected_content_type:
actual_content_type = (resp.headers.get("Content-Type") or "").split(";", 1)[0].strip().lower()
if actual_content_type == "text/html" and "html" not in expected_content_type.lower():
await _log_download_error(resp, "HTML response for %(url)s (expected non-HTML): %(status)s; body=%(body)s")
raise DownloadError(f"{loggable_url}: expected Content-Type {expected_content_type!r} but got HTML; treating as an error page")
makedirs(parent_dir)
with open(abs_filename, "wb") as fd:
while True:
Expand Down Expand Up @@ -729,6 +748,11 @@ def get_parts_of_url_path(url):
async def load_json_or_yaml_from_url(context: Context, url: str, path: str, overwrite: bool = True, auth: Optional[str] = None) -> Dict[str, Any]:
"""Retry a json/yaml file download, load it, then return its data.

Download and parse are combined into a single retry unit: if parsing the
downloaded file fails (e.g. truncated body, an HTML error page, a Cloud
Storage transcoding glitch), the cached file is deleted and the download
is retried.

Args:
context (scriptworker.context.Context): the scriptworker context.
url (str): the url to download
Expand All @@ -745,15 +769,42 @@ async def load_json_or_yaml_from_url(context: Context, url: str, path: str, over
"""
if path.endswith("json"):
file_type = "json"
expected_content_type = "application/json"
else:
file_type = "yaml"
expected_content_type = "application/yaml"

kwargs = {}
download_kwargs = {"expected_content_type": expected_content_type}
if auth:
kwargs = {"auth": auth}
if not overwrite or not os.path.exists(path):
await retry_async(download_file, args=(context, url, path), kwargs=kwargs, retry_exceptions=(DownloadError, aiohttp.ClientError, asyncio.TimeoutError))
return load_json_or_yaml(path, is_path=True, file_type=file_type)
download_kwargs["auth"] = auth
loggable_url = get_loggable_url(url)

async def _download_and_parse():
# Pre-existing cache semantics (despite the misleading parameter
# name): ``overwrite=True`` uses an existing file when present;
# ``overwrite=False`` always (re)downloads.
if not overwrite or not os.path.exists(path):
await download_file(context, url, path, **download_kwargs)
try:
return load_json_or_yaml(path, is_path=True, file_type=file_type)
except ScriptWorkerTaskException as exc:
log.warning(
"Failed to parse %s from %s (cached at %s); invalidating cache and retrying: %s",
file_type,
loggable_url,
path,
exc,
)
try:
os.remove(path)
except OSError:
pass
raise DownloadError(f"parse failure for {loggable_url}: {exc}")

return await retry_async(
_download_and_parse,
retry_exceptions=(DownloadError, aiohttp.ClientError, asyncio.TimeoutError),
)


# match_url_path_callback {{{1
Expand Down
45 changes: 43 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,47 @@ async def test_download_file_404(rw_context, fake_session_404, tmpdir, auth):
await utils.download_file(rw_context, "url", path, session=fake_session_404, auth=auth)


@pytest.mark.asyncio
async def test_download_file_rejects_html_when_non_html_expected(rw_context, fake_session, tmpdir):
"""An HTML response raises DownloadError when a non-HTML content-type was expected."""

async def html_request(method, url, *args, **kwargs):
resp = FakeResponse(method, url, status=200)
resp._headers = {"Content-Type": "text/html; charset=utf-8"}
return resp

fake_session._request = html_request
path = os.path.join(tmpdir, "foo.json")
with pytest.raises(DownloadError, match="HTML"):
await utils.download_file(rw_context, "url", path, session=fake_session, expected_content_type="application/json")
assert not os.path.exists(path)


@pytest.mark.asyncio
async def test_load_json_or_yaml_from_url_retries_on_parse_failure(rw_context, mocker, tmpdir):
"""A parse failure invalidates the cache and triggers a re-download."""
path = os.path.join(tmpdir, "out.json")
call_count = {"n": 0}

async def flaky_download(rw_context, url, abs_filename, session=None, chunk_size=128, auth=None, expected_content_type=None):
call_count["n"] += 1
if call_count["n"] == 1:
# First attempt: write garbage JSON
with open(abs_filename, "w") as fh:
fh.write("not valid json {")
else:
with open(abs_filename, "w") as fh:
fh.write('{"ok": true}')

# Neutralize retry_async backoff so the test is fast
mocker.patch.object(utils, "calculate_sleep_time", return_value=0)
mocker.patch.object(utils, "download_file", new=flaky_download)

result = await utils.load_json_or_yaml_from_url(rw_context, "url", path)
assert result == {"ok": True}
assert call_count["n"] == 2


# format_json {{{1
def test_format_json():
expected = "\n".join(["{", ' "a": 1,', ' "b": [', " 4,", " 3,", " 2", " ],", ' "c": {', ' "d": 5', " }", "}"])
Expand Down Expand Up @@ -474,7 +515,7 @@ def test_load_json_or_yaml(string, is_path, exception, raises, result):
async def test_load_json_or_yaml_from_url(rw_context, mocker, overwrite, file_type, tmpdir):
called_with_auth = []

async def mocked_download_file(rw_context, url, abs_filename, session=None, chunk_size=128, auth=None):
async def mocked_download_file(rw_context, url, abs_filename, session=None, chunk_size=128, auth=None, expected_content_type=None):
called_with_auth.append(auth == "someAuth")
return

Expand All @@ -495,7 +536,7 @@ async def mocked_download_file(rw_context, url, abs_filename, session=None, chun
async def test_load_json_or_yaml_from_url_auth(rw_context, mocker, overwrite, file_type, tmpdir):
called_with_auth = []

async def mocked_download_file(rw_context, url, abs_filename, session=None, chunk_size=128, auth=None):
async def mocked_download_file(rw_context, url, abs_filename, session=None, chunk_size=128, auth=None, expected_content_type=None):
called_with_auth.append(auth == "someAuth")
return

Expand Down