From 775e324d392dcc899def9550c7b2a88d3156c431 Mon Sep 17 00:00:00 2001 From: Dale Myers Date: Thu, 9 Apr 2026 23:27:23 +0100 Subject: [PATCH] Add async/await support via httpx migration and code generation [LLM] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add full async API support to simple_ado in a backwards-compatible way. The async code in simple_ado/_async/ is the single source of truth; synchronous code is auto-generated from it by scripts/generate_sync.py. Existing sync imports and usage are unchanged. Architecture: - simple_ado/_async/ contains the hand-written async source (24 files) - scripts/generate_sync.py transforms async → sync via text replacement (strip async/await, ADOAsync* → ADO*, AsyncIterator → Iterator, etc.) and formats output with black in-memory for idempotent generation - Generated sync files overwrite the top-level simple_ado/ modules and carry a "DO NOT EDIT" header - Shared modules (models, comments, exceptions, ado_types) live at the top level and are imported by both async and sync code HTTP layer changes: - Replace requests with httpx (sync httpx.Client / async httpx.AsyncClient) - Add stream_get() / stream_post() async context managers for streaming - get(stream=True) / post(stream=True) still work (with deprecation warning) for backwards compatibility - Narrow retryable status codes to {400,408,429,500,502,503,504} instead of the full 4xx range — deterministic failures like 401/403/404 no longer retry - Add 300s default timeout via httpx.Timeout - ADOHTTPClient gains close(), __enter__/__exit__, and __del__ for proper httpx.Client lifecycle management Auth changes: - Add ADOAsyncAuth base class with async get_authorization_header() - Add ADOAsyncTokenAuth, ADOAsyncBasicAuth, ADOAsyncAzIDAuth using azure.identity.aio.DefaultAzureCredential - Auth classes gain close() for resource cleanup ADOWorkItem changes: - Sync __getitem__ retains auto-refresh on missing fields - Async __getitem__ raises KeyError immediately (can't be async); use await work_item.get_field(key) for auto-refresh behavior - Add set() convenience method for both sync and async Test infrastructure: - Replace responses with respx for httpx mocking - Add pytest-asyncio for async test support - Async tests live in tests/unit/_async/ and are auto-generated into tests/unit/ by the same transform script - Add tests for HTTP client rate limiting, retryable status codes, and work item field access patterns Usage: # Existing sync — unchanged from simple_ado import ADOClient, ADOTokenAuth client = ADOClient(tenant="org", auth=ADOTokenAuth("token")) # New async from simple_ado._async import ADOAsyncClient, ADOAsyncTokenAuth async with ADOAsyncClient(tenant="org", auth=ADOAsyncTokenAuth("token")) as client: await client.verify_access() Breaking changes: - ADOHTTPException.response is now httpx.Response (was requests.Response) wrapped in a _CompatResponse that adds .ok for backwards compatibility - requests is no longer a dependency; httpx is required instead Co-Authored-By: Claude Opus 4.6 (1M context) --- poetry.lock | 145 ++- pylintrc | 3 +- pyproject.toml | 6 +- scripts/generate_sync.py | 845 ++++++++++++++++++ simple_ado/__init__.py | 34 +- simple_ado/_async/__init__.py | 442 +++++++++ simple_ado/_async/audit.py | 94 ++ simple_ado/_async/auth/__init__.py | 14 + simple_ado/_async/auth/ado_auth.py | 21 + simple_ado/_async/auth/ado_azid_auth.py | 36 + simple_ado/_async/auth/ado_basic_auth.py | 30 + simple_ado/_async/auth/ado_token_auth.py | 19 + simple_ado/_async/base_client.py | 28 + simple_ado/_async/builds.py | 495 ++++++++++ simple_ado/_async/endpoints.py | 95 ++ simple_ado/_async/git.py | 821 +++++++++++++++++ simple_ado/_async/governance.py | 397 ++++++++ simple_ado/_async/graph.py | 166 ++++ simple_ado/_async/http_client.py | 622 +++++++++++++ simple_ado/_async/identities.py | 65 ++ simple_ado/_async/pipelines.py | 216 +++++ simple_ado/_async/pools.py | 171 ++++ simple_ado/_async/pull_requests.py | 495 ++++++++++ simple_ado/_async/security.py | 649 ++++++++++++++ simple_ado/_async/user.py | 22 + simple_ado/_async/utilities.py | 59 ++ simple_ado/_async/wiki.py | 85 ++ simple_ado/_async/work_item.py | 210 +++++ simple_ado/_async/workitems.py | 630 +++++++++++++ simple_ado/ado_types.py | 10 + simple_ado/audit.py | 2 + simple_ado/auth/__init__.py | 2 + simple_ado/auth/ado_auth.py | 9 + simple_ado/auth/ado_azid_auth.py | 10 +- simple_ado/auth/ado_basic_auth.py | 17 +- simple_ado/auth/ado_token_auth.py | 2 + simple_ado/base_client.py | 2 + simple_ado/builds.py | 77 +- simple_ado/endpoints.py | 2 + simple_ado/exceptions.py | 36 +- simple_ado/git.py | 9 +- simple_ado/governance.py | 6 +- simple_ado/graph.py | 2 + simple_ado/http_client.py | 236 ++++- simple_ado/identities.py | 4 +- simple_ado/pipelines.py | 4 +- simple_ado/pools.py | 2 + simple_ado/pull_requests.py | 12 +- simple_ado/security.py | 4 +- simple_ado/types.py | 16 +- simple_ado/user.py | 2 + simple_ado/utilities.py | 8 +- simple_ado/wiki.py | 2 + simple_ado/work_item.py | 53 +- simple_ado/workitems.py | 55 +- tests/conftest.py | 26 +- tests/unit/_async/__init__.py | 0 tests/unit/_async/test_builds.py | 224 +++++ tests/unit/_async/test_client.py | 105 +++ tests/unit/_async/test_http_client.py | 271 ++++++ tests/unit/_async/test_work_item.py | 249 ++++++ tests/unit/_async/test_work_item_get_field.py | 106 +++ tests/unit/test_builds.py | 152 ++-- tests/unit/test_client.py | 76 +- tests/unit/test_http_client.py | 251 ++++++ tests/unit/test_work_item.py | 139 +-- tests/unit/test_work_item_get_field.py | 104 +++ 67 files changed, 8828 insertions(+), 374 deletions(-) create mode 100644 scripts/generate_sync.py create mode 100644 simple_ado/_async/__init__.py create mode 100644 simple_ado/_async/audit.py create mode 100644 simple_ado/_async/auth/__init__.py create mode 100644 simple_ado/_async/auth/ado_auth.py create mode 100644 simple_ado/_async/auth/ado_azid_auth.py create mode 100644 simple_ado/_async/auth/ado_basic_auth.py create mode 100644 simple_ado/_async/auth/ado_token_auth.py create mode 100644 simple_ado/_async/base_client.py create mode 100644 simple_ado/_async/builds.py create mode 100644 simple_ado/_async/endpoints.py create mode 100644 simple_ado/_async/git.py create mode 100644 simple_ado/_async/governance.py create mode 100644 simple_ado/_async/graph.py create mode 100644 simple_ado/_async/http_client.py create mode 100644 simple_ado/_async/identities.py create mode 100644 simple_ado/_async/pipelines.py create mode 100644 simple_ado/_async/pools.py create mode 100644 simple_ado/_async/pull_requests.py create mode 100644 simple_ado/_async/security.py create mode 100644 simple_ado/_async/user.py create mode 100644 simple_ado/_async/utilities.py create mode 100644 simple_ado/_async/wiki.py create mode 100644 simple_ado/_async/work_item.py create mode 100644 simple_ado/_async/workitems.py create mode 100644 simple_ado/ado_types.py create mode 100644 tests/unit/_async/__init__.py create mode 100644 tests/unit/_async/test_builds.py create mode 100644 tests/unit/_async/test_client.py create mode 100644 tests/unit/_async/test_http_client.py create mode 100644 tests/unit/_async/test_work_item.py create mode 100644 tests/unit/_async/test_work_item_get_field.py create mode 100644 tests/unit/test_http_client.py create mode 100644 tests/unit/test_work_item_get_field.py diff --git a/poetry.lock b/poetry.lock index aefac14..2268046 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,24 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.3 and should not be changed by hand. + +[[package]] +name = "anyio" +version = "4.13.0" +description = "High-level concurrency and networking framework on top of asyncio or Trio" +optional = false +python-versions = ">=3.10" +groups = ["main", "dev"] +files = [ + {file = "anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708"}, + {file = "anyio-4.13.0.tar.gz", hash = "sha256:334b70e641fd2221c1505b3890c69882fe4a2df910cba14d97019b90b24439dc"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} + +[package.extras] +trio = ["trio (>=0.32.0)"] [[package]] name = "astroid" @@ -223,7 +243,7 @@ version = "3.4.4" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7" -groups = ["main", "dev"] +groups = ["main"] files = [ {file = "charset_normalizer-3.4.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e824f1492727fa856dd6eda4f7cee25f8518a12f3c4a56a74e8095695089cf6d"}, {file = "charset_normalizer-3.4.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4bd5d4137d500351a30687c2d3971758aac9a19208fc110ccb9d7188fbe709e8"}, @@ -583,7 +603,7 @@ version = "1.3.1" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" -groups = ["dev"] +groups = ["main", "dev"] markers = "python_version == \"3.10\"" files = [ {file = "exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598"}, @@ -596,6 +616,65 @@ typing-extensions = {version = ">=4.6.0", markers = "python_version < \"3.13\""} [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "h11" +version = "0.16.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.8" +groups = ["main", "dev"] +files = [ + {file = "h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86"}, + {file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"}, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +groups = ["main", "dev"] +files = [ + {file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"}, + {file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.16" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<1.0)"] + +[[package]] +name = "httpx" +version = "0.28.1" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +groups = ["main", "dev"] +files = [ + {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"}, + {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" + +[package.extras] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +zstd = ["zstandard (>=0.18.0)"] + [[package]] name = "idna" version = "3.11" @@ -1049,6 +1128,25 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.25.3" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pytest_asyncio-0.25.3-py3-none-any.whl", hash = "sha256:9e89518e0f9bd08928f97a3482fdc4e244df17529460bc038291ccaf8f85c7c3"}, + {file = "pytest_asyncio-0.25.3.tar.gz", hash = "sha256:fc1da2cf9f125ada7e710b4ddad05518d4cee187ae9412e9ac9271003497f07a"}, +] + +[package.dependencies] +pytest = ">=8.2,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "7.0.0" @@ -1246,7 +1344,7 @@ version = "2.32.5" description = "Python HTTP for Humans." optional = false python-versions = ">=3.9" -groups = ["main", "dev"] +groups = ["main"] files = [ {file = "requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6"}, {file = "requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf"}, @@ -1263,25 +1361,19 @@ socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] -name = "responses" -version = "0.23.3" -description = "A utility library for mocking out the `requests` Python library." +name = "respx" +version = "0.22.0" +description = "A utility for mocking out the Python HTTPX and HTTP Core libraries." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" groups = ["dev"] files = [ - {file = "responses-0.23.3-py3-none-any.whl", hash = "sha256:e6fbcf5d82172fecc0aa1860fd91e58cbfd96cee5e96da5b63fa6eb3caa10dd3"}, - {file = "responses-0.23.3.tar.gz", hash = "sha256:205029e1cb334c21cb4ec64fc7599be48b859a0fd381a42443cdd600bfe8b16a"}, + {file = "respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0"}, + {file = "respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91"}, ] [package.dependencies] -pyyaml = "*" -requests = ">=2.30.0,<3.0" -types-PyYAML = "*" -urllib3 = ">=1.25.10,<3.0" - -[package.extras] -tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli ; python_version < \"3.11\"", "tomli-w", "types-requests"] +httpx = ">=0.25.0" [[package]] name = "tenacity" @@ -1393,21 +1485,6 @@ files = [ {file = "types_pyyaml-6.0.12.20250915.tar.gz", hash = "sha256:0f8b54a528c303f0e6f7165687dd33fafa81c807fcac23f632b63aa624ced1d3"}, ] -[[package]] -name = "types-requests" -version = "2.32.4.20260107" -description = "Typing stubs for requests" -optional = false -python-versions = ">=3.9" -groups = ["dev"] -files = [ - {file = "types_requests-2.32.4.20260107-py3-none-any.whl", hash = "sha256:b703fe72f8ce5b31ef031264fe9395cac8f46a04661a79f7ed31a80fb308730d"}, - {file = "types_requests-2.32.4.20260107.tar.gz", hash = "sha256:018a11ac158f801bfa84857ddec1650750e393df8a004a8a9ae2a9bec6fcb24f"}, -] - -[package.dependencies] -urllib3 = ">=2" - [[package]] name = "types-toml" version = "0.10.8.20240310" @@ -1438,7 +1515,7 @@ version = "2.6.3" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.9" -groups = ["main", "dev"] +groups = ["main"] files = [ {file = "urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4"}, {file = "urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed"}, @@ -1453,4 +1530,4 @@ zstd = ["backports-zstd (>=1.0.0) ; python_version < \"3.14\""] [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "f6c95fc5bafea14d20d20447dce5f0a44e77f4b290e3b5617bc56a22a7548f8f" +content-hash = "4617193a2bf0eb77533974b76c74ecb9548f2eb74cdbb61ba1b8e958c1a735ce" diff --git a/pylintrc b/pylintrc index d0011ae..eb3f172 100644 --- a/pylintrc +++ b/pylintrc @@ -15,6 +15,7 @@ load-plugins=pylint.extensions.docparams,pylint.extensions.docstyle,pylint.exten # C0413 wrong-import-position: Import "%s" should be placed at the top of the module Used when code and imports are mixed # C1801 len-as-condition: Do not use `len(SEQUENCE)` as condition value +# R0801 duplicate-code: Sync code is auto-generated from _async/ — duplication is by design # R0913 too-many-arguments: Too many arguments for function / method # R0917 too-many-positional-arguments: Too many positional arguments # W0511 fixme: TODO statements @@ -23,7 +24,7 @@ load-plugins=pylint.extensions.docparams,pylint.extensions.docstyle,pylint.exten # W1202 logging-format-interpolation: Use % formatting in logging functions and pass the % parameters as arguments # W1203 logging-fstring-interpolation: Use % formatting in logging functions and pass the % parameters as arguments # W3101 missing-timeout: Missing timeout argument for method 'requests.*' -disable=C0413,C1801,R0913,R0917,W0511,W0703,W1201,W1202,W1203,W3101 +disable=C0413,C1801,R0801,R0913,R0917,W0511,W0703,W1201,W1202,W1203,W3101 [REPORTS] diff --git a/pyproject.toml b/pyproject.toml index 5ea598a..fa57b35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.10" deserialize = "^2.2.0" -requests = "^2.31.0" +httpx = "^0.28.0" tenacity = "^8.2.2" azure-identity = "^1.25.1" @@ -50,10 +50,10 @@ pytest-cov = "^7.0.0" pytest-mock = "^3.11.1" python-dotenv = "^1.1.1" PyYAML = "^6.0.3" -responses = "^0.23.1" +respx = "^0.22.0" toml = "^0.10.2" types-PyYAML = "^6.0.12.11" -types-requests = "^2.31.0.2" +pytest-asyncio = "^0.25.0" types-toml = "^0.10.8.7" [tool.poetry.group.dev.dependencies] diff --git a/scripts/generate_sync.py b/scripts/generate_sync.py new file mode 100644 index 0000000..e63b82b --- /dev/null +++ b/scripts/generate_sync.py @@ -0,0 +1,845 @@ +#!/usr/bin/env python3 + +"""Generate synchronous code from the async source of truth. + +This script transforms the async code in simple_ado/_async/ into synchronous +code at the top level of simple_ado/. The async code is the source of truth; +the sync code is generated and should not be edited by hand. + +Usage: + python scripts/generate_sync.py +""" + +import os +import py_compile +import re +import sys +from typing import Callable + +import black + +# Directories +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +ASYNC_DIR = os.path.join(REPO_ROOT, "simple_ado", "_async") +SYNC_DIR = os.path.join(REPO_ROOT, "simple_ado") +ASYNC_TEST_DIR = os.path.join(REPO_ROOT, "tests", "unit", "_async") +SYNC_TEST_DIR = os.path.join(REPO_ROOT, "tests", "unit") + +BLACK_MODE = black.Mode(line_length=100) + + +def _format_with_black(source: str) -> str: + """Format source code with black. Returns source unchanged if black fails.""" + try: + return black.format_str(source, mode=BLACK_MODE) + except black.NothingChanged: + return source + except Exception: + return source + + +# Files that are shared (not generated from _async/) and should not be overwritten +SHARED_FILES = { + "ado_types.py", + "comments.py", + "exceptions.py", +} + +# Header to mark generated files +GENERATED_HEADER = "# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/{}. DO NOT EDIT.\n\n" + + +def transform_source(source: str, relative_path: str) -> str: + """Transform async Python source code to synchronous equivalent. + + :param source: The async source code + :param relative_path: The relative path of the source file (for the header) + :returns: The transformed synchronous source code + """ + + result = source + + # --- Import transformations --- + + # azure.identity.aio → azure.identity + result = result.replace("azure.identity.aio", "azure.identity") + + # simple_ado._async.foo → simple_ado.foo (with trailing dot, e.g. in from simple_ado._async.audit) + result = result.replace("simple_ado._async.", "simple_ado.") + # simple_ado._async → simple_ado (without trailing dot, e.g. in from simple_ado._async import) + result = result.replace("simple_ado._async", "simple_ado") + + # --- Class name transformations --- + # ADOAsync* → ADO* everywhere (identifiers, __all__ strings, type annotations). + # The ADOAsync prefix is unique enough that false positives in error messages + # are not a concern. + result = re.sub(r"\bADOAsync(\w+)", r"ADO\1", result) + + # --- async/await removal --- + + # async def → def + result = re.sub(r"\basync def\b", "def", result) + + # await expr — strip the await keyword + # Handle multi-line await (await on its own line followed by continuation) + result = re.sub(r"\bawait ", "", result) + + # async with → with + result = re.sub(r"\basync with\b", "with", result) + + # async for → for + result = re.sub(r"\basync for\b", "for", result) + + # --- Restore yield from --- + # async generators can't use "yield from", so the async source uses + # "for item in x: yield item". In sync generators we can restore the + # more concise "yield from x" — but only when the yield is the ONLY + # statement in the loop body (i.e. the next non-blank line is dedented). + result = re.sub( + r"for (\w+) in (.+):\n(\s+)yield \1\n(?=\S|\s*\n\s*(?!\3\S))", + r"yield from \2\n", + result, + ) + + # --- Type hint transformations --- + + # AsyncIterator → Iterator + result = result.replace("AsyncIterator", "Iterator") + + # AsyncGenerator → Generator + result = result.replace("AsyncGenerator", "Generator") + + # Deduplicate imports that may appear after AsyncIterator → Iterator when the + # source already imported Iterator alongside AsyncIterator. + result = re.sub(r"\bIterator,\s*Iterator\b", "Iterator", result) + + # --- httpx client transformations --- + + # httpx.AsyncClient → httpx.Client + result = result.replace("httpx.AsyncClient", "httpx.Client") + + # --- Async method name transformations --- + + # response.aiter_bytes → response.iter_bytes + result = result.replace(".aiter_bytes(", ".iter_bytes(") + + # client.aclose() → client.close() + result = result.replace(".aclose()", ".close()") + + # response.aread() → response.read() + result = result.replace(".aread()", ".read()") + + # --- asyncio transformations --- + + # asyncio.to_thread(fn, args) → fn(args) — sync doesn't need thread offloading. + # First handle the case with arguments: asyncio.to_thread(fn, arg1, arg2) + result = re.sub(r"asyncio\.to_thread\(([^,]+),\s*", r"\1(", result) + # Then handle the no-args case: asyncio.to_thread(fn) + result = re.sub(r"asyncio\.to_thread\(([^)]+)\)", r"\1()", result) + + # asyncio.sleep → time.sleep + result = result.replace("asyncio.sleep(", "time.sleep(") + + # --- Import cleanup --- + + # import asyncio → import time (only if asyncio is used solely for sleep) + result = result.replace("import asyncio\n", "import time\n") + + # Remove pytest_asyncio import (becomes unused in sync tests) + result = result.replace("import pytest_asyncio\n", "") + + # Remove @pytest.mark.asyncio decorators + result = re.sub(r"\s*@pytest\.mark\.asyncio\n", "\n", result) + + # Remove bare "import pytest" when it's no longer used after removing asyncio markers. + # Only remove if pytest is not referenced elsewhere in the file. + if "import pytest\n" in result: + # Count references to "pytest." excluding the import line itself + without_import = result.replace("import pytest\n", "", 1) + if "pytest." not in without_import and "pytest," not in without_import: + result = result.replace("import pytest\n", "") + + # --- Context manager transformations --- + + # asynccontextmanager → contextmanager + result = result.replace("asynccontextmanager", "contextmanager") + + # --- Dunder method transformations --- + + # __aenter__ → __enter__ + result = result.replace("__aenter__", "__enter__") + + # __aexit__ → __exit__ + result = result.replace("__aexit__", "__exit__") + + # __setitem_async__ → __setitem__ (async workaround for sync __setitem__) + result = result.replace("__setitem_async__", "__setitem__") + + # --- follow_redirects ↔ allow_redirects --- + # The sync generated code uses httpx too, so follow_redirects stays. + # No transformation needed here. + + # --- Docstring fixups --- + + # "(async)" → "" in module docstrings, cleaning up trailing space + result = result.replace(" (async)", "") + result = result.replace("(async)", "") + + # "async " in docstrings where it's a description word + # Only replace in specific patterns to avoid over-matching + result = result.replace("Async wrapper", "Wrapper") + result = result.replace("Async auth", "Auth") + result = result.replace("async auth", "auth") + result = result.replace("async authentication", "authentication") + result = result.replace("An async iterator", "An iterator") + result = result.replace("an async iterator", "an iterator") + + # Remove "In the async version..." sentences from docstrings + result = re.sub(r"\n\s+In the async version[^\n]*", "", result) + + # "async context manager" → "context manager" in docstrings + result = result.replace("an async context manager", "a context manager") + result = result.replace("async context manager", "context manager") + + # --- Import ordering fixup --- + # After asyncio → time replacement, stdlib imports may be out of alphabetical order. + # Fix the specific known case. + result = result.replace("import time\nimport contextlib\n", "import contextlib\nimport time\n") + + # Add generated header after the shebang line (if present) so the shebang remains on line 1 + header = GENERATED_HEADER.format(relative_path) + if result.startswith("#!"): + # Insert after the shebang line + newline_idx = result.index("\n") + result = result[: newline_idx + 1] + header + result[newline_idx + 1 :] + else: + result = header + result + + return result + + +def _restore_sync_getitem(source: str) -> str: + """Restore auto-refresh behavior in sync ADOWorkItem.__getitem__. + + The async __getitem__ cannot auto-refresh because Python does not support + ``async def __getitem__``. The sync version has no such limitation, so we + restore the original behavior where ``work_item["field"]`` transparently + refreshes from the server on a cache miss. + + :param source: The transformed sync work_item.py source + :returns: The source with __getitem__ patched to auto-refresh + :raises ValueError: If the expected pattern is not found (signals that the + async source changed and this transform needs updating) + """ + + old = '''\ + def __getitem__(self, key: str | ADOWorkItemBuiltInFields) -> Any: + """Get a field value from the work item. + + Supports both string field names and ADOWorkItemBuiltInFields enum values. + + :param key: The field name or ADOWorkItemBuiltInFields enum value + + :returns: The field value + + :raises KeyError: If the field is not found + """ + # Convert enum to string value if needed + field_name = key.value if isinstance(key, ADOWorkItemBuiltInFields) else key + + # Try to get from fields dict + fields = self._data.get("fields", {}) + if field_name in fields: + return fields[field_name] + + raise KeyError(f"Field '{field_name}' not found in work item {self.id}")''' + + new = '''\ + def __getitem__(self, key: str | ADOWorkItemBuiltInFields) -> Any: + """Get a field value from the work item. + + Supports both string field names and ADOWorkItemBuiltInFields enum values. + If the field is not present in the current data, the work item will be + refreshed from the server to try to populate missing fields. + + :param key: The field name or ADOWorkItemBuiltInFields enum value + + :returns: The field value + + :raises KeyError: If the field is not found even after refresh + """ + # Convert enum to string value if needed + field_name = key.value if isinstance(key, ADOWorkItemBuiltInFields) else key + + # Try to get from fields dict + fields = self._data.get("fields", {}) + if field_name in fields: + return fields[field_name] + + # Field not found — refresh from server (sync only; async must use get_field()) + self._log.debug(f"Field '{field_name}' not found, refreshing work item") + self.refresh() + + # Try again after refresh + fields = self._data.get("fields", {}) + if field_name in fields: + return fields[field_name] + + raise KeyError(f"Field '{field_name}' not found in work item {self.id}")''' + + result = source.replace(old, new) + if result == source: + raise ValueError( + "Failed to apply sync __getitem__ post-transform in work_item.py — " + "the expected pattern was not found. If the async __getitem__ changed, " + "update _restore_sync_getitem() to match." + ) + return result + + +def _add_sync_del(source: str) -> str: + """Add __del__ to sync ADOHTTPClient for silent garbage-collection cleanup. + + httpx.Client emits ResourceWarning when garbage-collected without being + closed. requests.Session did not, so existing consumers never called + close(). Adding __del__ preserves the old silent-GC behavior. + + :param source: The transformed sync http_client.py source + :returns: The source with __del__ inserted after __exit__ + :raises ValueError: If the expected pattern is not found + """ + + anchor = '''\ + def __exit__(self, *args: Any) -> None: + self.close()''' + + replacement = anchor + ''' + + def __del__(self) -> None: + try: + self._client.close() + except Exception: + pass''' + + result = source.replace(anchor, replacement) + if result == source: + raise ValueError( + "Failed to apply __del__ post-transform in http_client.py — " + "the expected __exit__ pattern was not found. If the async source " + "changed, update _add_sync_del() to match." + ) + return result + + +# Map of (rel_prefix, filename) → list of post-transform functions to apply +# after the standard async→sync transformation. +_POST_TRANSFORMS: dict[str, list[Callable[[str], str]]] = { + "work_item.py": [_restore_sync_getitem], + "http_client.py": [_add_sync_del], +} + + +def _apply_post_transforms(source: str, relative_path: str) -> str: + """Apply any file-specific post-transforms to the generated sync source. + + :param source: The already-transformed sync source + :param relative_path: The relative file path (e.g. "work_item.py") + :returns: The source with post-transforms applied + """ + for fn in _POST_TRANSFORMS.get(relative_path, []): + source = fn(source) + return source + + +def transform_init(source: str) -> str: + """Transform the async __init__.py to be the sync top-level __init__.py. + + This is handled specially because the top-level __init__.py also needs to + provide access to the _async subpackage. + + :param source: The async __init__.py source + :returns: The transformed sync __init__.py + """ + + result = transform_source(source, "__init__.py") + + # Add async re-exports at the bottom + async_exports = ''' + +# Async API access +from simple_ado import _async as aio # noqa: F401 — provides simple_ado.aio namespace + +__all__ += ["aio"] +''' + result = result.rstrip() + "\n" + async_exports + + return result + + +def process_directory(async_dir: str, sync_dir: str, rel_prefix: str = "") -> None: + """Recursively transform all .py files from async_dir to sync_dir. + + :param async_dir: The async source directory + :param sync_dir: The sync output directory + :param rel_prefix: The relative path prefix for logging + """ + + for entry in sorted(os.listdir(async_dir)): + async_path = os.path.join(async_dir, entry) + sync_path = os.path.join(sync_dir, entry) + relative = os.path.join(rel_prefix, entry) if rel_prefix else entry + + if os.path.isdir(async_path): + if entry.startswith("__pycache__"): + continue + os.makedirs(sync_path, exist_ok=True) + process_directory(async_path, sync_path, relative) + + elif entry.endswith(".py"): + with open(async_path, "r") as f: + source = f.read() + + if entry == "__init__.py" and rel_prefix == "": + # Top-level __init__.py gets special treatment + transformed = transform_init(source) + else: + transformed = transform_source(source, relative) + + transformed = _apply_post_transforms(transformed, relative) + transformed = _format_with_black(transformed) + + # Check that we're not overwriting a shared file + if rel_prefix == "" and entry in SHARED_FILES: + print(f" SKIP (shared): {relative}") + continue + + # Check if sync file exists and content is the same + if os.path.exists(sync_path): + with open(sync_path, "r") as f: + existing = f.read() + if existing == transformed: + continue + + print(f" Writing: {relative}") + with open(sync_path, "w") as f: + f.write(transformed) + + +def verify_generated_files(sync_dir: str) -> bool: + """Verify all generated .py files compile successfully. + + :param sync_dir: The directory containing generated files + :returns: True if all files compile, False otherwise + """ + ok = True + for root, _dirs, files in os.walk(sync_dir): + for name in sorted(files): + if not name.endswith(".py"): + continue + path = os.path.join(root, name) + try: + py_compile.compile(path, doraise=True) + except py_compile.PyCompileError as exc: + print(f" COMPILE ERROR: {path}: {exc}", file=sys.stderr) + ok = False + return ok + + +def check_directory(async_dir: str, sync_dir: str, rel_prefix: str = "") -> bool: + """Check if generated files are in sync with async source (without writing). + + :param async_dir: The async source directory + :param sync_dir: The sync output directory + :param rel_prefix: The relative path prefix for logging + :returns: True if all files are in sync + """ + in_sync = True + + for entry in sorted(os.listdir(async_dir)): + async_path = os.path.join(async_dir, entry) + sync_path = os.path.join(sync_dir, entry) + relative = os.path.join(rel_prefix, entry) if rel_prefix else entry + + if os.path.isdir(async_path): + if entry.startswith("__pycache__"): + continue + if not check_directory(async_path, sync_path, relative): + in_sync = False + + elif entry.endswith(".py"): + if rel_prefix == "" and entry in SHARED_FILES: + continue + + with open(async_path, "r") as f: + source = f.read() + + if entry == "__init__.py" and rel_prefix == "": + transformed = transform_init(source) + else: + transformed = transform_source(source, relative) + + transformed = _apply_post_transforms(transformed, relative) + transformed = _format_with_black(transformed) + + if not os.path.exists(sync_path): + print(f" MISSING: {relative}") + in_sync = False + else: + with open(sync_path, "r") as f: + existing = f.read() + if existing != transformed: + print(f" OUT OF SYNC: {relative}") + in_sync = False + + return in_sync + + +# --- Test generation --- + +# Header for generated test files +GENERATED_TEST_HEADER = "# THIS FILE IS AUTO-GENERATED FROM tests/unit/_async/{}. DO NOT EDIT.\n\n" + +# Test files that should not be generated (manually maintained in tests/unit/) +SHARED_TEST_FILES = {"__init__.py", "conftest.py"} + + +def transform_test_source(source: str, relative_path: str) -> str: + """Transform an async test file to its synchronous equivalent. + + Applies the standard code transforms first, then test-specific transforms + for pytest markers, fixture names, and imports. + + :param source: The async test source code + :param relative_path: The relative path of the source file (for the header) + :returns: The transformed synchronous test source code + """ + + # Apply all standard code transforms (async def → def, await removal, etc.) + # Use a dummy relative path for the code header — we'll replace it with the test header. + result = transform_source(source, relative_path) + + # Replace the auto-generated code header with a test-specific one + code_header = GENERATED_HEADER.format(relative_path) + test_header = GENERATED_TEST_HEADER.format(relative_path) + result = result.replace(code_header, test_header) + + # --- Test-specific transforms --- + + # Remove @pytest.mark.asyncio lines (including indented ones in test classes) + result = re.sub(r"^\s*@pytest\.mark\.asyncio\n", "", result, flags=re.MULTILINE) + + # @pytest_asyncio.fixture → @pytest.fixture + result = result.replace("@pytest_asyncio.fixture", "@pytest.fixture") + + # Remove import pytest_asyncio lines + result = re.sub(r"^import pytest_asyncio\n", "", result, flags=re.MULTILINE) + + # Fixture name transforms: _async_ → _ (e.g. mock_async_client → mock_client) + result = result.replace("_async_", "_") + + # Remaining async_ prefix in identifiers (e.g. async_http_client → http_client) + result = re.sub(r"\basync_(\w)", r"\1", result) + + return result + + +def _restore_sync_getitem_tests(source: str) -> str: + """Replace the async __getitem__ test with sync auto-refresh tests. + + The async __getitem__ just raises KeyError on missing fields. The sync + version auto-refreshes. This replaces the simple test with two tests that + verify the auto-refresh behavior. + + :param source: The transformed sync test_work_item.py source + :returns: The source with sync-specific __getitem__ tests + :raises ValueError: If the expected pattern is not found + """ + + old = '''\ +def test_work_item_getitem_missing_field_raises( + mock_work_item_data: dict[str, Any], + mock_workitems_client: ADOWorkItemsClient, +) -> None: + """Test that accessing a non-existent field raises KeyError.""" + work_item = ADOWorkItem( + data=mock_work_item_data, + client=mock_workitems_client, + project_id="test-project", + log=logging.getLogger("test"), + ) + + with pytest.raises(KeyError): + _ = work_item["NonExistent.Field"]''' + + new = '''\ +@respx.mock +def test_work_item_getitem_missing_field_refreshes( + mock_work_item_data: dict[str, Any], mock_client: ADOClient, mock_project_id: str +) -> None: + """Test that accessing a missing field auto-refreshes and returns the value.""" + refreshed_data = copy.deepcopy(mock_work_item_data) + refreshed_data["fields"]["System.Reason"] = "Fixed" + + respx.get( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=refreshed_data)) + + work_item = ADOWorkItem( + data=copy.deepcopy(mock_work_item_data), + client=mock_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + assert work_item["System.Reason"] == "Fixed" + + +@respx.mock +def test_work_item_getitem_missing_field_raises_after_refresh( + mock_work_item_data: dict[str, Any], mock_client: ADOClient, mock_project_id: str +) -> None: + """Test that accessing a non-existent field raises KeyError after refresh.""" + respx.get( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=mock_work_item_data)) + + work_item = ADOWorkItem( + data=copy.deepcopy(mock_work_item_data), + client=mock_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + with pytest.raises(KeyError): + _ = work_item["NonExistent.Field"]''' + + result = source.replace(old, new) + if result == source: + raise ValueError( + "Failed to apply sync __getitem__ test post-transform in test_work_item.py — " + "the expected pattern was not found." + ) + return result + + +def _restore_sync_setitem_test(source: str) -> str: + """Replace the async set() test with sync __setitem__ test. + + The async version tests ``await work_item.set(key, value)`` because async + ``__setitem__`` is not possible. The sync version tests ``work_item[key] = value``. + + :param source: The transformed sync test_work_item.py source + :returns: The source with sync __setitem__ test + :raises ValueError: If the expected pattern is not found + """ + + old = '''\ +@respx.mock +def test_work_item_set( + mock_work_item_data: dict[str, Any], + mock_client: ADOClient, + mock_project_id: str, +) -> None: + """Test setting a field using the async set method.""" + updated_data = copy.deepcopy(mock_work_item_data) + updated_data["fields"]["System.Title"] = "New Title" + + respx.patch( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=updated_data)) + + work_item = ADOWorkItem( + data=copy.deepcopy(mock_work_item_data), + client=mock_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + work_item.set("System.Title", "New Title") + + assert work_item["System.Title"] == "New Title"''' + + new = '''\ +@respx.mock +def test_work_item_setitem( + mock_work_item_data: dict[str, Any], + mock_client: ADOClient, + mock_project_id: str, +) -> None: + """Test setting a field using setitem.""" + updated_data = copy.deepcopy(mock_work_item_data) + updated_data["fields"]["System.Title"] = "New Title" + + respx.patch( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=updated_data)) + + work_item = ADOWorkItem( + data=copy.deepcopy(mock_work_item_data), + client=mock_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + work_item["System.Title"] = "New Title" + + assert work_item["System.Title"] == "New Title"''' + + result = source.replace(old, new) + if result == source: + raise ValueError( + "Failed to apply sync __setitem__ test post-transform in test_work_item.py — " + "the expected pattern was not found." + ) + return result + + +# Post-transforms for generated test files +_TEST_POST_TRANSFORMS: dict[str, list[Callable[[str], str]]] = { + "test_work_item.py": [_restore_sync_getitem_tests, _restore_sync_setitem_test], +} + + +def _apply_test_post_transforms(source: str, relative_path: str) -> str: + """Apply any file-specific post-transforms to the generated sync test source. + + :param source: The already-transformed sync test source + :param relative_path: The relative file path (e.g. "test_work_item.py") + :returns: The source with post-transforms applied + """ + for fn in _TEST_POST_TRANSFORMS.get(relative_path, []): + source = fn(source) + return source + + +def process_test_directory(async_dir: str, sync_dir: str) -> None: + """Transform async test files to sync test files. + + :param async_dir: The async test source directory (tests/unit/_async/) + :param sync_dir: The sync test output directory (tests/unit/) + """ + + for entry in sorted(os.listdir(async_dir)): + async_path = os.path.join(async_dir, entry) + sync_path = os.path.join(sync_dir, entry) + + if not entry.endswith(".py") or entry in SHARED_TEST_FILES: + continue + + if os.path.isdir(async_path): + continue + + with open(async_path, "r") as f: + source = f.read() + + transformed = transform_test_source(source, entry) + transformed = _apply_test_post_transforms(transformed, entry) + transformed = _format_with_black(transformed) + + if os.path.exists(sync_path): + with open(sync_path, "r") as f: + existing = f.read() + if existing == transformed: + continue + + print(f" Writing: tests/unit/{entry}") + with open(sync_path, "w") as f: + f.write(transformed) + + +def check_test_directory(async_dir: str, sync_dir: str) -> bool: + """Check if generated test files are in sync with async test source. + + :param async_dir: The async test source directory + :param sync_dir: The sync test output directory + :returns: True if all files are in sync + """ + in_sync = True + + for entry in sorted(os.listdir(async_dir)): + async_path = os.path.join(async_dir, entry) + + if not entry.endswith(".py") or entry in SHARED_TEST_FILES: + continue + + if os.path.isdir(async_path): + continue + + sync_path = os.path.join(sync_dir, entry) + + with open(async_path, "r") as f: + source = f.read() + + transformed = transform_test_source(source, entry) + transformed = _apply_test_post_transforms(transformed, entry) + transformed = _format_with_black(transformed) + + if not os.path.exists(sync_path): + print(f" MISSING: tests/unit/{entry}") + in_sync = False + else: + with open(sync_path, "r") as f: + existing = f.read() + if existing != transformed: + print(f" OUT OF SYNC: tests/unit/{entry}") + in_sync = False + + return in_sync + + +def main() -> int: + """Main entry point.""" + + check_only = "--check" in sys.argv + + if not os.path.isdir(ASYNC_DIR): + print(f"Error: Async source directory not found: {ASYNC_DIR}", file=sys.stderr) + return 1 + + if check_only: + print("Checking sync code is up to date with async source...") + ok = True + if not check_directory(ASYNC_DIR, SYNC_DIR): + ok = False + if os.path.isdir(ASYNC_TEST_DIR): + if not check_test_directory(ASYNC_TEST_DIR, SYNC_TEST_DIR): + ok = False + if not ok: + print() + print("Error: Generated files are out of sync.", file=sys.stderr) + print("Run 'python scripts/generate_sync.py' to regenerate.", file=sys.stderr) + return 1 + print("All generated files are in sync.") + return 0 + + print("Generating sync code from async source...") + print(f" Source: {ASYNC_DIR}") + print(f" Output: {SYNC_DIR}") + print() + + process_directory(ASYNC_DIR, SYNC_DIR) + + if os.path.isdir(ASYNC_TEST_DIR): + print() + print("Generating sync tests from async test source...") + print(f" Source: {ASYNC_TEST_DIR}") + print(f" Output: {SYNC_TEST_DIR}") + print() + process_test_directory(ASYNC_TEST_DIR, SYNC_TEST_DIR) + + print() + print("Verifying generated files compile...") + if not verify_generated_files(SYNC_DIR): + print("Error: Some generated files failed to compile.", file=sys.stderr) + return 1 + + print("Done.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/simple_ado/__init__.py b/simple_ado/__init__.py index 8c0ebc6..9b9b23d 100755 --- a/simple_ado/__init__.py +++ b/simple_ado/__init__.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/__init__.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. @@ -14,20 +16,16 @@ audit, auth as auth_module, builds, - comments, endpoints, - exceptions, git, governance, graph, http_client, identities, - models, pipelines, pools, pull_requests, security, - types, user, wiki, ) @@ -54,6 +52,14 @@ from simple_ado.work_item import ADOWorkItem from simple_ado.workitems import ADOWorkItemsClient +# Re-export submodules +from simple_ado import ( + ado_types, + comments, + exceptions, + models, +) + # Re-export auth_module as auth to maintain public API auth = auth_module @@ -100,7 +106,7 @@ "pools", "pull_requests", "security", - "types", + "ado_types", "user", "wiki", ] @@ -174,6 +180,16 @@ def __init__( self.wiki = ADOWikiClient(self.http_client, self.log) self.workitems = ADOWorkItemsClient(self.http_client, self.log) + def close(self) -> None: + """Close the underlying HTTP client.""" + self.http_client.close() + + def __enter__(self) -> "ADOClient": + return self + + def __exit__(self, *args: Any) -> None: + self.close() + def verify_access(self) -> bool: """Verify that we have access to ADO. @@ -249,7 +265,7 @@ def pull_request( :param project_id: The ID of the project the PR is in :param repository_id: The ID of repository the pull request is on - :returns: A new ADOPullRequest client for the pull request specified + :returns: A new ADOPullRequestClient for the pull request specified """ return ADOPullRequestClient( self.http_client, self.log, pull_request_id, project_id, repository_id @@ -425,3 +441,9 @@ def _canonicalize_branch_name(branch_name: str) -> str: return "refs/heads/" + branch_name return branch_name + + +# Async API access +from simple_ado import _async as aio # noqa: F401 — provides simple_ado.aio namespace + +__all__ += ["aio"] diff --git a/simple_ado/_async/__init__.py b/simple_ado/_async/__init__.py new file mode 100644 index 0000000..4c266bd --- /dev/null +++ b/simple_ado/_async/__init__.py @@ -0,0 +1,442 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO API wrapper (async).""" + +import datetime +import logging +from typing import Any, AsyncIterator +import urllib.parse + +from simple_ado._async import ( + audit, + auth as auth_module, + builds, + endpoints, + git, + governance, + graph, + http_client, + identities, + pipelines, + pools, + pull_requests, + security, + user, + wiki, +) +from simple_ado._async.auth.ado_auth import ADOAsyncAuth +from simple_ado._async.auth.ado_basic_auth import ADOAsyncBasicAuth +from simple_ado._async.auth.ado_token_auth import ADOAsyncTokenAuth +from simple_ado._async.auth.ado_azid_auth import ADOAsyncAzIDAuth +from simple_ado._async.audit import ADOAsyncAuditClient +from simple_ado._async.builds import ADOAsyncBuildClient +from simple_ado._async.endpoints import ADOAsyncEndpointsClient +from simple_ado.exceptions import ADOException, ADOHTTPException +from simple_ado._async.identities import ADOAsyncIdentitiesClient +from simple_ado._async.git import ADOAsyncGitClient +from simple_ado._async.graph import ADOAsyncGraphClient +from simple_ado._async.governance import ADOAsyncGovernanceClient +from simple_ado._async.http_client import ADOAsyncHTTPClient, ADOResponse +from simple_ado.models.pull_requests import ADOPullRequestTimeRangeType +from simple_ado._async.pipelines import ADOAsyncPipelineClient +from simple_ado._async.pools import ADOAsyncPoolsClient +from simple_ado._async.pull_requests import ADOAsyncPullRequestClient, ADOPullRequestStatus +from simple_ado._async.security import ADOAsyncSecurityClient +from simple_ado._async.user import ADOAsyncUserClient +from simple_ado._async.wiki import ADOAsyncWikiClient +from simple_ado._async.work_item import ADOAsyncWorkItem +from simple_ado._async.workitems import ADOAsyncWorkItemsClient + +# Re-export submodules +from simple_ado import ( + ado_types, + comments, + exceptions, + models, +) + +# Re-export auth_module as auth to maintain public API +auth = auth_module + +__all__ = [ + "ADOAsyncAuditClient", + "ADOAsyncAuth", + "ADOAsyncAzIDAuth", + "ADOAsyncBasicAuth", + "ADOAsyncBuildClient", + "ADOAsyncEndpointsClient", + "ADOException", + "ADOAsyncGitClient", + "ADOAsyncGovernanceClient", + "ADOAsyncGraphClient", + "ADOAsyncHTTPClient", + "ADOHTTPException", + "ADOAsyncIdentitiesClient", + "ADOAsyncPipelineClient", + "ADOAsyncPoolsClient", + "ADOAsyncPullRequestClient", + "ADOPullRequestStatus", + "ADOPullRequestTimeRangeType", + "ADOResponse", + "ADOAsyncSecurityClient", + "ADOAsyncTokenAuth", + "ADOAsyncUserClient", + "ADOAsyncWikiClient", + "ADOAsyncWorkItem", + "ADOAsyncWorkItemsClient", + # Submodules + "audit", + "auth", + "builds", + "comments", + "endpoints", + "exceptions", + "git", + "governance", + "graph", + "http_client", + "identities", + "models", + "pipelines", + "pools", + "pull_requests", + "security", + "ado_types", + "user", + "wiki", +] + + +class ADOAsyncClient: + """Async wrapper class around the ADO API. + + :param tenant: The ADO tenant to connect to + :param auth: The auth details to use for the API connection + :param user_agent: The user agent to set + :param extra_headers: Any extra headers which should be sent with the API requests + :param log: The logger to use for logging (a new one will be used if one is not supplied) + """ + + # pylint: disable=too-many-instance-attributes + + log: logging.Logger + + http_client: ADOAsyncHTTPClient + + audit: ADOAsyncAuditClient + builds: ADOAsyncBuildClient + endpoints: ADOAsyncEndpointsClient + git: ADOAsyncGitClient + governance: ADOAsyncGovernanceClient + graph: ADOAsyncGraphClient + identities: ADOAsyncIdentitiesClient + pipelines: ADOAsyncPipelineClient + pools: ADOAsyncPoolsClient + security: ADOAsyncSecurityClient + user: ADOAsyncUserClient + wiki: ADOAsyncWikiClient + workitems: ADOAsyncWorkItemsClient + + def __init__( + self, + *, + tenant: str, + auth: ADOAsyncAuth, # pylint: disable=redefined-outer-name + user_agent: str | None = None, + extra_headers: dict[str, str] | None = None, + log: logging.Logger | None = None, + ) -> None: + """Construct a new client object.""" + + if log is None: + self.log = logging.getLogger("ado") + else: + self.log = log.getChild("ado") + + self.http_client = ADOAsyncHTTPClient( + tenant=tenant, + auth=auth, + user_agent=user_agent if user_agent is not None else tenant, + log=self.log, + extra_headers=extra_headers, + ) + + self.audit = ADOAsyncAuditClient(self.http_client, self.log) + self.builds = ADOAsyncBuildClient(self.http_client, self.log) + self.endpoints = ADOAsyncEndpointsClient(self.http_client, self.log) + self.identities = ADOAsyncIdentitiesClient(self.http_client, self.log) + self.git = ADOAsyncGitClient(self.http_client, self.log) + self.governance = ADOAsyncGovernanceClient(self.http_client, self.log) + self.graph = ADOAsyncGraphClient(self.http_client, self.log) + self.pipelines = ADOAsyncPipelineClient(self.http_client, self.log) + self.pools = ADOAsyncPoolsClient(self.http_client, self.log) + self.security = ADOAsyncSecurityClient(self.http_client, self.log) + self.user = ADOAsyncUserClient(self.http_client, self.log) + self.wiki = ADOAsyncWikiClient(self.http_client, self.log) + self.workitems = ADOAsyncWorkItemsClient(self.http_client, self.log) + + async def close(self) -> None: + """Close the underlying HTTP client.""" + await self.http_client.close() + + async def __aenter__(self) -> "ADOAsyncClient": + return self + + async def __aexit__(self, *args: Any) -> None: + await self.close() + + async def verify_access(self) -> bool: + """Verify that we have access to ADO. + + :returns: True if we have access, False otherwise + """ + + request_url = ( + self.http_client.api_endpoint(is_default_collection=False) + "/projects?api-version=7.1" + ) + + try: + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + self.http_client.extract_value(response_data) + except ADOException: + return False + + return True + + async def create_pull_request( + self, + *, + source_branch: str, + target_branch: str, + project_id: str, + repository_id: str, + title: str | None = None, + description: str | None = None, + reviewer_ids: list[str] | None = None, + ) -> ADOResponse: + """Creates a pull request with the given information + + :param source_branch: The source branch of the pull request + :param target_branch: The target branch of the pull request + :param project_id: The ID of the project + :param repository_id: The ID for the repository + :param title: The title of the pull request + :param description: The description of the pull request + :param reviewer_ids: The reviewer IDs to be added to the pull request + + :returns: The ADO response with the data in it + + :raises ADOException: If we fail to create the pull request + """ + self.log.debug("Creating pull request") + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/git/repositories/{repository_id}" + request_url += "/pullRequests?api-version=7.1" + + body: dict[str, Any] = { + "sourceRefName": _canonicalize_branch_name(source_branch), + "targetRefName": _canonicalize_branch_name(target_branch), + } + + if title is not None: + body["title"] = title + + if description is not None: + body["description"] = description + + if reviewer_ids is not None and len(reviewer_ids) > 0: + body["reviewers"] = [{"id": reviewer_id} for reviewer_id in reviewer_ids] + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) + + def pull_request( + self, pull_request_id: int, project_id: str, repository_id: str + ) -> ADOAsyncPullRequestClient: + """Get an ADOAsyncPullRequestClient for the PR identifier. + + :param pull_request_id: The ID of the pull request to create the client for + :param project_id: The ID of the project the PR is in + :param repository_id: The ID of repository the pull request is on + + :returns: A new ADOAsyncPullRequestClient for the pull request specified + """ + return ADOAsyncPullRequestClient( + self.http_client, self.log, pull_request_id, project_id, repository_id + ) + + # pylint: disable=too-many-locals,too-complex,too-many-branches + async def list_all_pull_requests( + self, + *, + project_id: str, + top: int | None = None, + creator_id: str | None = None, + include_links: bool | None = None, + max_time: datetime.datetime | None = None, + min_time: datetime.datetime | None = None, + query_time_range_type: ADOPullRequestTimeRangeType | None = None, + repository_id: str | None = None, + reviewer_id: str | None = None, + branch_name: str | None = None, # TODO: Rename to source_ref_name + source_repo_id: str | None = None, + pr_status: ADOPullRequestStatus | None = None, + target_ref_name: str | None = None, + title: str | None = None, + ) -> AsyncIterator[Any]: + """Get the pull requests matching the specified criteria from ADO. + + :param project_id: The ID of the project. + :param top: The number of pull requests to retrieve per page. + :param creator_id: If set, search for pull requests that were created by this identity (Team Foundation ID). + :param include_links: Whether to include the _links field on the shallow references. + :param max_time: If specified, filters pull requests that were created/closed before this date based on the + queryTimeRangeType specified. + :param min_time: If specified, filters pull requests that were created/closed after this date based on the + queryTimeRangeType specified. + :param query_time_range_type: The type of time range to use for min_time and max_time filtering. Defaults to + Created if unset. + :param repository_id: If set, search for pull requests whose target branch is in this repository. + :param reviewer_id: If set, search for pull requests that have this identity as a reviewer (Team Foundation ID). + :param branch_name: If set, search for pull requests from this source branch. + :param source_repo_id: If set, search for pull requests whose source branch is in this repository. + :param pr_status: If set, search for pull requests that are in this state. Defaults to Active if unset. + :param target_ref_name: If set, search for pull requests into this target branch. + :param title: If set, filters pull requests that contain the specified text in the title. + + :returns: An async iterator yielding pull request data dictionaries. + """ + + self.log.debug("Fetching PRs") + + offset = 0 + + while True: + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + "/git/pullrequests?" + ) + + parameters: dict[str, Any] = {"$skip": offset, "api-version": "7.2-preview.2"} + + if top: + parameters["$top"] = top + + if creator_id: + parameters["searchCriteria.creatorId"] = creator_id + + if include_links is not None: + parameters["searchCriteria.includeLinks"] = str(include_links).lower() + + if max_time: + parameters["searchCriteria.maxTime"] = max_time.strftime("%Y-%m-%dT%H:%M:%S.000Z") + + if min_time: + parameters["searchCriteria.minTime"] = min_time.strftime("%Y-%m-%dT%H:%M:%S.000Z") + + if query_time_range_type: + parameters["searchCriteria.queryTimeRangeType"] = query_time_range_type.value + + if repository_id: + parameters["searchCriteria.repositoryId"] = repository_id + + if reviewer_id: + parameters["searchCriteria.reviewerId"] = reviewer_id + + if source_repo_id: + parameters["searchCriteria.sourceRepositoryId"] = source_repo_id + + if pr_status: + parameters["searchCriteria.status"] = pr_status.value + + if title: + parameters["searchCriteria.title"] = title + + encoded_parameters = urllib.parse.urlencode(parameters) + + request_url += encoded_parameters + + # ADO doesn't like it if the `/` in branch references are encoded, so we just append them manually + + if branch_name is not None: + request_url += ( + f"&searchCriteria.sourceRefName={_canonicalize_branch_name(branch_name)}" + ) + + if target_ref_name is not None: + request_url += ( + f"&searchCriteria.targetRefName={_canonicalize_branch_name(target_ref_name)}" + ) + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + + extracted = self.http_client.extract_value(response_data) + + if len(extracted) == 0: + break + + for item in extracted: + yield item + + offset += len(extracted) + + # pylint: enable=too-many-locals,too-complex,too-many-branches + + async def custom_get( + self, + *, + url_fragment: str, + parameters: dict[str, Any], + is_default_collection: bool = True, + is_internal: bool = False, + subdomain: str | None = None, + project_id: str | None = None, + ) -> ADOResponse: + """Perform a custom GET REST request. + + We don't always expose everything that would be preferred to the end + user, so to make it a little easier, we expose this method which lets + the user perform an arbitrary GET request, but where we supply the base + information. + + We only support GET requests as anything else is too complex to be + exposed in a generic manner. For these cases, the requests should be + built manually. + + :param url_fragment: The part of the URL that comes after `_apis/` + :param parameters: The URL parameters to append + :param is_default_collection: Whether this URL should start with the path "/DefaultCollection" + :param is_internal: Whether this URL should use internal API endpoint "/_api" + :param subdomain: A subdomain that should be used (if any) + :param project_id: The project ID (if required) + + :returns: The raw response + """ + + encoded_parameters = urllib.parse.urlencode(parameters) + request_url = self.http_client.api_endpoint( + is_default_collection=is_default_collection, + is_internal=is_internal, + subdomain=subdomain, + project_id=project_id, + ) + request_url += f"/{url_fragment}?{encoded_parameters}" + + return await self.http_client.get(request_url) + + +def _canonicalize_branch_name(branch_name: str) -> str: + """Cleanup the branch name before sending it via ADO request + + :param branch_name: The name of the branch to cleanup + + :returns: The cleaned up branch name to send via ADO request + """ + if not branch_name.startswith("refs/heads/"): + return "refs/heads/" + branch_name + + return branch_name diff --git a/simple_ado/_async/audit.py b/simple_ado/_async/audit.py new file mode 100644 index 0000000..4e4768b --- /dev/null +++ b/simple_ado/_async/audit.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO audit API wrapper (async).""" + +import datetime +import logging +from typing import Any, AsyncIterator +import urllib.parse + +import deserialize + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado._async.http_client import ADOAsyncHTTPClient +from simple_ado.models import AuditActionInfo + + +class ADOAsyncAuditClient(ADOAsyncBaseClient): + """Wrapper class around the ADO Audit APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("audit")) + + async def get_actions(self, area_name: str | None = None) -> list[AuditActionInfo]: + """Get the list of audit actions. + + :param area_name: The optional area name to scope down to + + :returns: The ADO response with the data in it + """ + + self.log.debug("Getting audit actions") + + parameters = {"api-version": "7.1-preview.1"} + + if area_name: + parameters["areaName"] = area_name + + request_url = self.http_client.audit_endpoint() + "/audit/actions?" + request_url += urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + raw_actions = self.http_client.extract_value(response_data) + return deserialize.deserialize(list[AuditActionInfo], raw_actions) + + async def query( + self, + start_time: datetime.datetime | None = None, + end_time: datetime.datetime | None = None, + skip_aggregation: bool | None = None, + ) -> AsyncIterator[dict[str, Any]]: + """Query the audit log. + + :param start_time: The earliest point to query (rounds down to the nearest second) + :param end_time: The latest point to query (rounds down to the nearest second) + :param skip_aggregation: Set to False to avoid aggregating events + + :returns: The queried log + """ + + parameters = {"api-version": "7.1-preview.1"} + + if start_time: + parameters["startTime"] = start_time.strftime("%Y-%m-%dT%H:%M:%S.000Z") + + if end_time: + parameters["endTime"] = end_time.strftime("%Y-%m-%dT%H:%M:%S.000Z") + + if skip_aggregation: + parameters["skipAggregation"] = str(skip_aggregation).lower() + + request_url = f"{self.http_client.audit_endpoint()}/audit/auditlog?" + request_url += urllib.parse.urlencode(parameters) + + url = request_url + + while True: + response = await self.http_client.get(url) + decoded = self.http_client.decode_response(response) + for entry in decoded["decoratedAuditLogEntries"]: + yield entry + + if not decoded.get("hasMore"): + return + + continuation_token = decoded["continuationToken"] + url = request_url + f"&continuationToken={continuation_token}" diff --git a/simple_ado/_async/auth/__init__.py b/simple_ado/_async/auth/__init__.py new file mode 100644 index 0000000..5cbd00f --- /dev/null +++ b/simple_ado/_async/auth/__init__.py @@ -0,0 +1,14 @@ +"""Umbrella module for all async authentication classes.""" + +from .ado_auth import ADOAsyncAuth +from .ado_basic_auth import ADOAsyncBasicAuth +from .ado_token_auth import ADOAsyncTokenAuth +from .ado_azid_auth import ADOAsyncAzIDAuth + +# Set the module's public interface +__all__ = [ + "ADOAsyncAuth", + "ADOAsyncBasicAuth", + "ADOAsyncTokenAuth", + "ADOAsyncAzIDAuth", +] diff --git a/simple_ado/_async/auth/ado_auth.py b/simple_ado/_async/auth/ado_auth.py new file mode 100644 index 0000000..4119168 --- /dev/null +++ b/simple_ado/_async/auth/ado_auth.py @@ -0,0 +1,21 @@ +"""Base async auth class.""" + +import abc + + +class ADOAsyncAuth(abc.ABC): + """Base class for async authentication.""" + + @abc.abstractmethod + async def get_authorization_header(self) -> str: + """Get the header value. + + :return: The header value.""" + raise NotImplementedError() + + async def close(self) -> None: + """Close any resources held by this auth instance. + + Subclasses that hold closeable resources (e.g. credential objects) + should override this method. The default implementation is a no-op. + """ diff --git a/simple_ado/_async/auth/ado_azid_auth.py b/simple_ado/_async/auth/ado_azid_auth.py new file mode 100644 index 0000000..dbf8c6e --- /dev/null +++ b/simple_ado/_async/auth/ado_azid_auth.py @@ -0,0 +1,36 @@ +"""Azure Identity async authentication auth class.""" + +import time + +from azure.identity.aio import DefaultAzureCredential +from azure.core.credentials import AccessToken +from simple_ado._async.auth.ado_auth import ADOAsyncAuth + + +class ADOAsyncAzIDAuth(ADOAsyncAuth): + """Azure Identity async auth.""" + + access_token: AccessToken | None + _credential: DefaultAzureCredential + + def __init__(self) -> None: + self.access_token = None + self._credential = DefaultAzureCredential() + + async def get_authorization_header(self) -> str: + """Get the header value. + + :return: The header value.""" + + # The get_token parameter specifies the Azure DevOps resource and requests a token with + # default permissions for API access. + if self.access_token is None or self.access_token.expires_on <= time.time() + 60: + self.access_token = await self._credential.get_token( + "499b84ac-1321-427f-aa17-267ca6975798/.default" + ) + + return "Bearer " + self.access_token.token + + async def close(self) -> None: + """Close the underlying credential.""" + await self._credential.close() diff --git a/simple_ado/_async/auth/ado_basic_auth.py b/simple_ado/_async/auth/ado_basic_auth.py new file mode 100644 index 0000000..c5cec2e --- /dev/null +++ b/simple_ado/_async/auth/ado_basic_auth.py @@ -0,0 +1,30 @@ +"""Basic authentication async auth class.""" + +import base64 +from simple_ado._async.auth.ado_auth import ADOAsyncAuth + + +class ADOAsyncBasicAuth(ADOAsyncAuth): + """Username/password auth. Also supports PATs.""" + + username: str + password: str + _cached_header: str | None + + def __init__(self, username: str, password: str) -> None: + self.username = username + self.password = password + self._cached_header = None + + async def get_authorization_header(self) -> str: + """Get the header value. + + :return: The header value.""" + + if self._cached_header is None: + username_password_bytes = (self.username + ":" + self.password).encode("utf-8") + self._cached_header = "Basic " + base64.b64encode(username_password_bytes).decode( + "utf-8" + ) + + return self._cached_header diff --git a/simple_ado/_async/auth/ado_token_auth.py b/simple_ado/_async/auth/ado_token_auth.py new file mode 100644 index 0000000..12882ba --- /dev/null +++ b/simple_ado/_async/auth/ado_token_auth.py @@ -0,0 +1,19 @@ +"""Token authentication async auth class.""" + +from simple_ado._async.auth.ado_auth import ADOAsyncAuth + + +class ADOAsyncTokenAuth(ADOAsyncAuth): + """Token auth.""" + + token: str + + def __init__(self, token: str) -> None: + self.token = token + + async def get_authorization_header(self) -> str: + """Get the header value. + + :return: The header value.""" + + return "Bearer " + self.token diff --git a/simple_ado/_async/base_client.py b/simple_ado/_async/base_client.py new file mode 100644 index 0000000..b173fd9 --- /dev/null +++ b/simple_ado/_async/base_client.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Base ADO Client (async).""" + +import logging + +from simple_ado._async.http_client import ADOAsyncHTTPClient + + +class ADOAsyncBaseClient: + """Base client for ADO API. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + log: logging.Logger + + http_client: ADOAsyncHTTPClient + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + """Construct a new base client object.""" + + self.log = log + self.http_client = http_client diff --git a/simple_ado/_async/builds.py b/simple_ado/_async/builds.py new file mode 100644 index 0000000..98d04a6 --- /dev/null +++ b/simple_ado/_async/builds.py @@ -0,0 +1,495 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO build API wrapper (async).""" + +import enum +import json +import logging +from typing import Any, AsyncIterator, Callable, cast +from urllib.parse import SplitResult +import urllib.parse + + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado._async.http_client import ADOAsyncHTTPClient, ADOResponse +from simple_ado._async.utilities import download_from_response_stream +from simple_ado.exceptions import ADOHTTPException +from simple_ado.ado_types import TeamFoundationId + + +class BuildQueryOrder(enum.Enum): + """The order for the build queries to be returned in.""" + + FINISH_TIME_ASCENDING = "finishTimeAscending" + FINISH_TIME_DESCENDING = "finishTimeDescending" + QUEUE_TIME_ASCENDING = "queueTimeAscending" + QUEUE_TIME_DESCENDING = "queueTimeDescending" + START_TIME_ASCENDING = "startTimeAscending" + START_TIME_DESCENDING = "startTimeDescending" + + +class StageUpdateType(enum.Enum): + """The type of update to perform on a stage.""" + + RETRY = "retry" + CANCEL = "cancel" + + +class ADOAsyncBuildClient(ADOAsyncBaseClient): + """Wrapper class around the ADO Build APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("build")) + + async def queue_build( + self, + *, + project_id: str, + definition_id: int, + source_branch: str, + variables: dict[str, str], + requesting_identity: TeamFoundationId | None = None, + ) -> ADOResponse: + """Queue a new build. + + :param project_id: The ID of the project + :param definition_id: The identity of the build definition to queue (can be a string) + :param source_branch: The source branch for the build + :param variables: A dictionary of variables to pass to the definition + :param requesting_identity: The identity of the user who requested the build be queued + + :returns: The ADO response with the data in it + """ + + request_url = ( + f"{self.http_client.api_endpoint(project_id=project_id)}/build/builds?api-version=7.1" + ) + variable_json = json.dumps(variables) + + self.log.debug(f"Queueing build ({definition_id}): {variable_json}") + + body: dict[str, Any] = { + "parameters": variable_json, + "definition": {"id": definition_id}, + "sourceBranch": source_branch, + } + + if requesting_identity: + body["requestedFor"] = {"id": requesting_identity} + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) + + async def build_info(self, *, project_id: str, build_id: int) -> ADOResponse: + """Get the info for a build. + + :param project_id: The ID of the project + :param build_id: The identifier of the build to get the info for + + :returns: The ADO response with the data in it + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/build/builds/{build_id}?api-version=7.1" + ) + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def get_builds( + self, + *, + project_id: str, + definitions: list[int] | None = None, + order: BuildQueryOrder | None = None, + ) -> AsyncIterator[dict[str, Any]]: + """Get the info for a build. + + :param project_id: The ID of the project + :param definitions: An optional list of build definition IDs to filter on + :param order: The order of the builds to return + + :returns: The ADO response with the data in it + """ + + request_url = self.http_client.api_endpoint(project_id=project_id) + "/build/builds/?" + + parameters = { + "api-version": "7.1", + } + + if definitions: + parameters["definitions"] = ",".join(map(str, definitions)) + + if order: + parameters["queryOrder"] = order.value + + request_url += urllib.parse.urlencode(parameters) + + url = request_url + + while True: + response = await self.http_client.get(url) + decoded = self.http_client.decode_response(response) + for item in decoded["value"]: + yield item + + continuation_token = response.headers.get( + "X-MS-ContinuationToken", response.headers.get("x-ms-continuationtoken") + ) + + if not continuation_token: + break + + url = request_url + f"&continuationToken={continuation_token}" + + async def list_artifacts(self, *, project_id: str, build_id: int) -> ADOResponse: + """List the artifacts for a build. + + :param project_id: The ID of the project + :param build_id: The ID of the build + :returns: The ADO response with the data in it + """ + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/build/builds/{build_id}/artifacts?api-version=7.2-preview.5" + ) + + self.log.debug(f"Fetching artifacts for build {build_id}...") + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def get_artifact_info( + self, *, project_id: str, build_id: int, artifact_name: str + ) -> ADOResponse: + """Fetch an artifacts details from a build. + + :param project_id: The ID of the project + :param build_id: The ID of the build + :param artifact_name: The name of the artifact to fetch + + :returns: The ADO response with the data in it + """ + + parameters = { + "artifactName": artifact_name, + "api-version": "7.1", + } + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/build/builds/{build_id}/artifacts?" + request_url += urllib.parse.urlencode(parameters) + + self.log.debug(f"Fetching artifact {artifact_name} from build {build_id}...") + + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def get_build_timeline(self, *, project_id: str, build_id: int) -> ADOResponse: + """Get the build timeline. + + :param project_id: The ID of the project + :param build_id: The ID of the build + + :returns: The ADO response with the build timeline data in it + """ + self.log.debug(f"Fetching build timeline for build {build_id}...") + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/build/builds/{build_id}/timeline?api-version=7.1" + ) + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def download_artifact( + self, + *, + project_id: str, + build_id: int, + artifact_name: str, + output_path: str, + progress_callback: Callable[[int, int], None] | None = None, + ) -> None: + """Download an artifact from a build. + + :param project_id: The ID of the project + :param build_id: The ID of the build + :param artifact_name: The name of the artifact to fetch + :param output_path: The path to write the output to. + :param progress_callback: An optional callback to call with the number of bytes downloaded and total size + """ + + parameters = { + "artifactName": artifact_name, + "$format": "zip", + "api-version": "7.1", + } + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/build/builds/{build_id}/artifacts?" + request_url += urllib.parse.urlencode(parameters) + + self.log.debug(f"Fetching artifact {artifact_name} from build {build_id}...") + + # ADO redirects artifact downloads to a different domain. We follow redirects manually + # to enforce that only .visualstudio.com domains are accepted, preventing potential + # open-redirect attacks. + url = request_url + + while True: + async with self.http_client.stream_get( + url, follow_redirects=False, set_accept_json=False + ) as response: + if response.status_code < 300 or response.status_code >= 400: + await download_from_response_stream( + response=response, + output_path=output_path, + log=self.log, + callback=progress_callback, + ) + return + + location = response.headers.get("location") + + if not location: + # Read the body before raising so downstream code can inspect the response + await response.aread() + raise ADOHTTPException( + f"ADO returned a redirect status code without a location header, configuration={self}", + response, + ) + + parts = cast(SplitResult, urllib.parse.urlsplit(location)) + + if parts.hostname and not parts.hostname.endswith(".visualstudio.com"): + await response.aread() + raise ADOHTTPException( + "ADO returned a redirect status code with a location header that is not on visualstudio.com, " + + f"configuration={self}", + response, + ) + + url = location + + async def get_file_manifest( + self, + *, + project_id: str, + build_id: int, + artifact_name: str, + artifact_id: str, + ) -> ADOResponse: + """Download an artifact from a build. + + :param project_id: The ID of the project + :param build_id: The ID of the build + :param artifact_name: The name of the artifact to fetch + :param artifact_id: The ID of the artifact to download (artifact.resource.data) + """ + + parameters = { + "artifactName": artifact_name, + "fileId": artifact_id, + "fileName": "", + "api-version": "7.2-preview.5", + } + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/build/builds/{build_id}/artifacts?" + request_url += urllib.parse.urlencode(parameters) + + self.log.debug( + f"Fetching file manifest from artifact {artifact_name} from build {build_id}..." + ) + + response = await self.http_client.get(request_url) + + return self.http_client.decode_response(response) + + async def download_file( + self, + *, + project_id: str, + build_id: int, + artifact_name: str, + file_id: str, + file_name: str, + output_path: str, + progress_callback: Callable[[int, int], None] | None = None, + ) -> None: + """Download an artifact from a build. + + :param project_id: The ID of the project + :param build_id: The ID of the build + :param artifact_name: The name of the artifact to fetch + :param file_id: The ID of the file to download + :param file_name: The name of the file to download + :param output_path: The path to write the output to. + :param progress_callback: An optional callback to call with the number of bytes downloaded and total size + """ + + parameters = { + "artifactName": artifact_name, + "fileId": file_id, + "fileName": file_name, + "api-version": "7.2-preview.5", + } + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/build/builds/{build_id}/artifacts?" + request_url += urllib.parse.urlencode(parameters) + + self.log.debug( + f"Fetching file {file_name} from artifact {artifact_name} from build {build_id}..." + ) + + async with self.http_client.stream_get(request_url) as response: + await download_from_response_stream( + response=response, + output_path=output_path, + log=self.log, + callback=progress_callback, + ) + + async def get_leases(self, *, project_id: str, build_id: int) -> ADOResponse: + """Get the retention leases for a build. + + :param project_id: The ID of the project + :param build_id: The ID of the build + + :returns: The ADO response with the data in it + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/build/builds/{build_id}/leases?api-version=7.1-preview.1" + ) + + self.log.debug(f"Fetching leases for build {build_id}...") + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def delete_leases(self, *, project_id: str, lease_ids: int | list[int]) -> None: + """Delete leases. + + :param project_id: The ID of the project + :param lease_ids: The IDs of the leases to delete + """ + + if isinstance(lease_ids, int): + all_ids = [lease_ids] + else: + all_ids = lease_ids + + ids = ",".join(map(str, all_ids)) + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/build/retention/leases?api-version=7.1-preview.2&ids={ids}" + ) + + self.log.debug(f"Deleting leases '{ids}'...") + + response = await self.http_client.delete(request_url) + self.http_client.validate_response(response) + + async def get_definitions(self, *, project_id: str) -> ADOResponse: + """Get all definitions + + :param project_id: The ID of the project + + :returns: The project's definitions + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + "/build/definitions?api-version=7.1" + ) + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def get_definition(self, *, project_id: str, definition_id: int) -> ADOResponse: + """Get all definitions + + :param project_id: The ID of the project + :param definition_id: The identifier of the definition to get + + :returns: A definition + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/build/definitions/{definition_id}?api-version=7.1" + ) + + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def delete_definition(self, *, project_id: str, definition_id: int) -> None: + """Delete a definition and all associated builds. + + :param project_id: The ID of the project + :param definition_id: The identifier of the definition to delete + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/build/definitions/{definition_id}?api-version=7.1-preview.7" + ) + + response = await self.http_client.delete(request_url) + self.http_client.validate_response(response) + + async def patch_stage( + self, + *, + project_id: str, + build_id: int, + stage_name: str, + force_retry_all_jobs: bool, + state: StageUpdateType, + ) -> None: + """Re-run the failed jobs on a build. + + :param project_id: The ID of the project + :param build_id: The identifier of the build to re-run the jobs on + :param stage_name: The name (identifier) of the stage to re-run. + :param force_retry_all_jobs: Whether to force retry all jobs in the stage, even if they succeeded + :param state: The type of update to perform on the stage + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/build/builds/{build_id}/stages/{stage_name}?api-version=7.1" + ) + + body: dict[str, Any] = {"forceRetryAllJobs": force_retry_all_jobs, "state": state.value} + + response = await self.http_client.patch(request_url, json_data=body) + self.http_client.validate_response(response) + + async def rerun_failed_jobs(self, *, project_id: str, build_id: int, stage_name: str) -> None: + """Re-run the failed jobs on a build. + + :param project_id: The ID of the project + :param build_id: The identifier of the build to re-run the jobs on + :param stage_name: The name (identifier) of the stage to re-run. + """ + + await self.patch_stage( + project_id=project_id, + build_id=build_id, + stage_name=stage_name, + force_retry_all_jobs=False, + state=StageUpdateType.RETRY, + ) diff --git a/simple_ado/_async/endpoints.py b/simple_ado/_async/endpoints.py new file mode 100644 index 0000000..bf332a4 --- /dev/null +++ b/simple_ado/_async/endpoints.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO service endpoints API wrapper (async).""" + +import logging +from typing import Any, AsyncIterator +import urllib.parse + + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado._async.http_client import ADOAsyncHTTPClient, ADOResponse + + +class ADOAsyncEndpointsClient(ADOAsyncBaseClient): + """Wrapper class around the ADO service endpoints APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("endpoints")) + + async def get_endpoints( + self, project_id: str, *, endpoint_type: str | None = None + ) -> ADOResponse: + """Gets the service endpoints. + + :param project_id: The identifier for the project + :param endpoint_type: The type to filter down to. + + :returns: The ADO response with the data in it + """ + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + "/serviceendpoint/endpoints?" + ) + + parameters = {"api-version": "7.1"} + + if endpoint_type: + parameters["type"] = endpoint_type + + request_url += urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def get_usage_history( + self, *, project_id: str, endpoint_id: str, top: int | None = None + ) -> AsyncIterator[dict[str, Any]]: + """Gets the usage history for an endpoint. + + :param project_id: The identifier for the project + :param endpoint_id: The endpoint to get the history for + :param top: If set, get this number of results + + :returns: The ADO response with the data in it + """ + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/serviceendpoint/{endpoint_id}/executionhistory?" + ) + + parameters: dict[str, Any] = {"api-version": "7.1"} + + if not top or top < 50: + parameters["top"] = top + else: + parameters["top"] = 50 + + request_url += urllib.parse.urlencode(parameters) + + url = request_url + + returned = 0 + + while True: + response = await self.http_client.get(url) + decoded = self.http_client.decode_response(response) + for use in decoded["value"]: + yield use + returned += 1 + + if top and returned >= top: + return + + if "X-MS-ContinuationToken" not in response.headers: + return + + continuation_token = response.headers["X-MS-ContinuationToken"] + url = request_url + f"&continuationToken={continuation_token}" diff --git a/simple_ado/_async/git.py b/simple_ado/_async/git.py new file mode 100644 index 0000000..0cfa306 --- /dev/null +++ b/simple_ado/_async/git.py @@ -0,0 +1,821 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO Git API wrapper (async).""" + +import enum +import logging +import os +from typing import Any, Callable, cast +import urllib.parse + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado._async.http_client import ADOAsyncHTTPClient, ADOResponse +from simple_ado._async.utilities import download_from_response_stream +from simple_ado.exceptions import ADOException + + +class ADOGitStatusState(enum.Enum): + """Possible values of git status states.""" + + NOT_SET = "notSet" + NOT_APPLICABLE = "notApplicable" + PENDING = "pending" + SUCCEEDED = "succeeded" + FAILED = "failed" + ERROR = "error" + + +class ADOReferenceUpdate: + """Contains the relevant details about a reference update. + + :param name: The full name of the reference to update. e.g. refs/heads/my_branch + :param old_object_id: The ID that the reference previously pointed to + :param new_object_id: The ID that the reference should point to + """ + + name: str + old_object_id: str + new_object_id: str + + def __init__(self, name: str, old_object_id: str | None, new_object_id: str | None) -> None: + self.name = name + + if old_object_id: + self.old_object_id = old_object_id + else: + self.old_object_id = "0000000000000000000000000000000000000000" + + if new_object_id: + self.new_object_id = new_object_id + else: + self.new_object_id = "0000000000000000000000000000000000000000" + + def json_data(self) -> dict[str, str]: + """Return the JSON representation for sending to ADO. + + :returns: The JSON representation + """ + + return { + "name": self.name, + "oldObjectId": self.old_object_id, + "newObjectId": self.new_object_id, + } + + +class ADOAsyncGitClient(ADOAsyncBaseClient): + """Wrapper class around the ADO Git APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("git")) + + async def all_repositories(self, project_id: str) -> ADOResponse: + """Get a list of repositories in the project. + + :param project_id: The ID of the project + + :returns: The ADO response with the data in it + """ + self.log.debug("Getting repositories") + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/git/repositories/?api-version=7.1" + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def get_repository(self, *, project_id: str, repository_id: str) -> ADOResponse: + """Get a repository from the project. + + :param project_id: The ID of the project + :param repository_id: The ID of the repository + + :returns: The ADO response with the data in it + """ + self.log.debug(f"Getting repository {repository_id}") + request_url = ( + f"{self.http_client.api_endpoint(project_id=project_id)}/git/" + + f"repositories/{repository_id}?api-version=7.1" + ) + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def get_status(self, *, sha: str, project_id: str, repository_id: str) -> ADOResponse: + """Set a status on a PR. + + :param sha: The SHA of the commit to add the status to. + :param project_id: The ID of the project + :param repository_id: The ID for the repository + + :returns: The ADO response with the data in it + + :raises ADOException: If the SHA is not the full version + """ + + self.log.debug(f"Getting status for sha: {sha}") + + if len(sha) != 40: + raise ADOException("The SHA for a commit must be the full 40 character version") + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/git/repositories/{repository_id}/commits/{sha}/statuses?api-version=7.1" + ) + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def set_status( + self, + *, + sha: str, + state: ADOGitStatusState, + identifier: str, + description: str, + project_id: str, + repository_id: str, + context: str, + target_url: str | None = None, + ) -> ADOResponse: + """Set a status on a PR. + + :param sha: The SHA of the commit to add the status to. + :param state: The state to set the status to. + :param identifier: A unique identifier for the status (so it can be changed later) + :param description: The text to show in the status + :param project_id: The ID of the project + :param repository_id: The ID for the repository + :param context: The context to use for build status notifications + :param target_url: An optional URL to set which is opened when the description is clicked. + + :returns: The ADO response with the data in it + + :raises ADOException: If the SHA is not the full version, or the state is set to NOT_SET + """ + + self.log.debug(f"Setting status ({state}) on sha ({sha}): {identifier} -> {description}") + + if len(sha) != 40: + raise ADOException("The SHA for a commit must be the full 40 character version") + + if state == ADOGitStatusState.NOT_SET: + raise ADOException("The NOT_SET state cannot be used for statuses on commits") + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/git/repositories/{repository_id}/commits/{sha}/" + ) + request_url += "statuses?api-version=7.1" + + body = { + "state": state.value, + "description": description, + "context": {"name": context, "genre": identifier}, + } + + if target_url is not None: + body["targetUrl"] = target_url + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) + + async def diff_between_commits( + self, + *, + base_commit: str, + target_commit: str, + project_id: str, + repository_id: str, + ) -> ADOResponse: + """Get the diff between two commits. + + :param base_commit: The full hash of the base commit to perform the diff against. + :param target_commit: The full hash of the commit to perform the diff of. + :param project_id: The ID of the project + :param repository_id: The ID for the repository + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Fetching commit diff: {base_commit}..{target_commit}") + + base_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/git/repositories/{repository_id}/diffs/commits?" + ) + + changes: list[dict[str, Any]] = [] + skip = 0 + + while True: + parameters = { + "api-version": "7.1", + "baseVersionType": "commit", + "baseVersion": base_commit, + "targetVersionType": "commit", + "targetVersion": target_commit, + "$skip": skip, + "$top": 100, + } + + request_url = base_url + urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + data = cast(dict[str, Any], self.http_client.decode_response(response)) + + changes.extend(data["changes"]) + + if data.get("allChangesIncluded", False): + data["changes"] = changes + return data + + skip += 100 + + async def download_zip( + self, + *, + branch: str, + path: str = "/", + output_path: str, + project_id: str, + repository_id: str, + callback: Callable[[int, int], None] | None = None, + ) -> None: + """Download the zip of the branch specified. + + :param branch: The name of the branch to download. + :param path: The path in the repository to download. + :param output_path: The path to write the output to. + :param project_id: The ID of the project + :param repository_id: The ID for the repository + :param callback: The callback for download progress updates. First + parameter is bytes downloaded, second is total bytes. + The latter will be 0 if the content length is unknown. + + :raises ADOException: If the output path already exists + :raises ADOHTTPException: If we fail to fetch the zip for any reason + """ + + self.log.debug(f"Downloading branch: {branch}") + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/git/repositories/{repository_id}/Items?" + + parameters = { + "path": path, + "versionDescriptor[versionOptions]": "0", + "versionDescriptor[versionType]": "0", + "versionDescriptor[version]": branch, + "resolveLfs": "true", + "$format": "zip", + "api-version": "7.1", + } + + request_url += urllib.parse.urlencode(parameters) + + if os.path.exists(output_path): + raise ADOException("The output path already exists") + + async with self.http_client.stream_get(request_url) as response: + await download_from_response_stream( + response=response, + output_path=output_path, + log=self.log, + callback=callback, + ) + + # pylint: disable=too-many-locals + async def get_refs( + self, + *, + project_id: str, + repository_id: str, + filter_startswith: str | None = None, + filter_contains: str | None = None, + include_links: bool | None = None, + include_statuses: bool | None = None, + include_my_branches: bool | None = None, + latest_statuses_only: bool | None = None, + peel_tags: bool | None = None, + top: int | None = None, + continuation_token: str | None = None, + ) -> ADOResponse: + """Set a status on a PR. + + All non-specified options use the ADO default. + + :param project_id: The ID of the project + :param repository_id: The ID for the repository + :param filter_startswith: A filter to apply to the refs + (starts with) + :param filter_contains: A filter to apply to the refs + (contains) + :param include_links: Specifies if referenceLinks should + be included in the result + :param include_statuses: Includes up to the first 1000 + commit statuses for each ref + :param include_my_branches: Includes only branches that + the user owns, the branches + the user favorites, and the + default branch. Cannot be + combined with the filter + parameter. + :param latest_statuses_only: True to include only the tip + commit status for each ref. + This requires + `include_statuses` to be set + to `True`. + :param peel_tags: Annotated tags will populate the + `PeeledObjectId` property. + :param top: Maximum number of refs to return. It cannot be + bigger than 1000. If it is not provided, but + `continuation_token` is, top will default to + 100. + :param continuation_token: The continuation token used for + pagination + + :returns: The ADO response with the data in it + """ + + # pylint: disable=too-complex + + self.log.debug("Getting refs") + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/git/repositories/{repository_id}/refs?" + + parameters: dict[str, Any] = {} + + if filter_startswith: + parameters["filter"] = filter_startswith + + if filter_contains: + parameters["filterContains"] = filter_contains + + if include_links: + parameters["includeLinks"] = "true" if include_links else "false" + + if include_statuses: + parameters["includeStatuses"] = "true" if include_statuses else "false" + + if include_my_branches: + parameters["includeMyBranches"] = "true" if include_my_branches else "false" + + if latest_statuses_only: + parameters["latestStatusesOnly"] = "true" if latest_statuses_only else "false" + + if peel_tags: + parameters["peelTags"] = "true" if peel_tags else "false" + + if top: + parameters["$top"] = top + + if continuation_token: + parameters["continuationToken"] = continuation_token + + request_url += urllib.parse.urlencode(parameters) + + if len(parameters) > 0: + request_url += "&" + + request_url += "api-version=7.1" + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + # pylint: enable=too-complex + + # pylint: enable=too-many-locals + + async def get_stats_for_branch( + self, *, project_id: str, repository_id: str, branch_name: str + ) -> ADOResponse: + """Get the stats for a branch. + + :param project_id: The ID of the project + :param repository_id: The ID for the repository + :param branch_name: The name of the branch to get the stats for + + :returns: The ADO response with the data in it + """ + + self.log.debug("Getting stats") + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/git/repositories/{repository_id}" + request_url += f"/stats/branches?name={branch_name}&api-version=7.1" + + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def get_commit( + self, + *, + commit_id: str, + project_id: str, + repository_id: str, + change_count: int | None = None, + ) -> ADOResponse: + """Set a status on a PR. + + All non-specified options use the ADO default. + + :param commit_id: The id of the commit + :param project_id: The ID of the project + :param repository_id: The ID for the repository + :param change_count: The number of changes to include in + the result + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Getting commit: {commit_id}") + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/git/repositories/{repository_id}" + request_url += f"/commits/{commit_id}?api-version=7.1" + + if change_count: + request_url += f"&changeCount={change_count}" + + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def update_refs( + self, + *, + updates: list[ADOReferenceUpdate], + project_id: str, + repository_id: str, + ) -> ADOResponse: + """Update a list of references. + + :param updates: The list of updates to make + :param project_id: The ID of the project + :param repository_id: The ID for the repository + + :returns: The ADO response with the data in it + """ + + self.log.debug("Updating references") + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/git/repositories/{repository_id}" + request_url += "/refs?api-version=7.1" + + data = [update.json_data() for update in updates] + + response = await self.http_client.post(request_url, json_data=data) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def delete_branch( + self, branch_name: str, object_id: str, project_id: str, repository_id: str + ) -> ADOResponse: + """Delete a branch + + :param branch_name: The full name of the branch. e.g. refs/heads/my_branch + :param object_id: The ID of the object the branch currently points to + :param project_id: The ID of the project + :param repository_id: The ID for the repository + + :returns: The ADO response + """ + return await self.update_refs( + updates=[ADOReferenceUpdate(branch_name, object_id, None)], + project_id=project_id, + repository_id=repository_id, + ) + + class VersionControlRecursionType(enum.Enum): + """Specifies the level of recursion to use when getting an item.""" + + FULL = "full" + NONE = "none" + ONE_LEVEL = "oneLevel" + ONE_LEVEL_PLUS_NESTED_EMPTY_FOLDERS = "oneLevelPlusNestedEmptyFolders" + + class GitVersionOptions(enum.Enum): + """Version options.""" + + FIRST_PARENT = "firstParent" + NONE = "none" + PREVIOUS_CHANGE = "previousChange" + + class GitVersionType(enum.Enum): + """Version type. Determines how the ID of an item is interpreted.""" + + BRANCH = "branch" + COMMIT = "commit" + TAG = "tag" + + # pylint: disable=too-many-locals,too-complex + async def get_item( + self, + *, + project_id: str, + repository_id: str, + path: str | None = None, + scope_path: str | None = None, + recursion_level: VersionControlRecursionType | None = None, + include_content_metadata: bool | None = None, + latest_processed_changes: bool | None = None, + version_options: GitVersionOptions | None = None, + version: str | None = None, + version_type: GitVersionType | None = None, + include_content: bool | None = None, + resolve_lfs: bool | None = None, + ) -> ADOResponse: + """Get a git item. + + All non-specified options use the ADO default. + + :param project_id: The ID of the project + :param repository_id: The ID for the repository + :param path: The item path. Either this or scope_path must be set. + :param scope_path: The path scope. Either this or path must be set. + :param recursion_level: The recursion level + :param include_content_metadata: Set to include content metadata + :param latest_processed_changes: Set to include the latest changes + :param version_options: Specify additional modifiers to version + :param version: Version string identifier (name of tag/branch, SHA1 of commit) + :param version_type: Version type (branch, tag or commit). + :param include_content: Set to true to include item content when requesting JSON + :param resolve_lfs: Set to true to resolve LFS pointer files to resolve actual content + + :returns: The ADO response with the data in it + """ + + self.log.debug("Getting item") + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/git/repositories/{repository_id}/items?" + + parameters: dict[str, Any] = {"api-version": "7.1", "$format": "json"} + + if not scope_path and not path: + raise ADOException("Either path or scope_path must be set") + + if scope_path and path: + raise ADOException("Either path or scope_path must be set, not both") + + if path is not None: + parameters["path"] = path + + if scope_path is not None: + parameters["scopePath"] = scope_path + + if recursion_level is not None: + parameters["recursionLevel"] = recursion_level.value + + if include_content_metadata is not None: + parameters["includeContentMetadata"] = "true" if include_content_metadata else "false" + + if latest_processed_changes is not None: + parameters["latestProcessedChange"] = "true" if latest_processed_changes else "false" + + if version_options is not None: + parameters["versionDescriptor.versionOptions"] = version_options.value + + if version is not None: + parameters["versionDescriptor.version"] = version + + if version_type is not None: + parameters["versionDescriptor.versionType"] = version_type.value + + if include_content is not None: + parameters["includeContent"] = "true" if include_content else "false" + + if resolve_lfs is not None: + parameters["resolveLfs"] = "true" if resolve_lfs else "false" + + request_url += urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + # pylint: enable=too-many-locals,too-complex + + # pylint: disable=too-many-locals,line-too-long,too-complex,too-many-branches + async def get_commits( + self, + *, + project_id: str, + repository_id: str, + skip: int | None = None, + top: int | None = None, + from_date: str | None = None, + to_date: str | None = None, + from_commit_id: str | None = None, + to_commit_id: str | None = None, + author: str | None = None, + user: str | None = None, + exclude_deletes: bool | None = None, + inlcude_links: bool | None = None, + include_push_data: bool | None = None, + include_user_image_url: bool | None = None, + include_work_items: bool | None = None, + item_path: str | None = None, + item_version: str | None = None, + item_version_options: GitVersionOptions | None = None, + item_version_type: GitVersionType | None = None, + ) -> ADOResponse: + """Retrieve git commits for a project + + All non-specified options use the ADO default. + + :param project_id: The ID of the project + :param repository_id: The ID for the repository + :param skip: Number of entries to skip + :param top: Maximum number of entries to retrieve + :param from_date: If provided, only include history entries created after this date + :param to_date: If provided, only include history entries created before this date + :param from_commit_id: If provided, a lower bound for filtering commits alphabetically + :param to_commit_id: If provided, an upper bound for filtering commits alphabetically + :param author: Alias or display name of the author + :param user: Alias or display name of the committer + :param exclude_deletes: If itemPath is specified, determines whether to exclude delete entries of the specified path. + :param inlcude_links: Whether to include the _links field on the shallow references + :param include_push_data: Whether to include the push information + :param include_user_image_url: Whether to include the image Url for committers and authors + :param include_work_items: Whether to include linked work items + :param item_path: Path of item to search under + :param item_version: Version string identifier (name of tag/branch, SHA1 of commit) + :param item_version_options: Version options - Specify additional modifiers to version (e.g Previous) + :param item_version_type: Version type (branch, tag, or commit). Determines how Id is interpreted + + :returns: The ADO response with the data in it + """ + self.log.debug("Getting commits") + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/git/repositories/{repository_id}/commits?" + + parameters: dict[str, Any] = {"api-version": "7.2-preview.2"} + + if skip is not None: + parameters["$skip"] = skip + + if top is not None: + parameters["$top"] = top + + if from_date is not None: + parameters["fromDate"] = from_date + + if to_date is not None: + parameters["toDate"] = to_date + + if from_commit_id is not None: + parameters["fromCommitId"] = from_commit_id + + if to_commit_id is not None: + parameters["toCommitId"] = to_commit_id + + if author is not None: + parameters["author"] = author + + if user is not None: + parameters["user"] = user + + if exclude_deletes is not None: + parameters["excludeDeletes"] = exclude_deletes + + if inlcude_links is not None: + parameters["includeLinks"] = inlcude_links + + if include_push_data is not None: + parameters["includePushData"] = include_push_data + + if include_user_image_url is not None: + parameters["includeUserImageUrl"] = include_user_image_url + + if include_work_items is not None: + parameters["includeWorkItems"] = include_work_items + + if item_path is not None: + parameters["itemPath"] = item_path + + if item_version is not None: + parameters["itemVersion.version"] = item_version + + if item_version_options is not None: + parameters["itemVersion.versionOptions"] = item_version_options.value + + if item_version_type is not None: + parameters["itemVersion.versionType"] = item_version_type.value + + request_url += urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + # pylint: disable=too-many-locals,line-too-long,too-complex,too-many-branches + + class BlobFormat(enum.Enum): + """The type of format to get a blob in.""" + + JSON = "json" + ZIP = "zip" + TEXT = "text" + OCTETSTREAM = "octetstream" + + async def get_blob( + self, + *, + blob_id: str, + project_id: str, + repository_id: str, + blob_format: BlobFormat | None = None, + download: bool | None = None, + file_name: str | None = None, + resolve_lfs: bool | None = None, + ) -> Any: + """Get a git item. + + All non-specified options use the ADO default. + + :param blob_id: The SHA1 of the blob + :param project_id: The ID for the project + :param repository_id: The ID for the repository + :param blob_format: The format to get the blob in + :param download: Set to True to download rather than get a response + :param file_name: The file name to use for the download if download is set to True + :param resolve_lfs: Set to true to resolve LFS pointer files to resolve actual content + + :returns: The data returned and the return type depends on what you set blob_format to + """ + + self.log.debug("Getting blob") + + request_url = self.http_client.api_endpoint( + is_default_collection=False, project_id=project_id + ) + request_url += f"/git/repositories/{repository_id}/blobs/{blob_id}?" + + parameters: dict[str, Any] = { + "api-version": "7.1", + } + + if blob_format is not None: + parameters["$format"] = blob_format.value + + if download is not None: + parameters["download"] = "true" if download else "false" + + if file_name is not None: + parameters["fileName"] = file_name + + if resolve_lfs is not None: + parameters["resolveLfs"] = "true" if resolve_lfs else "false" + + request_url += urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + + if blob_format == ADOAsyncGitClient.BlobFormat.TEXT: + self.http_client.validate_response(response) + return response.text + + if blob_format == ADOAsyncGitClient.BlobFormat.JSON: + return self.http_client.decode_response(response) + + self.http_client.validate_response(response) + return response.content + + async def get_blobs( + self, + *, + blob_ids: list[str], + output_path: str, + project_id: str, + repository_id: str, + ) -> ADOResponse: + """Get a git item. + + All non-specified options use the ADO default. + + :param blob_ids: The SHA1s of the blobs + :param output_path: The location to write out the zip to + :param project_id: The ID for the project + :param repository_id: The ID for the repository + + :raises FileExistsError: If the output path already exists + """ + + self.log.debug("Getting blobs") + + request_url = self.http_client.api_endpoint( + is_default_collection=False, project_id=project_id + ) + request_url += f"/git/repositories/{repository_id}/blobs?api-version=7.1" + + if os.path.exists(output_path): + raise FileExistsError("The output path already exists") + + async with self.http_client.stream_post( + request_url, + additional_headers={"Accept": "application/zip"}, + json_data=blob_ids, + ) as response: + await download_from_response_stream( + response=response, output_path=output_path, log=self.log + ) diff --git a/simple_ado/_async/governance.py b/simple_ado/_async/governance.py new file mode 100644 index 0000000..9732abd --- /dev/null +++ b/simple_ado/_async/governance.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO governance API wrapper (async).""" + +import enum +import logging +from typing import Any, cast +import urllib.parse + + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado.exceptions import ADOHTTPException +from simple_ado._async.http_client import ADOAsyncHTTPClient + + +class ADOAsyncGovernanceClient(ADOAsyncBaseClient): + """Wrapper class around the ADO Governance APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + class AlertSeverity(enum.Enum): + """The potential alert severities.""" + + LOW = 0 + MEDIUM = 1 + HIGH = 2 + CRITICAL = 3 + + @classmethod + def _missing_(cls, value: Any) -> "ADOAsyncGovernanceClient.AlertSeverity": + if value == "low": + return cls(0) + + if value == "medium": + return cls(1) + + if value == "high": + return cls(2) + + if value == "critical": + return cls(3) + + return cast("ADOAsyncGovernanceClient.AlertSeverity", super()._missing_(value)) + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("governance")) + + async def get_governed_repositories(self, *, project_id: str) -> dict[str, Any]: + """Get all governed repositories for the project + + :param project_id: The ID of the project + + :returns: The governed repositories + """ + + request_url = self.http_client.api_endpoint( + is_default_collection=False, subdomain="governance", project_id=project_id + ) + request_url += "/ComponentGovernance/GovernedRepositories?api-version=6.1-preview.1" + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return cast(dict[str, Any], self.http_client.extract_value(response_data)) + + async def get_governed_repository( + self, *, governed_repository_id: str | int, project_id: str + ) -> dict[str, Any]: + """Get a particular governed repository. + + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + + :returns: The governed repository details + """ + + request_url = self.http_client.api_endpoint( + is_default_collection=False, subdomain="governance", project_id=project_id + ) + request_url += f"/ComponentGovernance/GovernedRepositories/{governed_repository_id}?api-version=6.1-preview.1" + + response = await self.http_client.get(request_url) + return cast(dict[str, Any], self.http_client.decode_response(response)) + + async def delete_governed_repository( + self, *, governed_repository_id: str | int, project_id: str + ) -> None: + """Delete a governed repository. + + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + """ + + request_url = self.http_client.api_endpoint( + is_default_collection=False, subdomain="governance", project_id=project_id + ) + request_url += f"/ComponentGovernance/GovernedRepositories/{governed_repository_id}?api-version=6.1-preview.1" + + response = await self.http_client.delete(request_url) + self.http_client.validate_response(response) + + async def remove_policy( + self, + *, + policy_id: str, + governed_repository_id: str | int, + project_id: str, + ) -> None: + """Remove a policy from a repository. + + :param policy_id: The ID of the policy to remove + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + + :raises ADOHTTPException: If removing the policy failed + """ + + request_url = self.http_client.api_endpoint( + is_default_collection=False, subdomain="governance", project_id=project_id + ) + request_url += "/ComponentGovernance/GovernedRepositories" + request_url += f"/{governed_repository_id}/policyreferences" + request_url += f"/{policy_id}?api-version=5.1-preview.1" + + response = await self.http_client.delete(request_url) + + if not response.is_success: + raise ADOHTTPException( + f"Failed to remove policy {policy_id} from {governed_repository_id}", + response, + ) + + async def _set_alert_settings( + self, + *, + alert_settings: dict[str, Any], + governed_repository_id: str | int, + project_id: str, + ) -> None: + """Set alert settings for governance for a repository. + + :param alert_settings: The settings for the alert on the repo + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + + :raises ADOHTTPException: If setting the alert settings failed + """ + + request_url = self.http_client.api_endpoint( + is_default_collection=False, subdomain="governance", project_id=project_id + ) + request_url += "/ComponentGovernance/GovernedRepositories" + request_url += f"/{governed_repository_id}/AlertSettings" + request_url += "?api-version=5.0-preview.2" + + response = await self.http_client.put(request_url, alert_settings) + + if not response.is_success: + raise ADOHTTPException( + f"Failed to set alert settings on repo {governed_repository_id}", + response, + ) + + async def get_alert_settings( + self, *, governed_repository_id: str | int, project_id: str + ) -> dict[str, Any]: + """Get alert settings for governance for a repository. + + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + + :returns: The settings for the alerts on the repo + """ + + request_url = self.http_client.api_endpoint( + is_default_collection=False, subdomain="governance", project_id=project_id + ) + request_url += "/ComponentGovernance/GovernedRepositories" + request_url += f"/{governed_repository_id}/AlertSettings" + request_url += "?api-version=5.0-preview.2" + + response = await self.http_client.get(request_url) + return cast(dict[str, Any], self.http_client.decode_response(response)) + + async def get_show_banner_in_repo_view( + self, *, governed_repository_id: str | int, project_id: str + ) -> bool: + """Get whether to show the banner in the repo view or not. + + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + + :returns: True if the banner is shown in the repo view, False otherwise + """ + + current_settings = await self.get_alert_settings( + governed_repository_id=governed_repository_id, project_id=project_id + ) + + return cast(bool, current_settings["showRepositoryWarningBanner"]) + + async def set_show_banner_in_repo_view( + self, + *, + show_banner: bool, + governed_repository_id: str | int, + project_id: str, + ) -> None: + """Set whether to show the banner in the repo view or not. + + :param show_banner: Set to True to show the banner in the repo view, False to hide it + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + """ + + current_settings = await self.get_alert_settings( + governed_repository_id=governed_repository_id, project_id=project_id + ) + + current_settings["showRepositoryWarningBanner"] = show_banner + + await self._set_alert_settings( + alert_settings=current_settings, + governed_repository_id=governed_repository_id, + project_id=project_id, + ) + + async def get_minimum_alert_severity( + self, *, governed_repository_id: str | int, project_id: str + ) -> AlertSeverity: + """Get the minimum severity to alert for. + + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + + :returns: The minimum alert severity + """ + + current_settings = await self.get_alert_settings( + governed_repository_id=governed_repository_id, project_id=project_id + ) + + return ADOAsyncGovernanceClient.AlertSeverity(current_settings["minimumAlertSeverity"]) + + async def set_minimum_alert_severity( + self, + *, + alert_severity: AlertSeverity, + governed_repository_id: str | int, + project_id: str, + ) -> None: + """Set the minimum severity to alert for. + + :param alert_severity: The minimum alert serverity to notify about + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + """ + + current_settings = await self.get_alert_settings( + governed_repository_id=governed_repository_id, project_id=project_id + ) + + current_settings["minimumAlertSeverity"] = alert_severity.value + + await self._set_alert_settings( + alert_settings=current_settings, + governed_repository_id=governed_repository_id, + project_id=project_id, + ) + + async def set_work_item_settings( + self, + *, + create_for_security_alerts: bool, + create_for_legal_alerts: bool, + area_path: str, + work_item_type: str, + extra_fields: list[tuple[str, str]] | None = None, + governed_repository_id: str | int, + project_id: str, + ) -> None: + """Set whether to show the banner in the repo view or not. + + :param create_for_security_alerts: Set to True to create work items for security alerts, False otherwise + :param create_for_legal_alerts: Set to True to create work items for legal alerts, False otherwise + :param area_path: The area path to open the tickets under + :param work_item_type: The type of work item to create (this must match one in your project) + :param extra_fields: An optional list of tuples of field IDs and values to set on the created work item + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + """ + + current_settings = await self.get_alert_settings( + governed_repository_id=governed_repository_id, project_id=project_id + ) + + current_settings["workItemSettings"]["areaPath"] = area_path + current_settings["workItemSettings"][ + "legalAlertWorkItemCreationEnabled" + ] = create_for_legal_alerts + current_settings["workItemSettings"][ + "securityAlertWorkItemCreationEnabled" + ] = create_for_security_alerts + current_settings["workItemSettings"]["workItemType"] = work_item_type + + if extra_fields: + current_settings["workItemSettings"]["workItemTemplateRows"] = [] + for field_id, value in extra_fields: + current_settings["workItemSettings"]["workItemTemplateRows"].append( + {"fieldId": field_id, "value": value} + ) + + await self._set_alert_settings( + alert_settings=current_settings, + governed_repository_id=governed_repository_id, + project_id=project_id, + ) + + async def get_branches( + self, + *, + tracked_only: bool = True, + governed_repository_id: str | int, + project_id: str, + ) -> dict[str, Any]: + """Get the branches for the goverened repository. + + Note: Due to lack of documentation, the pagination for this API is + unclear and therefore this call can only return the top 99,999 results. + If there are any more than this, the call will return the top 99,999 and + exit. + + :param tracked_only: Set to True if only tracked branches should be returned (default), False otherwise + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + + :returns: The settings for the alerts on the repo + """ + + request_url = self.http_client.api_endpoint( + is_default_collection=False, subdomain="governance", project_id=project_id + ) + request_url += ( + f"/ComponentGovernance/GovernedRepositories/{governed_repository_id}/Branches?" + ) + + parameters: dict[str, Any] = {"top": 99999, "isTracked": tracked_only} + + request_url += urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return cast(dict[str, Any], self.http_client.extract_value(response_data)) + + async def get_alerts( + self, + *, + branch_name: str, + include_history: bool = False, + include_development_dependencies: bool = True, + governed_repository_id: str | int, + project_id: str, + ) -> dict[str, Any]: + """Get the alerts on a given branch. + + :param branch_name: The branch to get the alerts for + :param include_history: It isn't clear what this parameter does. Defaults to False + :param include_development_dependencies: Set to True to include alerts on development + dependencies, False otherwise (defaults to True) + :param governed_repository_id: The repository governance ID + :param project_id: The ID of the project + + :returns: The settings for the alerts on the repo + """ + + request_url = self.http_client.api_endpoint( + is_default_collection=False, subdomain="governance", project_id=project_id + ) + request_url += ( + f"/ComponentGovernance/GovernedRepositories/{governed_repository_id}" + + f"/Branches/{branch_name}/Alerts?" + ) + + parameters = { + "includeHistory": include_history, + "includeDevelopmentDependencies": include_development_dependencies, + } + + request_url += urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return cast(dict[str, Any], self.http_client.extract_value(response_data)) diff --git a/simple_ado/_async/graph.py b/simple_ado/_async/graph.py new file mode 100644 index 0000000..0090d0f --- /dev/null +++ b/simple_ado/_async/graph.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO graph API wrapper (async).""" + +import logging +from typing import Any, AsyncIterator, cast + + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado._async.http_client import ADOAsyncHTTPClient, ADOResponse + + +class ADOAsyncGraphClient(ADOAsyncBaseClient): + """Wrapper class around the ADO Graph APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("graph")) + + async def get_scope_descriptors(self, storage_key: str) -> ADOResponse: + """Get the scope descriptors for a given subject. + + :param storage_key: Storage key (UUID) of the subject (user, group, scope, etc.) to resolve + + :returns: The scope descriptors + """ + + request_url = f"{self.http_client.graph_endpoint()}/graph/descriptors/{storage_key}" + request_url += "/?api-version=7.1-preview.1" + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def get_storage_key(self, subject_descriptor: str) -> ADOResponse: + """Get the storage key for a given subject descriptor. + + :param subject_descriptor: Descriptor of the subject to resolve + + :returns: The storage key + """ + + request_url = f"{self.http_client.graph_endpoint()}/graph/storagekeys/{subject_descriptor}" + request_url += "/?api-version=7.1-preview.1" + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def lookup_subjects(self, subject_descriptors: list[str]) -> list[Any]: + """Lookup the various subject descriptors. + + :param subject_descriptors: Descriptors of the subjects to resolve + + :returns: The lookup keys + """ + + request_url = f"{self.http_client.graph_endpoint()}/graph/subjectlookup" + request_url += "/?api-version=7.1-preview.1" + + data = {"lookupKeys": [{"descriptor": descriptor} for descriptor in subject_descriptors]} + + response = await self.http_client.post(request_url, json_data=data) + response_data = self.http_client.decode_response(response) + return cast(list[Any], self.http_client.extract_value(response_data)) + + async def list_groups(self, *, scope_descriptor: str | None = None) -> list[Any]: + """Get the groups in the organization. + + :param scope_descriptor: Specify a non-default scope (collection, project) to search for groups. + + :returns: The ADO response with the data in it + """ + + request_url = f"{self.http_client.graph_endpoint()}/graph/groups?api-version=7.1-preview.1" + + if scope_descriptor: + request_url += f"&scopeDescriptor={scope_descriptor}" + + groups: list[Any] = [] + continuation_token = None + + while True: + if continuation_token: + url = request_url + f"&continuationToken={continuation_token}" + else: + url = request_url + + response = await self.http_client.get(url) + decoded = self.http_client.decode_response(response) + groups += decoded["value"] + + if "X-MS-ContinuationToken" not in response.headers: + break + + continuation_token = response.headers["X-MS-ContinuationToken"] + + return groups + + async def get_group(self, descriptor: str) -> ADOResponse: + """Get the group + + :param descriptor: The descriptor for the group + + :returns: The ADO response with the data in it + """ + + request_url = f"{self.http_client.graph_endpoint()}/graph/groups" + request_url += f"/{descriptor}?api-version=7.1-preview.1" + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def get_user(self, descriptor: str) -> ADOResponse: + """Get the user + + :param descriptor: The descriptor for the user + + :returns: The ADO response with the data in it + """ + + request_url = f"{self.http_client.graph_endpoint()}/graph/users" + request_url += f"/{descriptor}?api-version=7.1-preview.1" + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def get_users(self) -> ADOResponse: + """Get the users + + :returns: The ADO response with the data in it + """ + + request_url = f"{self.http_client.graph_endpoint()}/graph/users" + request_url += "?api-version=7.1-preview.1" + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def list_users_in_container(self, scope_descriptor: str) -> AsyncIterator[Any]: + """Get users in the container. + + :param scope_descriptor: Specify a non-default scope (collection, project) to search for users. + + :returns: The ADO response with the data in it + """ + + request_url = ( + f"{self.http_client.graph_endpoint()}/graph/Memberships/" + + f"{scope_descriptor}?api-version=7.1-preview.1&direction=down" + ) + + continuation_token = None + + while True: + if continuation_token: + url = request_url + f"&continuationToken={continuation_token}" + else: + url = request_url + + response = await self.http_client.get(url) + decoded = self.http_client.decode_response(response) + for item in decoded["value"]: + yield item + + if "X-MS-ContinuationToken" not in response.headers: + break + + continuation_token = response.headers["X-MS-ContinuationToken"] diff --git a/simple_ado/_async/http_client.py b/simple_ado/_async/http_client.py new file mode 100644 index 0000000..6384c16 --- /dev/null +++ b/simple_ado/_async/http_client.py @@ -0,0 +1,622 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO HTTP API wrapper (async).""" + +import asyncio +import contextlib +import datetime +import logging +import os +from typing import Any, AsyncIterator, TypeAlias + +import httpx +from tenacity import ( + retry, + retry_if_exception, + stop_after_attempt, + wait_random_exponential, +) + +from simple_ado._async.auth.ado_auth import ADOAsyncAuth +from simple_ado.exceptions import ADOException, ADOHTTPException +from simple_ado.models import PatchOperation + + +# pylint: disable=invalid-name +ADOThread = dict[str, Any] +ADOResponse: TypeAlias = Any +# pylint: enable=invalid-name + + +# Only retry status codes where a subsequent attempt may succeed: +# 400 Bad Request — ADO occasionally returns transient 400s under load +# 408 Request Timeout — server didn't receive the full request in time +# 429 Too Many Requests — rate-limited, back off and retry +# 500, 502, 503, 504 — server errors that are typically transient +# Previous versions retried all 4xx, but codes like 401/403/404 are +# deterministic failures that won't resolve on retry. +_RETRYABLE_STATUS_CODES = {400, 408, 429, 500, 502, 503, 504} + + +def _is_retryable_get_failure(exception: Exception) -> bool: + if not isinstance(exception, ADOHTTPException): + return False + + return exception.response.status_code in _RETRYABLE_STATUS_CODES + + +def _is_connection_failure(exception: Exception) -> bool: + if isinstance(exception, (httpx.ConnectError, httpx.TimeoutException)): + return True + + exception_checks = [ + "Operation timed out", + "Connection aborted.", + "bad handshake: ", + "Failed to establish a new connection", + ] + + for check in exception_checks: + if check in str(exception): + return True + + return False + + +class ADOAsyncHTTPClient: + """Base class that actually makes API calls to Azure DevOps (async). + + :param tenant: The name of the ADO tenant to connect to + :param extra_headers: Any extra headers which should be added to each request + :param user_agent: The user agent to set + :param auth: The authentication details + :param log: The logger to use for logging + """ + + log: logging.Logger + tenant: str + extra_headers: dict[str, str] + auth: ADOAsyncAuth + _not_before: datetime.datetime | None + _client: httpx.AsyncClient + + def __init__( + self, + *, + tenant: str, + auth: ADOAsyncAuth, + user_agent: str, + log: logging.Logger, + extra_headers: dict[str, str] | None = None, + ) -> None: + """Construct a new client object.""" + + self.log = log.getChild("http") + + self.tenant = tenant + self.auth = auth + self._not_before = None + + self._client = httpx.AsyncClient( + headers={"User-Agent": f"simple_ado/{user_agent}"}, + follow_redirects=True, + timeout=httpx.Timeout(300.0), + ) + + if extra_headers is None: + self.extra_headers = {} + else: + self.extra_headers = extra_headers + + async def close(self) -> None: + """Close the underlying HTTP client and auth resources.""" + await self._client.aclose() + await self.auth.close() + + async def __aenter__(self) -> "ADOAsyncHTTPClient": + return self + + async def __aexit__(self, *args: Any) -> None: + await self.close() + + def graph_endpoint(self) -> str: + """Generate the base url for all graph API calls (this varies depending on the API). + + :returns: The constructed graph URL + """ + return f"https://vssps.dev.azure.com/{self.tenant}/_apis" + + def audit_endpoint(self) -> str: + """Generate the base url for all audit API calls. + + :returns: The constructed graph URL + """ + return f"https://auditservice.dev.azure.com/{self.tenant}/_apis" + + def api_endpoint( + self, + *, + is_default_collection: bool = True, + is_internal: bool = False, + subdomain: str | None = None, + project_id: str | None = None, + ) -> str: + """Generate the base url for all API calls (this varies depending on the API). + + :param is_default_collection: Whether this URL should start with the path "/DefaultCollection" + :param is_internal: Whether this URL should use internal API endpoint "/_api" + :param subdomain: A subdomain that should be used (if any) + :param project_id: The project ID. This will be added if supplied + + :returns: The constructed base URL + """ + + url = f"https://{self.tenant}." + + if subdomain: + url += subdomain + "." + + url += "visualstudio.com" + + if is_default_collection: + url += "/DefaultCollection" + + if project_id: + url += f"/{project_id}" + + if is_internal: + url += "/_api" + else: + url += "/_apis" + + return url + + async def _wait(self) -> None: + """Wait as long as we need for rate limiting purposes.""" + if not self._not_before: + return + + remaining = self._not_before - datetime.datetime.now() + + if remaining.total_seconds() < 0: + self._not_before = None + return + + self.log.debug(f"Sleeping for {remaining} seconds before issuing next request") + await asyncio.sleep(remaining.total_seconds()) + + def _track_rate_limit(self, response: httpx.Response) -> None: + """Track the rate limit info from a request. + + :param response: The response to track the info from. + """ + + if "Retry-After" in response.headers: + # We get massive windows for retry after, so we wait 10 seconds or + # the duration, whichever is smaller. If we get a 429, we'll increase. + self._not_before = datetime.datetime.now() + datetime.timedelta( + seconds=min(15, int(response.headers["Retry-After"])) + ) + return + + # Slow down if needed + if int(response.headers.get("X-RateLimit-Remaining", 100)) < 10: + self._not_before = datetime.datetime.now() + datetime.timedelta(seconds=1) + return + + # No limit, so go at full speed + self._not_before = None + + @retry( + retry=( + retry_if_exception(_is_connection_failure) # type: ignore + | retry_if_exception(_is_retryable_get_failure) # type: ignore + ), + wait=wait_random_exponential(max=10), + stop=stop_after_attempt(5), + ) + async def get( + self, + request_url: str, + *, + additional_headers: dict[str, str] | None = None, + stream: bool = False, + follow_redirects: bool = True, + set_accept_json: bool = True, + ) -> httpx.Response: + """Issue a GET request with the correct headers. + + When stream=True, the response body is not immediately loaded. The caller + must use the response as a context manager or call response.close() when done. + + :param request_url: The URL to issue the request to + :param additional_headers: Any additional headers to add to the request + :param stream: Set to True to stream the response back + :param follow_redirects: Set to False to disable redirects + :param set_accept_json: Set to False to disable setting the Accept header + + :returns: The raw response object from the API + """ + + await self._wait() + + headers = await self.construct_headers( + additional_headers=additional_headers, set_accept_json=set_accept_json + ) + + if stream: + request = self._client.build_request( + "GET", + request_url, + headers=headers, + ) + response = await self._client.send( + request, + stream=True, + follow_redirects=follow_redirects, + ) + else: + response = await self._client.get( + request_url, + headers=headers, + follow_redirects=follow_redirects, + ) + + self._track_rate_limit(response) + + return response + + @contextlib.asynccontextmanager + async def stream_get( + self, + request_url: str, + *, + additional_headers: dict[str, str] | None = None, + follow_redirects: bool = True, + set_accept_json: bool = True, + ) -> AsyncIterator[httpx.Response]: + """Issue a streaming GET request. Must be used as an async context manager. + + This is a convenience wrapper that handles cleanup automatically. + Prefer this over get(stream=True) when possible. + + :param request_url: The URL to issue the request to + :param additional_headers: Any additional headers to add to the request + :param follow_redirects: Set to False to disable redirects + :param set_accept_json: Set to False to disable setting the Accept header + + :yields: The raw response object from the API + """ + await self._wait() + + headers = await self.construct_headers( + additional_headers=additional_headers, set_accept_json=set_accept_json + ) + + async with self._client.stream( + "GET", + request_url, + headers=headers, + follow_redirects=follow_redirects, + ) as response: + self._track_rate_limit(response) + yield response + + @retry( + retry=retry_if_exception(_is_connection_failure), # type: ignore + wait=wait_random_exponential(max=10), + stop=stop_after_attempt(5), + ) + async def post( + self, + request_url: str, + *, + operations: list[PatchOperation] | None = None, + additional_headers: dict[str, str] | None = None, + json_data: Any | None = None, + stream: bool = False, + ) -> httpx.Response: + """Issue a POST request with the correct headers. + + Note: If `json_data` and `operations` are not None, the latter will take + precedence. + + When stream=True, the response body is not immediately loaded. The caller + must use the response as a context manager or call response.close() when done. + + :param request_url: The URL to issue the request to + :param operations: The patch operations to send with the request + :param additional_headers: Any additional headers to add to the request + :param json_data: The JSON data to send with the request + :param stream: Set to True to stream the response back + + :returns: The raw response object from the API + """ + + await self._wait() + + if operations is not None: + json_data = [operation.serialize() for operation in operations] + if additional_headers is None: + additional_headers = {} + if "Content-Type" not in additional_headers: + additional_headers["Content-Type"] = "application/json-patch+json" + + headers = await self.construct_headers(additional_headers=additional_headers) + + if stream: + request = self._client.build_request( + "POST", + request_url, + headers=headers, + json=json_data, + ) + response = await self._client.send(request, stream=True) + else: + response = await self._client.post( + request_url, + headers=headers, + json=json_data, + ) + + self._track_rate_limit(response) + + return response + + @contextlib.asynccontextmanager + async def stream_post( + self, + request_url: str, + *, + additional_headers: dict[str, str] | None = None, + json_data: Any | None = None, + ) -> AsyncIterator[httpx.Response]: + """Issue a streaming POST request. Must be used as an async context manager. + + :param request_url: The URL to issue the request to + :param additional_headers: Any additional headers to add to the request + :param json_data: The JSON data to send with the request + + :yields: The raw response object from the API + """ + await self._wait() + + headers = await self.construct_headers(additional_headers=additional_headers) + + async with self._client.stream( + "POST", + request_url, + headers=headers, + json=json_data, + ) as response: + self._track_rate_limit(response) + yield response + + @retry( + retry=retry_if_exception(_is_connection_failure), # type: ignore + wait=wait_random_exponential(max=10), + stop=stop_after_attempt(5), + ) + async def patch( + self, + request_url: str, + *, + operations: list[PatchOperation] | None = None, + json_data: Any | None = None, + additional_headers: dict[str, Any] | None = None, + ) -> httpx.Response: + """Issue a PATCH request with the correct headers. + + Note: If `json_data` and `operations` are not None, the latter will take + precedence. + + :param request_url: The URL to issue the request to + :param additional_headers: Any additional headers to add to the request + :param json_data: The JSON data to send with the request + :param operations: The patch operations to send with the request + + :returns: The raw response object from the API + """ + + await self._wait() + + if operations is not None: + json_data = [operation.serialize() for operation in operations] + if additional_headers is None: + additional_headers = {} + if "Content-Type" not in additional_headers: + additional_headers["Content-Type"] = "application/json-patch+json" + + headers = await self.construct_headers(additional_headers=additional_headers) + response = await self._client.patch(request_url, headers=headers, json=json_data) + self._track_rate_limit(response) + return response + + @retry( + retry=retry_if_exception(_is_connection_failure), # type: ignore + wait=wait_random_exponential(max=10), + stop=stop_after_attempt(5), + ) + async def put( + self, + request_url: str, + json_data: Any | None = None, + *, + additional_headers: dict[str, Any] | None = None, + ) -> httpx.Response: + """Issue a PUT request with the correct headers. + + :param request_url: The URL to issue the request to + :param additional_headers: Any additional headers to add to the request + :param json_data: The JSON data to send with the request + + :returns: The raw response object from the API + """ + await self._wait() + + headers = await self.construct_headers(additional_headers=additional_headers) + response = await self._client.put(request_url, headers=headers, json=json_data) + self._track_rate_limit(response) + return response + + @retry( + retry=retry_if_exception(_is_connection_failure), # type: ignore + wait=wait_random_exponential(max=10), + stop=stop_after_attempt(5), + ) + async def delete( + self, request_url: str, *, additional_headers: dict[str, Any] | None = None + ) -> httpx.Response: + """Issue a DELETE request with the correct headers. + + :param request_url: The URL to issue the request to + :param additional_headers: Any additional headers to add to the request + + :returns: The raw response object from the API + """ + await self._wait() + + headers = await self.construct_headers(additional_headers=additional_headers) + response = await self._client.delete(request_url, headers=headers) + self._track_rate_limit(response) + return response + + @retry( + retry=retry_if_exception(_is_connection_failure), # type: ignore + wait=wait_random_exponential(max=10), + stop=stop_after_attempt(5), + ) + async def post_file( + self, + request_url: str, + file_path: str, + *, + additional_headers: dict[str, Any] | None = None, + ) -> httpx.Response: + """POST a file to the URL with the given file name. + + :param request_url: The URL to issue the request to + :param file_path: The path to the file to be posted + :param additional_headers: Any additional headers to add to the request + + :returns: The raw response object from the API""" + + await self._wait() + + file_size = os.path.getsize(file_path) + + headers = await self.construct_headers(additional_headers=additional_headers) + headers["Content-Length"] = str(file_size) + headers["Content-Type"] = "application/json" + + content = await asyncio.to_thread(self._read_file, file_path) + + response = await self._client.post( + request_url, + headers=headers, + content=content, + ) + self._track_rate_limit(response) + return response + + @staticmethod + def _read_file(file_path: str) -> bytes: + """Read a file's contents as bytes. + + :param file_path: The path to the file to read + :returns: The file contents + """ + with open(file_path, "rb") as file_handle: + return file_handle.read() + + def validate_response(self, response: httpx.Response) -> None: + """Checking a response for errors. + + :param response: The response to check + + :raises ADOHTTPException: Raised if the request returned a non-200 status code + :raises ADOException: Raise if the response was not JSON + """ + + self.log.debug("Validating response from ADO") + + if not response.is_success: + raise ADOHTTPException( + f"ADO returned a non-200 status code, configuration={self}", + response, + ) + + def decode_response(self, response: httpx.Response) -> ADOResponse: + """Decode the response from ADO, checking for errors. + + :param response: The response to check and parse + + :returns: The JSON data from the ADO response + + :raises ADOHTTPException: Raised if the request returned a non-200 status code + :raises ADOException: Raise if the response was not JSON + """ + + self.validate_response(response) + + self.log.debug("Decoding response from ADO") + + try: + content: ADOResponse = response.json() + except Exception as ex: + raise ADOException("The response did not contain JSON") from ex + + return content + + def extract_value(self, response_data: ADOResponse) -> ADOResponse: + """Extract the "value" from the raw JSON data from an API response + + :param response_data: The raw JSON data from an API response + + :returns: The ADO response with the data in it + + :raises ADOException: If the response is invalid (does not support value extraction) + """ + + self.log.debug("Extracting value") + + try: + value: ADOResponse = response_data["value"] + return value + except Exception as ex: + raise ADOException("The response was invalid (did not contain a value).") from ex + + async def construct_headers( + self, + *, + additional_headers: dict[str, str] | None = None, + set_accept_json: bool = True, + ) -> dict[str, str]: + """Contruct the headers used for a request, adding anything additional. + + :param additional_headers: A dictionary of the additional headers to add. + :param set_accept_json: Set to False to disable setting the Accept header + + :returns: A dictionary of the headers for a request + """ + + headers: dict[str, str] = {} + + if set_accept_json: + headers["Accept"] = "application/json" + + headers["Authorization"] = await self.auth.get_authorization_header() + + for header_name, header_value in self.extra_headers.items(): + headers[header_name] = header_value + + if additional_headers is None: + return headers + + for header_name, header_value in additional_headers.items(): + headers[header_name] = header_value + + return headers diff --git a/simple_ado/_async/identities.py b/simple_ado/_async/identities.py new file mode 100644 index 0000000..dd5dfd3 --- /dev/null +++ b/simple_ado/_async/identities.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO identities API wrapper (async).""" + +import logging +from typing import Any, cast + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado.exceptions import ADOException +from simple_ado._async.http_client import ADOAsyncHTTPClient +from simple_ado.ado_types import TeamFoundationId + + +class ADOAsyncIdentitiesClient(ADOAsyncBaseClient): + """Wrapper class around the ADO identities APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("user")) + + async def search(self, identity: str) -> list[dict[str, Any]]: + """Fetch the unique Team Foundation GUID for a given identity. + + :param identity: The identity to fetch for (should be email for users and display name for groups) + + :returns: The found identities + + :raises ADOException: If we can't get the identity from the response + """ + + request_url = self.http_client.graph_endpoint() + request_url += f"/identities?searchFilter=General&filterValue={identity}" + request_url += "&queryMembership=None&api-version=7.1" + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return cast(list[dict[str, Any]], self.http_client.extract_value(response_data)) + + async def get_team_foundation_id(self, identity: str) -> TeamFoundationId: + """Fetch the unique Team Foundation GUID for a given identity. + + :param identity: The identity to fetch for (should be email for users and display name for groups) + + :returns: The team foundation ID + + :raises ADOException: If we can't get the identity from the response + """ + + results = await self.search(identity) + + if len(results) == 0: + raise ADOException("Could not resolve identity: " + identity) + + if len(results) > 1: + raise ADOException(f"Found multiple identities matching '{identity}'") + + result = results[0] + + return cast(TeamFoundationId, result["id"]) diff --git a/simple_ado/_async/pipelines.py b/simple_ado/_async/pipelines.py new file mode 100644 index 0000000..77b7486 --- /dev/null +++ b/simple_ado/_async/pipelines.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO pipeline API wrapper (async).""" + +import logging +from typing import Any, AsyncIterator, cast +import urllib.parse + + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado._async.http_client import ADOAsyncHTTPClient, ADOResponse + + +class ADOAsyncPipelineClient(ADOAsyncBaseClient): + """Wrapper class around the ADO Pipeline APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("pipeline")) + + async def get_pipelines( + self, + *, + top: int | None = None, + order_by: str | None = None, + project_id: str, + ) -> AsyncIterator[dict[str, Any]]: + """Get all the pipelines in the project. + + Note: This hasn't been tested with continuation tokens. + + :param top: An optional integer to only get the top N pipelines + :param order_by: A sort expression to use (defaults to "name asc") + :param project_id: The ID of the project + + :returns: The pipelines in the project + """ + + parameters: dict[str, Any] = {"api-version": "7.1-preview.1"} + + if top: + parameters["$top"] = top + + if order_by: + parameters["orderBy"] = order_by + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/pipelines?" + request_url += urllib.parse.urlencode(parameters) + + url = request_url + + while True: + response = await self.http_client.get(url) + decoded = self.http_client.decode_response(response) + for item in decoded["value"]: + yield item + + if not decoded.get("hasMore"): + return + + continuation_token = decoded["continuationToken"] + url = request_url + f"&continuationToken={continuation_token}" + + async def get_pipeline( + self, + *, + project_id: str, + pipeline_id: int, + pipeline_version: int | None = None, + ) -> ADOResponse: + """Get the info for a pipeline. + + :param project_id: The ID of the project + :param pipeline_id: The identifier of the pipeline to get the info for + :param pipeline_version: The version of the pipeline to get the info for + + :returns: The ADO response with the data in it + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/pipelines/{pipeline_id}?api-version=7.1-preview.1" + ) + + if pipeline_version: + request_url += f"pipelineVersion={pipeline_version}" + + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def preview( + self, + *, + project_id: str, + pipeline_id: int, + pipeline_version: int | None = None, + ) -> str | None: + """Queue a dry run of the pipeline to return the final yaml. + + :param project_id: The ID of the project + :param pipeline_id: The identifier of the pipeline to get the info for + :param pipeline_version: The version of the pipeline to get the info for + + :returns: The raw YAML generated after parsing the templates (None if it is not a YAML pipeline) + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/pipelines/{pipeline_id}/preview?api-version=7.1-preview.1" + ) + + if pipeline_version: + request_url += f"pipelineVersion={pipeline_version}" + + body = { + "previewRun": True, + } + + response = await self.http_client.post(request_url, json_data=body) + data = self.http_client.decode_response(response) + return cast(str | None, data.get("finalYaml")) + + async def get_top_ten_thousand_runs(self, *, project_id: str, pipeline_id: int) -> ADOResponse: + """Get the top 10,000 runs for a pipeline. + + :param project_id: The ID of the project + :param pipeline_id: The identifier of the pipeline to get the runs for + + :returns: The ADO response with the data in it + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/pipelines/{pipeline_id}/runs?api-version=7.1" + ) + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def get_run(self, *, project_id: str, pipeline_id: int, run_id: int) -> ADOResponse: + """Get a pipeline run. + + :param project_id: The ID of the project + :param pipeline_id: The identifier of the pipeline to get the run for + :param run_id: The identifier of the run to get + + :returns: The ADO response with the data in it + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/pipelines/{pipeline_id}/runs/{run_id}?api-version=7.1" + ) + + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def run_pipeline( + self, + *, + project_id: str, + pipeline_id: int, + pipeline_version: int | None = None, + preview_run: bool = False, + resources: dict[str, Any] | None = None, + stages_to_skip: list[str] | None = None, + template_parameters: dict[str, Any] | None = None, + variables: dict[str, Any] | None = None, + ) -> ADOResponse: + """Run a pipeline. + + :param project_id: The ID of the project + :param pipeline_id: The identifier of the pipeline to run + :param pipeline_version: The version of the pipeline to run (leave unset to run latest) + :param preview_run: If True, this will run a dry run of the pipeline to return the final yaml + :param resources: The resources to use for the run. See + https://learn.microsoft.com/en-us/rest/api/azure/devops/pipelines/runs/run-pipeline?view=azure-devops-rest-6.1#runresourcesparameters + :param stages_to_skip: A list of stages to skip if any + :param template_parameters: The template parameters to use for the run. A map of keys to values. + :param variables: The variables to use for the run. A map of strings to `Variable`. See: + https://learn.microsoft.com/en-us/rest/api/azure/devops/pipelines/runs/run-pipeline?view=azure-devops-rest-6.1#variable + + :returns: The ADO response with the data in it + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/pipelines/{pipeline_id}/runs?api-version=7.1" + ) + + if pipeline_version: + request_url += f"pipelineVersion={pipeline_version}" + + body: dict[str, Any] = {"previewRun": preview_run} + + if resources: + body["resources"] = resources + + if stages_to_skip: + body["stagesToSkip"] = stages_to_skip + + if template_parameters: + body["templateParameters"] = template_parameters + + if variables: + body["variables"] = variables + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) diff --git a/simple_ado/_async/pools.py b/simple_ado/_async/pools.py new file mode 100644 index 0000000..334b928 --- /dev/null +++ b/simple_ado/_async/pools.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO pools API wrapper (async).""" + +import enum +import logging +from typing import Any +import urllib.parse + + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado._async.http_client import ADOAsyncHTTPClient, ADOResponse + + +class TaskAgentPoolActionFilter(enum.Enum): + """Represents an agent pool action filter.""" + + MANAGE = "manage" + NONE = "none" + USE = "use" + + +class ADOAsyncPoolsClient(ADOAsyncBaseClient): + """Wrapper class around the undocumented ADO pools APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("pools")) + + async def get_pools( + self, + *, + pool_name: str | None = None, + action_filter: TaskAgentPoolActionFilter | None = None, + ) -> ADOResponse: + """Gets the agent details. + + :param pool_name: The name of the pool to match on + :param action_filter: Set to filter on the type of pools + + :returns: The ADO response with the data in it + """ + + request_url = self.http_client.api_endpoint(is_default_collection=False) + request_url += "/distributedtask/pools?" + + parameters = {"api-version": "7.1"} + + if pool_name: + parameters["poolName"] = pool_name + + if action_filter: + parameters["actionFilter"] = action_filter.value + + request_url += urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def get_agents( + self, + *, + pool_id: int, + agent_name: str | None = None, + include_capabilities: bool = True, + include_assigned_request: bool = True, + include_last_completed_request: bool = True, + ) -> ADOResponse: + """Gets the agents details. + + :param pool_id: The ID of the pool the agents are in + :param agent_name: The name of the agent to match on + :param include_capabilities: Set to False to not include capabilities + :param include_assigned_request: Set to False to not include the current assigned request + :param include_last_completed_request: Set to False to not include the last completed request + + :returns: The ADO response with the data in it + """ + + request_url = self.http_client.api_endpoint(is_default_collection=False) + request_url += f"/distributedtask/pools/{pool_id}/agents?" + + parameters = { + "includeCapabilities": include_capabilities, + "includeAssignedRequest": include_assigned_request, + "includeLastCompletedRequest": include_last_completed_request, + "api-version": "7.1", + } + + if agent_name: + parameters["agentName"] = agent_name + + request_url += urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def get_agent(self, *, pool_id: int, agent_id: int) -> ADOResponse: + """Gets the agent details. + + :param pool_id: The ID of the pool the agent is in + :param agent_id: The ID of the agent to get + + :returns: The ADO response with the data in it + """ + + request_url = self.http_client.api_endpoint(is_default_collection=False) + request_url += f"/distributedtask/pools/{pool_id}/agents/{agent_id}?api-version=7.1" + + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def update_agent( + self, *, pool_id: int, agent_id: int, agent_data: dict[str, Any] + ) -> ADOResponse: + """Adds required reviewers when opening PRs against a given branch. + + :param pool_id: The ID of the pool the agent is in + :param agent_id: The ID of the agent to disable + :param agent_data: The data to set on the agent + + :returns: The ADO response with the data in it + """ + + request_url = self.http_client.api_endpoint(is_default_collection=False) + request_url += f"/distributedtask/pools/{pool_id}/agents/{agent_id}?api-version=7.1" + + response = await self.http_client.patch(request_url, json_data=agent_data) + return self.http_client.decode_response(response) + + async def set_agent_state(self, *, pool_id: int, agent_id: int, enabled: bool) -> ADOResponse: + """Set the enabled/disabled state of an agent. + + :param pool_id: The ID of the pool the agent is in + :param agent_id: The ID of the agent to disable + :param enabled: Set to True to enable an agent, False to disable it + + :returns: The ADO response with the data in it + """ + + agent_details = await self.get_agent(pool_id=pool_id, agent_id=agent_id) + agent_details["enabled"] = enabled + return await self.update_agent(pool_id=pool_id, agent_id=agent_id, agent_data=agent_details) + + async def update_agent_capabilities( + self, *, pool_id: int, agent_id: int, capabilities: dict[str, str] + ) -> ADOResponse: + """Set the enabled/disabled state of an agent. + + :param pool_id: The ID of the pool the agent is in + :param agent_id: The ID of the agent to disable + :param capabilities: The new capabilities to set + + :returns: The ADO response with the data in it + """ + + request_url = self.http_client.api_endpoint(is_default_collection=False) + request_url += ( + f"/distributedtask/pools/{pool_id}/agents/{agent_id}/usercapabilities?api-version=7.1" + ) + + response = await self.http_client.put(request_url, json_data=capabilities) + return self.http_client.decode_response(response) diff --git a/simple_ado/_async/pull_requests.py b/simple_ado/_async/pull_requests.py new file mode 100644 index 0000000..095f81a --- /dev/null +++ b/simple_ado/_async/pull_requests.py @@ -0,0 +1,495 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO Pull Request API wrapper (async).""" + +import enum +import logging +from typing import Any + +import deserialize + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado._async.git import ADOGitStatusState +from simple_ado._async.http_client import ADOAsyncHTTPClient, ADOResponse, ADOThread +from simple_ado.comments import ( + ADOComment, + ADOCommentLocation, + ADOCommentProperty, + ADOCommentStatus, +) +from simple_ado.exceptions import ADOException + +from simple_ado.models import ( + PatchOperation, + AddOperation, + DeleteOperation, + PropertyValue, +) + + +class ADOPullRequestStatus(enum.Enum): + """Possible values of pull request states.""" + + ABANDONED = "abandoned" + ACTIVE = "active" + ALL = "all" + COMPLETED = "completed" + NOT_SET = "notSet" + + +class ADOAsyncPullRequestClient(ADOAsyncBaseClient): + """Wrapper class around the ADO Pull Request APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + :param pull_request_id: The ID of the pull request + :param project_id: The ID of the project the PR is in + :param repository_id: The ID of the repository the PR is for + """ + + pull_request_id: int + project_id: str + repository_id: str + + def __init__( + self, + http_client: ADOAsyncHTTPClient, + log: logging.Logger, + pull_request_id: int, + project_id: str, + repository_id: str, + ) -> None: + self.pull_request_id = pull_request_id + self.repository_id = repository_id + self.project_id = project_id + super().__init__(http_client, log.getChild(f"pr.{pull_request_id}")) + + async def details(self) -> ADOResponse: + """Get the details for the PR from ADO. + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Getting PR: {self.pull_request_id}") + request_url = ( + self.http_client.api_endpoint(project_id=self.project_id) + + f"/git/repositories/{self.repository_id}" + + f"/pullRequests/{self.pull_request_id}?api-version=7.1" + ) + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def workitems(self) -> ADOResponse: + """Get the workitems associated with the PR from ADO. + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Getting workitems: {self.pull_request_id}") + request_url = ( + self.http_client.api_endpoint(project_id=self.project_id) + + f"/git/repositories/{self.repository_id}" + + f"/pullRequests/{self.pull_request_id}/workitems?api-version=7.1" + ) + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def iterations(self) -> ADOResponse: + """Get the iterations of this PR. + + :returns: The ADO response with the iterations data in it + """ + self.log.debug(f"Getting iterations: {self.pull_request_id}") + request_url = ( + self.http_client.api_endpoint(project_id=self.project_id) + + f"/git/repositories/{self.repository_id}" + + f"/pullRequests/{self.pull_request_id}/iterations?api-version=7.1" + ) + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def get_threads(self, *, include_deleted: bool = False) -> list[ADOThread]: + """Get the comments on the PR from ADO. + + :param include_deleted: Set to True if deleted threads should be included. + + :returns: A list of ADOThreads that were found + """ + + self.log.debug(f"Getting threads: {self.pull_request_id}") + request_url = ( + self.http_client.api_endpoint(project_id=self.project_id) + + f"/git/repositories/{self.repository_id}" + + f"/pullRequests/{self.pull_request_id}/threads?api-version=7.1" + ) + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + comments: list[ADOThread] = self.http_client.extract_value(response_data) + + if include_deleted: + return comments + + return [comment for comment in comments if comment["isDeleted"] is False] + + async def create_comment_with_text( + self, + comment_text: str, + *, + comment_location: ADOCommentLocation | None = None, + status: ADOCommentStatus | None = None, + comment_identifier: str | None = None, + ) -> ADOResponse: + """Create a thread using a single root comment. + + :param comment_text: The text to set in the comment. + :param comment_location: The location to place the comment. + :param status: The status of the comment + :param comment_identifier: A unique identifier for the comment that can be used for identification + at a later date + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Creating comment: ({self.pull_request_id}) {comment_text}") + comment = ADOComment(comment_text, comment_location) + return await self.create_comment( + comment, + status=status, + comment_identifier=comment_identifier, + ) + + async def create_comment( + self, + comment: ADOComment, + *, + status: ADOCommentStatus | None = None, + comment_identifier: str | None = None, + ) -> ADOResponse: + """Create a thread using a single root comment. + + :param comment: The comment to add. + :param status: The status of the comment + :param comment_identifier: A unique identifier for the comment that can be used for identification + at a later date + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Creating comment: ({self.pull_request_id}) {comment}") + return await self.create_thread( + comments=[comment.generate_representation()], + thread_location=comment.location, + status=status, + comment_identifier=comment_identifier, + ) + + async def create_thread( + self, + *, + comments: list[dict[str, Any]], + thread_location: ADOCommentLocation | None = None, + status: ADOCommentStatus | None = None, + comment_identifier: str | None = None, + ) -> ADOResponse: + """Create a thread on a PR. + + :param comments: The comments to add to the thread. + :param thread_location: The location the thread should be added + :param status: The status of the comment + :param comment_identifier: A unique identifier for the comment that can be used for identification + at a later date + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Creating thread ({self.pull_request_id})") + + request_url = ( + self.http_client.api_endpoint(project_id=self.project_id) + + f"/git/repositories/{self.repository_id}" + + f"/pullRequests/{self.pull_request_id}/threads?api-version=7.1" + ) + + properties = { + ADOCommentProperty.SUPPORTS_MARKDOWN: ADOCommentProperty.create_bool(True), + } + + if comment_identifier: + properties[ADOCommentProperty.COMMENT_IDENTIFIER] = ADOCommentProperty.create_string( + comment_identifier + ) + + body: dict[str, Any] = { + "comments": comments, + "properties": properties, + "status": (status.value if status is not None else ADOCommentStatus.ACTIVE.value), + } + + if thread_location is not None: + body["threadContext"] = thread_location.generate_representation() + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) + + async def delete_thread(self, thread: ADOThread) -> None: + """Delete a comment thread from a pull request. + + :param thread: The thread to delete + """ + + thread_id = thread["id"] + + self.log.debug(f"Deleting thread: ({self.pull_request_id}) {thread_id}") + + for comment in thread["comments"]: + comment_id = comment["id"] + self.log.debug(f"Deleting comment: {comment_id}") + request_url = self.http_client.api_endpoint(project_id=self.project_id) + request_url += f"/git/repositories/{self.repository_id}" + request_url += f"/pullRequests/{self.pull_request_id}/threads/{thread_id}" + request_url += f"/comments/{comment_id}?api-version=7.1" + await self.http_client.delete(request_url) + + async def create_thread_list( + self, + *, + threads: list[ADOComment], + comment_identifier: str | None = None, + ) -> None: + """Create a list of threads + + :param threads: The threads to create + :param comment_identifier: A unique identifier for the comments that can be used for + identification at a later date + + :raises ADOException: If a thread is not an ADO comment + """ + + self.log.debug(f"Setting threads on PR: {self.pull_request_id}") + + # Check the type of the input + for thread in threads: + if not isinstance(thread, ADOComment): # pyright: ignore[reportUnnecessaryIsInstance] + raise ADOException("Thread was not an ADOComment: " + str(thread)) + + for thread in threads: + self.log.debug("Adding thread") + await self.create_comment(thread, comment_identifier=comment_identifier) + + async def get_statuses(self) -> ADOResponse: + """Get the statuses on a PR. + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Getting PR statuses on PR {self.pull_request_id}") + + request_url = ( + self.http_client.api_endpoint(project_id=self.project_id) + + f"/git/repositories/{self.repository_id}" + + f"/pullRequests/{self.pull_request_id}/statuses?api-version=7.1" + ) + + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def set_status( + self, + state: ADOGitStatusState, + identifier: str, + description: str, + context: str, + *, + iteration: int | None = None, + target_url: str | None = None, + ) -> ADOResponse: + """Set a status on a PR. + + :param state: The state to set the status to. + :param identifier: A unique identifier for the status (so it can be changed later) + :param description: The text to show in the status + :param context: The context for the build status + :param iteration: The iteration of the PR to set the status on + :param target_url: An optional URL to set which is opened when the description is clicked. + + :returns: The ADO response with the data in it + """ + + self.log.debug( + f"Setting PR status ({state}) on PR ({self.pull_request_id}): {identifier} -> {description}" + ) + + request_url = ( + self.http_client.api_endpoint(project_id=self.project_id) + + f"/git/repositories/{self.repository_id}" + + f"/pullRequests/{self.pull_request_id}/statuses?api-version=7.1" + ) + + body: dict[str, Any] = { + "state": state.value, + "description": description, + "context": {"name": context, "genre": identifier}, + } + + if iteration is not None: + body["iterationId"] = iteration + + if target_url is not None: + body["targetUrl"] = target_url + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) + + def _thread_matches_identifier(self, thread: ADOThread, identifier: str) -> bool: + """Check if the ADO thread matches the user and identifier + + :param thread: The thread to check + :param identifier: The identifier to check against + + :returns: True if the thread matches, False otherwise + + :raises ADOException: If we couldn't find the author or the properties + """ + + _ = self + + try: + # Deleted threads can stay around if they have other comments, so we + # check if it was deleted before we check anything else. + if thread["comments"][0]["isDeleted"]: + return False + except Exception: + # If it's not there, it's not set + pass + + try: + properties = thread["properties"] + except Exception as ex: + raise ADOException("Could not find properties in thread: " + str(thread)) from ex + + if properties is None: + return False + + comment_identifier = properties.get(ADOCommentProperty.COMMENT_IDENTIFIER) + + if comment_identifier is None: + return False + + value = comment_identifier.get("$value") + + if value == identifier: + return True + + return False + + async def threads_with_identifier(self, identifier: str) -> list[ADOThread]: + """Get the threads on a PR which begin with the prefix specified. + + :param identifier: The identifier to look for threads with + + :returns: The list of threads matching the identifier + + :raises ADOException: If the response is in an unexpected format + """ + + self.log.debug( + f'Fetching threads with identifier "{identifier}" on PR {self.pull_request_id}' + ) + + matching_threads: list[ADOThread] = [] + + for thread in await self.get_threads(): + self.log.debug("Handling thread...") + + if self._thread_matches_identifier(thread, identifier): + matching_threads.append(thread) + + return matching_threads + + async def delete_threads_with_identifier(self, identifier: str) -> None: + """Delete the threads on a PR which begin with the prefix specified. + + :param identifier: The identifier property value to look for threads matching + """ + + self.log.debug( + f'Deleting threads with identifier "{identifier}" on PR {self.pull_request_id}' + ) + + for thread in await self.threads_with_identifier(identifier): + self.log.debug(f"Deleting thread: {thread}") + await self.delete_thread(thread=thread) + + async def get_properties(self) -> dict[str, PropertyValue]: + """Get the properties on the PR from ADO. + + :returns: The properties that were found + """ + + self.log.debug(f"Getting properties: {self.pull_request_id}") + request_url = ( + self.http_client.api_endpoint(project_id=self.project_id) + + f"/git/repositories/{self.repository_id}" + + f"/pullRequests/{self.pull_request_id}/properties?api-version=7.1" + ) + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + raw_properties = self.http_client.extract_value(response_data) + + properties = deserialize.deserialize(dict[str, PropertyValue], raw_properties) + + return properties + + async def patch_properties(self, operations: list[PatchOperation]) -> dict[str, PropertyValue]: + """Patch the properties on the PR. + + Usually add_property(), delete_property() and update_property() are + going to be what you need instead of this base function. + + :param operations: The raw operations + + :returns: The new properties + """ + + self.log.debug(f"Patching properties: {self.pull_request_id}") + request_url = ( + self.http_client.api_endpoint(project_id=self.project_id) + + f"/git/repositories/{self.repository_id}" + + f"/pullRequests/{self.pull_request_id}/properties?api-version=7.1" + ) + + response = await self.http_client.patch(request_url, operations=operations) + + response_data = self.http_client.decode_response(response) + raw_properties = self.http_client.extract_value(response_data) + + properties = deserialize.deserialize(dict[str, PropertyValue], raw_properties) + + return properties + + async def add_property(self, name: str, value: str) -> dict[str, PropertyValue]: + """Add a property to the PR. + + :param name: The name of the property to add + :param value: The value of the property to add + + :returns: The new properties + """ + + operation = AddOperation("/" + name, value) + return await self.patch_properties([operation]) + + async def delete_property(self, name: str) -> dict[str, PropertyValue]: + """Delete a property from the PR. + + :param name: The name of the property to delete + + :returns: The new properties + """ + + operation = DeleteOperation("/" + name) + return await self.patch_properties([operation]) diff --git a/simple_ado/_async/security.py b/simple_ado/_async/security.py new file mode 100644 index 0000000..2356559 --- /dev/null +++ b/simple_ado/_async/security.py @@ -0,0 +1,649 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO security API wrapper (async).""" + +import enum +import json +import logging +from typing import Any, ClassVar, cast +import urllib.parse + + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado.exceptions import ADOException +from simple_ado._async.http_client import ADOAsyncHTTPClient, ADOResponse +from simple_ado.ado_types import TeamFoundationId + + +class ADOBranchPermission(enum.IntEnum): + """Possible types of git branch permissions.""" + + ADMINISTER = 2**0 + READ = 2**1 + CONTRIBUTE = 2**2 + FORCE_PUSH = 2**3 + CREATE_BRANCH = 2**4 + CREATE_TAG = 2**5 + MANAGE_NOTES = 2**6 + BYPASS_PUSH_POLICIES = 2**7 + CREATE_REPOSITORY = 2**8 + DELETE_REPOSITORY = 2**9 + RENAME_REPOSITORY = 2**10 + EDIT_POLICIES = 2**11 + REMOVE_OTHERS_LOCKS = 2**12 + MANAGE_PERMISSIONS = 2**13 + CONTRIBUTE_TO_PULL_REQUESTS = 2**14 + BYPASS_PULL_REQUEST_POLICIES = 2**15 + + +class ADOBranchPermissionLevel(enum.IntEnum): + """Possible values of git branch permissions.""" + + NOT_SET = 0 + ALLOW = 1 + DENY = 2 + + +class ADOBranchPolicy(enum.Enum): + """Possible types of git branch protections.""" + + APPROVAL_COUNT = "fa4e907d-c16b-4a4c-9dfa-4906e5d171dd" + BUILD = "0609b952-1397-4640-95ec-e00a01b2c241" + CASE_ENFORCEMENT = "7ed39669-655c-494e-b4a0-a08b4da0fcce" + MAXIMUM_BLOB_SIZE = "2e26e725-8201-4edd-8bf5-978563c34a80" + MERGE_STRATEGY = "fa4e907d-c16b-4a4c-9dfa-4916e5d171ab" + REQUIRED_REVIEWERS = "fd2167ab-b0be-447a-8ec8-39368250530e" + STATUS_CHECK = "cbdc66da-9728-4af8-aada-9a5a32e4a226" + WORK_ITEM = "40e92b44-2fe1-4dd6-b3d8-74a9c21d0c6e" + + +class ADOPolicyApplicability(enum.Enum): + """Different types of policy applicability.""" + + APPLY_BY_DEFAULT = None + CONDITIONAL = 1 + + +class ADOAsyncSecurityClient(ADOAsyncBaseClient): + """Wrapper class around the undocumented ADO Security APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + GIT_PERMISSIONS_NAMESPACE: ClassVar[str] = "2e9eb7ed-3c0a-47d4-87c1-0ffdd275fd87" + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("security")) + + async def get_policies(self, project_id: str) -> ADOResponse: + """Gets the existing policies. + + :param project_id: The identifier for the project + + :returns: The ADO response with the data in it + """ + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + "/policy/Configurations?api-version=7.1" + ) + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def delete_policy(self, project_id: str, policy_id: int) -> None: + """Delete a policy. + + :param project_id: The identifier for the project + :param policy_id: The ID of the policy to delete + """ + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/policy/Configurations/{policy_id}?api-version=7.1" + ) + response = await self.http_client.delete(request_url) + self.http_client.validate_response(response) + + # pylint: disable=too-many-locals + async def add_branch_status_check_policy( + self, + *, + branch: str, + is_blocking: bool = True, + is_enabled: bool = True, + required_status_author_id: str | None = None, + default_display_name: str | None = None, + invalidate_on_source_update: bool = True, + filename_filter: list[str] | None = None, + applicability: ADOPolicyApplicability = ADOPolicyApplicability.APPLY_BY_DEFAULT, + status_name: str, + status_genre: str, + project_id: str, + repository_id: str, + ) -> ADOResponse: + """Adds a new status check policy for a given branch. + + :param branch: The git branch to set the policy for + :param is_blocking: Whether the status blocks PR completion or not. + :param is_enabled: Whether the status is enabled or not. + :param required_status_author_id: The ID of a required author (None if anyone) + :param default_display_name: The default display name for the policy + :param invalidate_on_source_update: Set to True to invalid the status when an update to + the PR happens, False otherwise + :param filename_filter: A list of file name filters this policy should + only apply to + :param applicability: Set to apply always or just if the status is posted + :param status_name: The name of the status + :param status_genre: The genre of the status + :param project_id: The identifier for the project + :param repository_id: The ID for the repository + + :returns: The ADO response with the data in it + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + "/policy/Configurations?api-version=7.1" + ) + + settings: dict[str, Any] = { + "authorId": required_status_author_id, + "defaultDisplayName": default_display_name, + "invalidateOnSourceUpdate": invalidate_on_source_update, + "policyApplicability": applicability.value, + "statusName": status_name, + "statusGenre": status_genre, + "scope": [ + { + "repositoryId": repository_id, + "refName": f"refs/heads/{branch}", + "matchKind": "Exact", + } + ], + } + + if filename_filter: + settings["filenamePatterns"] = filename_filter + + body: dict[str, Any] = { + "type": {"id": ADOBranchPolicy.STATUS_CHECK.value}, + "revision": 1, + "isDeleted": False, + "isBlocking": is_blocking, + "isEnabled": is_enabled, + "settings": settings, + } + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) + + # pylint: enable=too-many-locals + + async def add_branch_build_policy( + self, + *, + branch: str, + build_definition_id: int, + build_expiration: int | None = None, + project_id: str, + repository_id: str, + ) -> ADOResponse: + """Adds a new build policy for a given branch. + + :param branch: The git branch to set the build policy for + :param build_definition_id: The build definition to use when creating the build policy + :param build_expiration: How long in minutes before the build expires. Set to None for + immediately on changes to source branch. + :param project_id: The identifier for the project + :param repository_id: The ID for the repository + + :returns: The ADO response with the data in it + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + "/policy/Configurations?api-version=7.1" + ) + + body: dict[str, Any] = { + "type": {"id": ADOBranchPolicy.BUILD.value}, + "revision": 1, + "isDeleted": False, + "isBlocking": True, + "isEnabled": True, + "settings": { + "buildDefinitionId": build_definition_id, + "displayName": None, + "queueOnSourceUpdateOnly": build_expiration is not None, + "manualQueueOnly": False, + "validDuration": (build_expiration if build_expiration is not None else 0), + "scope": [ + { + "refName": f"refs/heads/{branch}", + "matchKind": "Exact", + "repositoryId": repository_id, + } + ], + }, + } + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) + + async def add_branch_required_reviewers_policy( + self, + *, + branch: str, + identities: list[str], + project_id: str, + repository_id: str, + ) -> ADOResponse: + """Adds required reviewers when opening PRs against a given branch. + + :param branch: The git branch to set required reviewers for + :param identities: A list of identities to become required + reviewers (should be team foundation IDs) + :param project_id: The identifier for the project + :param repository_id: The ID for the repository + + :returns: The ADO response with the data in it + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + "/policy/Configurations?api-version=7.1" + ) + + body: dict[str, Any] = { + "type": {"id": ADOBranchPolicy.REQUIRED_REVIEWERS.value}, + "revision": 1, + "isDeleted": False, + "isBlocking": True, + "isEnabled": True, + "settings": { + "requiredReviewerIds": identities, + "filenamePatterns": [], + "addedFilesOnly": False, + "ignoreIfSourceIsInScope": False, + "message": None, + "scope": [ + { + "refName": f"refs/heads/{branch}", + "matchKind": "Exact", + "repositoryId": repository_id, + } + ], + }, + } + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) + + async def set_branch_approval_count_policy( + self, + *, + branch: str, + minimum_approver_count: int, + creator_vote_counts: bool = False, + reset_on_source_push: bool = False, + project_id: str, + repository_id: str, + ) -> ADOResponse: + """Set minimum number of reviewers for a branch. + + :param branch: The git branch to set minimum number of reviewers on + :param minimum_approver_count: The minimum number of approvals required + :param creator_vote_counts: Allow users to approve their own changes + :param reset_on_source_push: Reset reviewer votes when there are new changes + :param project_id: The identifier for the project + :param repository_id: The ID for the repository + + :returns: The ADO response with the data in it + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + "/policy/Configurations?api-version=7.1" + ) + + body: dict[str, Any] = { + "type": {"id": ADOBranchPolicy.APPROVAL_COUNT.value}, + "revision": 2, + "isDeleted": False, + "isBlocking": True, + "isEnabled": True, + "settings": { + "minimumApproverCount": minimum_approver_count, + "creatorVoteCounts": creator_vote_counts, + "resetOnSourcePush": reset_on_source_push, + "scope": [ + { + "refName": f"refs/heads/{branch}", + "matchKind": "exact", + "repositoryId": repository_id, + } + ], + }, + } + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) + + async def set_branch_work_item_policy( + self, + *, + branch: str, + required: bool = True, + project_id: str, + repository_id: str, + ) -> ADOResponse: + """Set the work item policy for a branch. + + :param branch: The git branch to set the work item policy on + :param required: Whether or not linked work items should be mandatory + :param project_id: The identifier for the project + :param repository_id: The ID for the repository + + :returns: The ADO response with the data in it + """ + + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + "/policy/Configurations?api-version=7.1" + ) + + body: dict[str, Any] = { + "type": {"id": ADOBranchPolicy.WORK_ITEM.value}, + "revision": 2, + "isDeleted": False, + "isBlocking": required, + "isEnabled": True, + "settings": { + "scope": [ + { + "refName": f"refs/heads/{branch}", + "matchKind": "Exact", + "repositoryId": repository_id, + } + ] + }, + } + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) + + async def set_branch_permissions( + self, + *, + branch: str, + identity: TeamFoundationId, + permissions: dict[ADOBranchPermission, ADOBranchPermissionLevel], + project_id: str, + repository_id: str, + ) -> ADOResponse: + """Set permissions for an identity on a branch. + + :param branch: The git branch to set permissions on + :param identity: The identity to set permissions for (should be team foundation ID) + :param permissions: A dictionary of permissions to set + :param project_id: The identifier for the project + :param repository_id: The ID for the repository + + :returns: The ADO response with the data in it + """ + + descriptor_info = await self._get_descriptor_info( + branch=branch, + team_foundation_id=identity, + project_id=project_id, + repository_id=repository_id, + ) + + request_url = self.http_client.api_endpoint(is_internal=True, project_id=project_id) + request_url += "/_security/ManagePermissions?__v=5" + + updates: list[dict[str, Any]] = [] + for permission, level in permissions.items(): + updates.append( + { + "PermissionId": level, + "PermissionBit": permission, + "NamespaceId": ADOAsyncSecurityClient.GIT_PERMISSIONS_NAMESPACE, + "Token": self.generate_updates_token( + branch_name=branch, + project_id=project_id, + repository_id=repository_id, + ), + } + ) + + package: dict[str, Any] = { + "IsRemovingIdentity": False, + "TeamFoundationId": identity, + "DescriptorIdentityType": descriptor_info["type"], + "DescriptorIdentifier": descriptor_info["id"], + "PermissionSetId": ADOAsyncSecurityClient.GIT_PERMISSIONS_NAMESPACE, + "PermissionSetToken": self._generate_permission_set_token( + branch=branch, project_id=project_id, repository_id=repository_id + ), + "RefreshIdentities": False, + "Updates": updates, + "TokenDisplayName": None, + } + + body = {"updatePackage": json.dumps(package)} + + response = await self.http_client.post(request_url, json_data=body) + return self.http_client.decode_response(response) + + async def _get_descriptor_info( + self, + *, + branch: str, + team_foundation_id: TeamFoundationId, + project_id: str, + repository_id: str, + ) -> dict[str, str]: + """Fetch the descriptor identity information for a given identity. + + :param branch: The git branch of interest + :param team_foundation_id: the unique Team Foundation GUID for the identity + :param project_id: The identifier for the project + :param repository_id: The ID for the repository + + :returns: The raw descriptor info + + :raises ADOException: If we can't determine the descriptor info from the response + """ + + request_url = self.http_client.api_endpoint(is_internal=True, project_id=project_id) + request_url += "/_security/DisplayPermissions?" + + parameters: dict[str, Any] = { + "tfid": team_foundation_id, + "permissionSetId": ADOAsyncSecurityClient.GIT_PERMISSIONS_NAMESPACE, + "permissionSetToken": self._generate_permission_set_token( + branch=branch, project_id=project_id, repository_id=repository_id + ), + "__v": "5", + } + + request_url += urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + + try: + descriptor_info = { + "type": response_data["descriptorIdentityType"], + "id": response_data["descriptorIdentifier"], + } + except Exception as ex: + raise ADOException( + "Could not determine descriptor info for team_foundation_id: " + + str(team_foundation_id) + ) from ex + + return descriptor_info + + def _generate_permission_set_token( + self, + branch: str, + project_id: str, + repository_id: str, + ) -> str: + """Generate the token required for reading identity details and writing permissions. + + :param branch: The git branch of interest + :param project_id: The ID for the project + :param repository_id: The ID for the repository + + :returns: The permission token + """ + _ = self + encoded_branch = branch.replace("/", "^") + return f"repoV2/{project_id}/{repository_id}/refs^heads^{encoded_branch}/" + + def generate_updates_token( + self, + *, + project_id: str, + repository_id: str | None = None, + branch_name: str | None = None, + ) -> str: + """Generate the token required for updating permissions. + + A project ID must always be set. Repository ID and branch name are + optional, but if a branch name is set, then a repository ID must also be + set. + + :param project_id: The ID for the project + :param repository_id: The ID for the repository + :param branch_name: The git branch of interest + + :returns: The update token + """ + + _ = self + + token = f"repoV2/{project_id}/" + + if not repository_id: + return token + + token += f"{repository_id}/" + + if not branch_name: + return token + + # Encode each node in the branch to hex + encoded_branch_nodes = [node.encode("utf-16le").hex() for node in branch_name.split("/")] + + encoded_branch = "/".join(encoded_branch_nodes) + + return token + f"refs/heads/{encoded_branch}/" + + async def query_namespaces( + self, *, namespace_id: str, local_only: bool | None = None + ) -> ADOResponse: + """Query a namespace + + :param namespace_id: The identifier for the namespace + :param local_only: Specify whether to check local namespaces only or not + + :returns: The ADO response with the data in it + """ + request_url = ( + self.http_client.api_endpoint() + + f"/securitynamespaces/{namespace_id}?api-version=7.1-preview.1" + ) + + if local_only is not None: + request_url += f"&localOnly={local_only}".lower() + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def query_access_control_lists( + self, + *, + namespace_id: str, + descriptors: list[str] | None = None, + token: str | None = None, + ) -> ADOResponse: + """Query a namespace + + :param namespace_id: The identifier for the namespace + :param descriptors: An optional of list of descriptors to filter down to those. + :param token: An optional token to filter down to + + :returns: The ADO response with the data in it + """ + + if descriptors is None: + descriptors = [] + + descriptors = [ + ( + "Microsoft.TeamFoundation.Identity;" + descriptor + if not descriptor.startswith("Microsoft.TeamFoundation.Identity;") + else descriptor + ) + for descriptor in descriptors + ] + + request_url = ( + self.http_client.api_endpoint() + + f"/accesscontrollists/{namespace_id}?api-version=7.1-preview.1" + ) + + if len(descriptors) > 0: + request_url += "&descriptors=" + ",".join(descriptors) + + if token: + request_url += f"&token={token}" + + response = await self.http_client.get(request_url) + response_data = self.http_client.decode_response(response) + return self.http_client.extract_value(response_data) + + async def get_permissions( + self, + *, + branch: str, + team_foundation_id: TeamFoundationId, + project_id: str, + repository_id: str, + ) -> dict[str, Any]: + """Get the permissions for a branch + + :param branch: The name of the branch to get the permissions for + :param team_foundation_id: the unique Team Foundation GUID for the identity + :param project_id: The identifier for the project + :param repository_id: The ID for the repository + + :returns: The raw descriptor info + + :raises ADOException: If we can't determine the descriptor info from the response + """ + + request_url = self.http_client.api_endpoint(is_internal=True, project_id=project_id) + request_url += "/_security/DisplayPermissions?" + + parameters: dict[str, Any] = { + "tfid": team_foundation_id, + "permissionSetId": ADOAsyncSecurityClient.GIT_PERMISSIONS_NAMESPACE, + "permissionSetToken": self._generate_permission_set_token( + branch=branch, project_id=project_id, repository_id=repository_id + ), + "__v": "5", + } + + request_url += urllib.parse.urlencode(parameters) + + response = await self.http_client.get(request_url) + return cast(dict[str, Any], self.http_client.decode_response(response)) diff --git a/simple_ado/_async/user.py b/simple_ado/_async/user.py new file mode 100644 index 0000000..8643522 --- /dev/null +++ b/simple_ado/_async/user.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO user API wrapper (async).""" + +import logging + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado._async.http_client import ADOAsyncHTTPClient + + +class ADOAsyncUserClient(ADOAsyncBaseClient): + """Wrapper class around the ADO user APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("user")) diff --git a/simple_ado/_async/utilities.py b/simple_ado/_async/utilities.py new file mode 100644 index 0000000..2ae2998 --- /dev/null +++ b/simple_ado/_async/utilities.py @@ -0,0 +1,59 @@ +"""Utilities for dealing with the ADO REST API (async).""" + +import logging +from typing import Callable + +import httpx + +from simple_ado.exceptions import ADOHTTPException + + +def boolstr(value: bool) -> str: + """Return a boolean formatted as string for ADO calls + + :param value: The value to format + + :returns: A string representation of the boolean value + """ + return str(value).lower() + + +async def download_from_response_stream( + *, + response: httpx.Response, + output_path: str, + log: logging.Logger, + callback: Callable[[int, int], None] | None = None, +) -> None: + """Downloads a file from an already open response stream. + + :param response: The response to download from + :param output_path: The path to write the file out to + :param log: The log to use for progress updates + :param callback: If supplied, this will be called on every new chunk to update progress to the caller + + :raises ADOHTTPException: If we fail to fetch the file for any reason + """ + + # A sensible modern value + chunk_size = 1024 * 16 + + if response.status_code < 200 or response.status_code >= 300: + raise ADOHTTPException("Failed to fetch file", response) + + with open(output_path, "wb") as output_file: + content_length_string = response.headers.get("content-length", "0") + + total_size = int(content_length_string) + total_downloaded = 0 + + async for data in response.aiter_bytes(chunk_size=chunk_size): + total_downloaded += len(data) + output_file.write(data) + + if callback is not None: + callback(total_downloaded, total_size) + + if total_size != 0: + progress = int((total_downloaded * 100.0) / total_size) + log.info(f"Download progress: {progress}%") diff --git a/simple_ado/_async/wiki.py b/simple_ado/_async/wiki.py new file mode 100644 index 0000000..234b1e6 --- /dev/null +++ b/simple_ado/_async/wiki.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO Wiki API wrapper (async).""" + +import logging + +from simple_ado.exceptions import ADOException +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado._async.http_client import ADOAsyncHTTPClient, ADOResponse + + +class ADOAsyncWikiClient(ADOAsyncBaseClient): + """Wrapper class around the ADO Wiki APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("wiki")) + + async def get_page_version(self, page_id: str, wiki_id: str, project_id: str) -> ADOResponse: + """Get's the current version of a wiki page. This returns a required parameter for updating a wiki page. + + https://docs.microsoft.com/en-us/rest/api/azure/devops/wiki/pages/get%20page%20by%20id?view=azure-devops-rest-6.1 + + :param page_id: Wiki page ID + :param wiki_id: Wiki ID or Wiki name + :param project_id: The ID of the project + + :returns: The ADO response with the data in it + + :raises ADOException: If we fail to fetch the page version. + """ + + self.log.debug(f"Get wiki page: {page_id}") + request_url = ( + self.http_client.api_endpoint(is_default_collection=False, project_id=project_id) + + f"/wiki/wikis/{wiki_id}/pages/{page_id}?api-version=7.1" + ) + response = await self.http_client.get(request_url) + self.http_client.validate_response(response) + etag = response.headers.get("ETag") + if not etag: + raise ADOException("No ETag returned for wiki page.") + return etag + + async def update_page( + self, + page_id: str, + wiki_id: str, + project_id: str, + content: str, + current_version_etag: str, + ) -> ADOResponse: + """Update a the contents of a wiki page. + + https://docs.microsoft.com/en-us/rest/api/azure/devops/wiki/pages/update?view=azure-devops-rest-6.1 + + :param page_id: Wiki page ID + :param wiki_id: Wiki ID or Wiki name + :param project_id: The ID of the project + :param content: Content of the wiki page. + :param current_version_etag: The ETag of the current wiki page to verify the update. + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Updating wiki page: {page_id}") + request_url = ( + self.http_client.api_endpoint(is_default_collection=False, project_id=project_id) + + f"/wiki/wikis/{wiki_id}/pages/{page_id}?api-version=7.1" + ) + response = await self.http_client.patch( + request_url, + json_data={ + "content": content, + }, + additional_headers={ + "Content-Type": "application/json", + "If-Match": current_version_etag, + }, + ) + return self.http_client.decode_response(response) diff --git a/simple_ado/_async/work_item.py b/simple_ado/_async/work_item.py new file mode 100644 index 0000000..6b63e4c --- /dev/null +++ b/simple_ado/_async/work_item.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO work item wrapper with lazy loading and patching capabilities (async).""" + +import logging +from typing import Any, TYPE_CHECKING + +from simple_ado.exceptions import ADOException +from simple_ado.models import ADOWorkItemBuiltInFields, ReplaceOperation + +if TYPE_CHECKING: + from simple_ado._async.workitems import ADOAsyncWorkItemsClient + + +class ADOAsyncWorkItem: + """Wrapper class for work item data with lazy loading and patching capabilities. + + :param data: The work item data from the API response + :param client: The work items client for API operations + :param project_id: The ID of the project this work item belongs to + :param log: The logger to use + """ + + _data: dict[str, Any] + _client: "ADOAsyncWorkItemsClient" + _project_id: str + _log: logging.Logger + + def __init__( + self, + data: dict[str, Any], + client: "ADOAsyncWorkItemsClient", + project_id: str, + log: logging.Logger, + ) -> None: + """Initialize the work item wrapper. + + :param data: The work item data from the API response + :param client: The work items client for API operations + :param project_id: The ID of the project this work item belongs to + :param log: The logger to use + """ + self._data = data + self._client = client + self._project_id = project_id + work_item_id = data.get("id", "unknown") + self._log = log.getChild(f"workitem.{work_item_id}") + + @property + def id(self) -> int | None: + """Get the work item ID. + + :returns: The work item ID, or None if not present + """ + return self._data.get("id") + + @property + def data(self) -> dict[str, Any]: + """Get the raw work item data. + + :returns: The complete work item data dictionary + """ + return self._data + + def __getitem__(self, key: str | ADOWorkItemBuiltInFields) -> Any: + """Get a field value from the work item. + + Supports both string field names and ADOWorkItemBuiltInFields enum values. + + :param key: The field name or ADOWorkItemBuiltInFields enum value + + :returns: The field value + + :raises KeyError: If the field is not found + """ + # Convert enum to string value if needed + field_name = key.value if isinstance(key, ADOWorkItemBuiltInFields) else key + + # Try to get from fields dict + fields = self._data.get("fields", {}) + if field_name in fields: + return fields[field_name] + + raise KeyError(f"Field '{field_name}' not found in work item {self.id}") + + async def get_field( + self, + key: str | ADOWorkItemBuiltInFields, + *, + auto_refresh: bool = True, + ) -> Any: + """Get a field value, optionally refreshing from the server if not found. + + Unlike ``__getitem__``, this method can auto-refresh the work item data + from the server when a field is missing, which is useful for fields that + weren't included in the initial response. + + :param key: The field name or ADOWorkItemBuiltInFields enum value + :param auto_refresh: If True (default), refresh from server when field is missing + + :returns: The field value + + :raises KeyError: If the field is not found (even after refresh) + """ + field_name = key.value if isinstance(key, ADOWorkItemBuiltInFields) else key + fields = self._data.get("fields", {}) + + if field_name in fields: + return fields[field_name] + + if not auto_refresh: + raise KeyError(f"Field '{field_name}' not found in work item {self.id}") + + await self.refresh() + + fields = self._data.get("fields", {}) + if field_name in fields: + return fields[field_name] + + raise KeyError(f"Field '{field_name}' not found in work item {self.id}") + + async def __setitem_async__(self, key: str | ADOWorkItemBuiltInFields, value: Any) -> None: + """Set a field value and patch it on the server. + + Equivalent to ``await work_item.patch(key, value)``. + + :param key: The field name or ADOWorkItemBuiltInFields enum value + :param value: The new value for the field + """ + await self.patch(key, value) + + async def set(self, key: str | ADOWorkItemBuiltInFields, value: Any) -> None: + """Set a field value and patch it on the server. + + Convenience alias for ``await work_item.patch(key, value)``. + + :param key: The field name or ADOWorkItemBuiltInFields enum value + :param value: The new value for the field + """ + await self.patch(key, value) + + async def refresh(self) -> None: + """Refresh the work item data from the server. + + This reloads all fields from the API, which is useful for populating + missing fields or getting the latest values. + + :raises ADOException: If the work item ID is not available + """ + work_item_id = self.id + if work_item_id is None: + raise ADOException("Cannot refresh work item without an ID") + + self._log.debug(f"Refreshing work item {work_item_id}") + self._data = await self._client.get(str(work_item_id), self._project_id) + + async def patch( + self, + field: str | ADOWorkItemBuiltInFields, + value: Any, + bypass_rules: bool = False, + supress_notifications: bool = False, + ) -> None: + """Patch a field on the work item. + + This updates the field on the server and refreshes the local data. + + :param field: The field name or ADOWorkItemBuiltInFields enum value + :param value: The new value for the field + :param bypass_rules: Set to True if we should bypass validation rules + :param supress_notifications: Set to True if notifications should be suppressed + + :raises ADOException: If the work item ID is not available + """ + work_item_id = self.id + if work_item_id is None: + raise ADOException("Cannot patch work item without an ID") + + # We need the prefix to patch. If it's an enum, it gets handled further on + if isinstance(field, str): + field = f"/fields/{field}" + + self._log.debug(f"Patching field '{field}' on work item {work_item_id}") + + # Create replace operation + operation = ReplaceOperation(field, value) + + # Call the client's update method + response = await self._client.update( + identifier=str(work_item_id), + operations=[operation], + project_id=self._project_id, + bypass_rules=bypass_rules, + supress_notifications=supress_notifications, + ) + + # Update local data with response + self._data = response + + def __repr__(self) -> str: + """Get a string representation of the work item. + + :returns: A string representation showing the work item ID and type + """ + work_item_id = self.id + work_item_type = self._data.get("fields", {}).get("System.WorkItemType", "Unknown") + return f"" diff --git a/simple_ado/_async/workitems.py b/simple_ado/_async/workitems.py new file mode 100644 index 0000000..d35760d --- /dev/null +++ b/simple_ado/_async/workitems.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""ADO work items API wrapper (async).""" + +import logging +import os +import typing +from typing import Any, AsyncIterator, Iterator, List, TypeVar, cast + +from simple_ado._async.base_client import ADOAsyncBaseClient +from simple_ado._async.http_client import ADOAsyncHTTPClient, ADOResponse +from simple_ado._async.utilities import boolstr +from simple_ado._async.work_item import ADOAsyncWorkItem +from simple_ado.exceptions import ADOException, ADOHTTPException + +from simple_ado.models import ( + PatchOperation, + AddOperation, + WorkItemRelationType, + ADOWorkItemBuiltInFields, +) + +T = TypeVar("T") + + +# batched is only available in Python 3.12+ +def _batched(sequence: List[T], n: int) -> Iterator[List[T]]: + """Batch data into lists of length n. + + :param sequence: The iterable to batch + :param n: The batch size + + :yields: Lists of size n (or smaller for the last batch) + """ + for i in range(0, len(sequence), n): + yield sequence[i : i + n] + + +class BatchRequest: + """The base type for a batch request. + + :param method: The HTTP method to use for the batch request + :param uri: The URI for the batch request + :param headers: The headers to be sent with the batch request + """ + + method: str + uri: str + headers: dict[str, str] + + def __init__(self, method: str, uri: str, headers: dict[str, str]) -> None: + self.method = method + self.uri = uri + self.headers = headers + + def body(self) -> dict[str, Any]: + """Generate the body of the request to be used in the API call. + + :returns: A dictionary with the raw API data for the request + """ + return {"method": self.method, "uri": self.uri, "headers": self.headers} + + +class DeleteBatchRequest(BatchRequest): + """A deletion batch request. + + :param uri: The URI for the batch request + :param headers: The headers to be sent with the batch request + """ + + def __init__(self, uri: str, headers: dict[str, str] | None = None) -> None: + if headers is None: + headers = {} + + if headers.get("Content-Type") is None: + headers["Content-Type"] = "application/json-patch+json" + + super().__init__("DELETE", uri, headers) + + +class ADOAsyncWorkItemsClient(ADOAsyncBaseClient): + """Wrapper class around the ADO work items APIs. + + :param http_client: The HTTP client to use for the client + :param log: The logger to use + """ + + def __init__(self, http_client: ADOAsyncHTTPClient, log: logging.Logger) -> None: + super().__init__(http_client, log.getChild("workitems")) + + # TODO: Switch this to the default on next major version bump + async def get(self, identifier: str, project_id: str) -> ADOResponse: + """Get the data about a work item. + + :param identifier: The identifier of the work item + :param project_id: The ID of the project + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Getting work item: {identifier}") + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/wit/workitems/{identifier}?api-version=7.1&$expand=all" + ) + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def get_work_item(self, identifier: str, project_id: str) -> ADOAsyncWorkItem: + """Get a work item as an ADOAsyncWorkItem object. + + :param identifier: The identifier of the work item + :param project_id: The ID of the project + + :returns: An ADOAsyncWorkItem object wrapping the work item data + """ + + self.log.debug(f"Getting work item: {identifier}") + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/wit/workitems/{identifier}?api-version=7.1&$expand=all" + ) + response = await self.http_client.get(request_url) + data = self.http_client.decode_response(response) + return ADOAsyncWorkItem(data, self, project_id, self.log) + + # TODO: Switch this to the default on next major version bump + async def list(self, identifiers: List[int], project_id: str) -> ADOResponse: + """Get a list of work items. + + :param identifiers: The list of requested work item ids + :param project_id: The ID of the project + + :returns: The ADO response with the data in it + """ + + ids = ",".join(map(str, identifiers)) + + self.log.debug(f"Getting work items: {ids}") + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/wit/workitems?api-version=7.1&ids={ids}&$expand=all" + ) + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + # TODO: Switch this to the default on next major version bump + async def ilist(self, identifiers: List[int], project_id: str) -> AsyncIterator[dict[str, Any]]: + """Get a list of work items. + + :param identifiers: The list of requested work item ids + :param project_id: The ID of the project + + :returns: The ADO response with the data in it + """ + + for id_chunk in _batched(identifiers, 200): + + ids = ",".join(map(str, id_chunk)) + + self.log.debug(f"Getting work items: {ids}") + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/wit/workitems?api-version=7.1&ids={ids}&$expand=all" + ) + response = await self.http_client.get(request_url) + data = self.http_client.decode_response(response) + + for item in data.get("value", []): + yield item + + async def list_work_items( + self, identifiers: List[int], project_id: str + ) -> AsyncIterator[ADOAsyncWorkItem]: + """Get a list of work items as ADOAsyncWorkItem objects with automatic chunking. + + :param identifiers: The list of requested work item ids + :param project_id: The ID of the project + + :returns: An iterator of ADOAsyncWorkItem objects + """ + + for id_chunk in _batched(identifiers, 200): + + ids = ",".join(map(str, id_chunk)) + + self.log.debug(f"Getting work items: {ids}") + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/wit/workitems?api-version=7.1&ids={ids}&$expand=all" + ) + response = await self.http_client.get(request_url) + data = self.http_client.decode_response(response) + + for item_data in data.get("value", []): + yield ADOAsyncWorkItem(item_data, self, project_id, self.log) + + async def get_work_item_types(self, project_id: str) -> ADOResponse: + """Get the types of work items supported by the project. + + :returns: The ADO response with the data in it + """ + self.log.debug("Getting work item types") + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/wit/workitemtypes?api-version=7.1" + response = await self.http_client.get(request_url) + return self.http_client.decode_response(response) + + async def add_property( + self, + *, + identifier: str, + field: str | ADOWorkItemBuiltInFields, + value: str, + project_id: str, + bypass_rules: bool = False, + supress_notifications: bool = False, + ) -> ADOResponse: + """Add a property value to a work item. + + :param identifier: The identifier of the work item + :param field: The field to add (either a string or ADOWorkItemBuiltInFields enum) + :param value: The value to set the field to + :param project_id: The ID of the project + :param bypass_rules: Set to True if we should bypass validation + rules, False otherwise + :param supress_notifications: Set to True if notifications for this + change should be supressed, False + otherwise + + :returns: The ADO response with the data in it + """ + + # Convert enum to string value if needed + field_str = field.value if isinstance(field, ADOWorkItemBuiltInFields) else field + + # Add /fields/ prefix if not already present (and not a special path like /relations/) + if not field_str.startswith("/"): + field_str = f"/fields/{field_str}" + + self.log.debug(f"Add field '{field_str}' to ticket {identifier}") + + operation = AddOperation(field_str, value) + + request_url = ( + f"{self.http_client.api_endpoint(project_id=project_id)}/wit/workitems/{identifier}" + ) + request_url += f"?bypassRules={boolstr(bypass_rules)}" + request_url += f"&suppressNotifications={boolstr(supress_notifications)}" + request_url += "&api-version=7.1" + + response = await self.http_client.patch( + request_url, + operations=[operation], + additional_headers={"Content-Type": "application/json-patch+json"}, + ) + + return self.http_client.decode_response(response) + + async def add_attachment( + self, + *, + identifier: str, + path_to_attachment: str, + project_id: str, + filename: str | None = None, + bypass_rules: bool = False, + supress_notifications: bool = False, + ) -> ADOResponse: + """Add an attachment to a work item. + + :param identifier: The identifier of the work item + :param path_to_attachment: The path to the attachment on disk + :param project_id: The ID of the project + :param filename: The new file name of the attachment + :param bypass_rules: Set to True if we should bypass validation + rules, False otherwise + :param supress_notifications: Set to True if notifications for this + change should be supressed, False + otherwise + + :returns: The ADO response with the data in it + + :raises ADOException: If we can't get the url from the response + """ + + self.log.debug(f"Adding attachment to {identifier}: {path_to_attachment}") + + if filename is None: + filename = os.path.basename(path_to_attachment) + + filename = filename.replace("#", "_") + + # Upload the file + request_url = ( + self.http_client.api_endpoint(project_id=project_id) + + f"/wit/attachments?fileName={filename}&api-version=7.1" + ) + + response = await self.http_client.post_file(request_url, path_to_attachment) + + response_data = self.http_client.decode_response(response) + + url = response_data.get("url") + + if url is None: + raise ADOException(f"Failed to get url from response: {response_data}") + + # Attach it to the ticket + operation = AddOperation( + "/relations/-", + {"rel": "AttachedFile", "url": url, "attributes": {"comment": ""}}, + ) + + request_url = ( + f"{self.http_client.api_endpoint(project_id=project_id)}/wit/workitems/{identifier}" + ) + request_url += f"?bypassRules={boolstr(bypass_rules)}" + request_url += f"&suppressNotifications={boolstr(supress_notifications)}" + request_url += "&api-version=7.1" + + response = await self.http_client.patch( + request_url, + operations=[operation], + additional_headers={"Content-Type": "application/json-patch+json"}, + ) + + return self.http_client.decode_response(response) + + async def _add_link( + self, + *, + parent_identifier: str, + child_url: str, + relation_type: WorkItemRelationType, + project_id: str, + bypass_rules: bool = False, + supress_notifications: bool = False, + ) -> ADOResponse: + """Add a link between a parent work item and another resource. + + :param parent_identifier: The identifier of the parent work item + :param child_url: The URL of the child item to link to + :param relation_type: The relationship type between + the parent and the child + :param project_id: The ID of the project + :param bypass_rules: Set to True if we should bypass validation + rules, False otherwise + :param supress_notifications: Set to True if notifications for this + change should be supressed, False + otherwise + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Adding link {parent_identifier} -> {child_url} ({relation_type})") + + operation = AddOperation( + "/relations/-", + { + "rel": relation_type.value, + "url": child_url, + "attributes": {"comment": ""}, + }, + ) + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/wit/workitems/{parent_identifier}" + request_url += f"?bypassRules={boolstr(bypass_rules)}" + request_url += f"&suppressNotifications={boolstr(supress_notifications)}" + request_url += "&api-version=7.1" + + response = await self.http_client.patch( + request_url, + operations=[operation], + additional_headers={"Content-Type": "application/json-patch+json"}, + ) + + return self.http_client.decode_response(response) + + async def link_tickets( + self, + *, + parent_identifier: str, + child_identifier: str, + relationship: WorkItemRelationType, + project_id: str, + bypass_rules: bool = False, + supress_notifications: bool = False, + ) -> ADOResponse: + """Add a link between a parent and child work item. + + :param parent_identifier: The identifier of the parent work item + :param child_identifier: The identifier of the child work item + :param relationship: The relationship type between + the two work items + :param project_id: The ID of the project + :param bypass_rules: Set to True if we should bypass validation + rules, False otherwise + :param supress_notifications: Set to True if notifications for this + change should be supressed, False + otherwise + + :returns: The ADO response with the data in it + """ + child_url = f"{self.http_client.api_endpoint()}/wit/workitems/{child_identifier}" + return await self._add_link( + parent_identifier=parent_identifier, + child_url=child_url, + relation_type=relationship, + project_id=project_id, + bypass_rules=bypass_rules, + supress_notifications=supress_notifications, + ) + + async def add_hyperlink( + self, + *, + identifier: str, + hyperlink: str, + project_id: str, + bypass_rules: bool = False, + supress_notifications: bool = False, + ) -> ADOResponse: + """Add a hyperlink link to a work item. + + :param identifier: The identifier of the work item + :param hyperlink: The hyperlink to add to the work item + :param project_id: The ID of the project + :param bypass_rules: Set to True if we should bypass validation + rules, False otherwise + :param supress_notifications: Set to True if notifications for this + change should be supressed, False + otherwise + + :returns: The ADO response with the data in it + """ + return await self._add_link( + parent_identifier=identifier, + child_url=hyperlink, + relation_type=WorkItemRelationType.HYPERLINK, + project_id=project_id, + bypass_rules=bypass_rules, + supress_notifications=supress_notifications, + ) + + async def create( + self, + *, + item_type: str, + operations: List[AddOperation], + project_id: str, + bypass_rules: bool = False, + supress_notifications: bool = False, + ) -> ADOResponse: + """Create a new work item. + + :param item_type: The type of work item to create + :param operations: The list of add operations to use to create the + ticket + :param project_id: The ID of the project + :param bypass_rules: Set to True if we should bypass validation + rules, False otherwise + :param supress_notifications: Set to True if notifications for this + change should be supressed, False + otherwise + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Creating a new {item_type}") + + request_url = ( + f"{self.http_client.api_endpoint(project_id=project_id)}/wit/workitems/${item_type}" + ) + request_url += f"?bypassRules={boolstr(bypass_rules)}" + request_url += f"&suppressNotifications={boolstr(supress_notifications)}" + request_url += "&api-version=7.1" + + response = await self.http_client.post( + request_url, + operations=cast(List[PatchOperation], operations), + additional_headers={"Content-Type": "application/json-patch+json"}, + ) + + return self.http_client.decode_response(response) + + async def update( + self, + *, + identifier: str, + operations: List[PatchOperation], + project_id: str, + bypass_rules: bool = False, + supress_notifications: bool = False, + ) -> ADOResponse: + """Update a work item. + + :param identifier: The identifier of the work item + :param operations: The list of operations to use to update the ticket + :param project_id: The ID of the project + :param bypass_rules: Set to True if we should bypass validation + rules, False otherwise + :param supress_notifications: Set to True if notifications for this + change should be supressed, False + otherwise + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Updating {identifier}") + + request_url = ( + f"{self.http_client.api_endpoint(project_id=project_id)}/wit/workitems/{identifier}" + ) + request_url += f"?bypassRules={boolstr(bypass_rules)}" + request_url += f"&suppressNotifications={boolstr(supress_notifications)}" + request_url += "&api-version=7.1" + + response = await self.http_client.patch( + request_url, + operations=operations, + additional_headers={"Content-Type": "application/json-patch+json"}, + ) + + return self.http_client.decode_response(response) + + async def execute_query(self, query_string: str, project_id: str) -> ADOResponse: + """Execute a WIQL query. + + :param query_string: The WIQL query string to execute + :param project_id: The ID of the project + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Executing query: {query_string}") + + request_url = ( + f"{self.http_client.api_endpoint(project_id=project_id)}/wit/wiql?api-version=7.1" + ) + + response = await self.http_client.post(request_url, json_data={"query": query_string}) + + return self.http_client.decode_response(response) + + async def execute_query_by_id(self, query_id: str, project_id: str) -> ADOResponse: + """Gets the results of the query given the query ID. + + :param query_id: The query id to execute + :param project_id: The ID of the project + + :returns: The ADO response with the data in it + """ + + self.log.debug(f"Executing query with id: {query_id}") + + request_url = f"{self.http_client.api_endpoint(project_id=project_id)}/wit/wiql/{query_id}?api-version=7.1" + + response = await self.http_client.get(request_url) + + return self.http_client.decode_response(response) + + async def delete( + self, + *, + identifier: str, + project_id: str, + permanent: bool = False, + supress_notifications: bool = False, + ) -> ADOResponse: + """Delete a work item. + + :param identifier: The identifier of the work item + :param project_id: The ID of the project + :param permanent: Set to True if we should permanently delete the + work item, False otherwise + :param supress_notifications: Set to True if notifications for this + change should be supressed, False + otherwise + + :returns: The ADO response with the data in it + + :raises ADOHTTPException: Raised if the response code is not 204 (No Content) + """ + + self.log.debug(f"Deleting {identifier}") + + request_url = ( + f"{self.http_client.api_endpoint(project_id=project_id)}/wit/workitems/{identifier}" + ) + request_url += f"?suppressNotifications={boolstr(supress_notifications)}" + request_url += f"&destroy={boolstr(permanent)}" + request_url += "&api-version=7.1" + + response = await self.http_client.delete( + request_url, + additional_headers={"Content-Type": "application/json-patch+json"}, + ) + + if response.status_code != 204: + raise ADOHTTPException(f"Failed to delete '{identifier}'", response) + + return self.http_client.decode_response(response) + + async def batch(self, operations: typing.List[BatchRequest]) -> ADOResponse: + """Run a batch operation. + + :param operations: The list of batch operations to run + + :returns: The ADO response with the data in it + + :raises ADOException: Raised if we try and run more than 200 batch operations at once + """ + + if len(operations) >= 200: + raise ADOException("Cannot perform more than 200 batch operations at once") + + self.log.debug("Running batch operation") + + full_body: list[dict[str, Any]] = [] + for operation in operations: + full_body.append(operation.body()) + + request_url = f"{self.http_client.api_endpoint()}/wit/$batch" + + response = await self.http_client.post(request_url, json_data=full_body) + + return self.http_client.decode_response(response) diff --git a/simple_ado/ado_types.py b/simple_ado/ado_types.py new file mode 100644 index 0000000..b71f5db --- /dev/null +++ b/simple_ado/ado_types.py @@ -0,0 +1,10 @@ +"""Custom types for the library.""" + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import NewType + +# pylint: disable=invalid-name + +TeamFoundationId = NewType("TeamFoundationId", str) diff --git a/simple_ado/audit.py b/simple_ado/audit.py index aa633a1..90212a8 100755 --- a/simple_ado/audit.py +++ b/simple_ado/audit.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/audit.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. diff --git a/simple_ado/auth/__init__.py b/simple_ado/auth/__init__.py index 96278a3..d2dc78a 100644 --- a/simple_ado/auth/__init__.py +++ b/simple_ado/auth/__init__.py @@ -1,3 +1,5 @@ +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/auth/__init__.py. DO NOT EDIT. + """Umbrella module for all authentication classes.""" from .ado_auth import ADOAuth diff --git a/simple_ado/auth/ado_auth.py b/simple_ado/auth/ado_auth.py index 80e4a93..9c38e4e 100644 --- a/simple_ado/auth/ado_auth.py +++ b/simple_ado/auth/ado_auth.py @@ -1,3 +1,5 @@ +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/auth/ado_auth.py. DO NOT EDIT. + """Base auth class.""" import abc @@ -12,3 +14,10 @@ def get_authorization_header(self) -> str: :return: The header value.""" raise NotImplementedError() + + def close(self) -> None: + """Close any resources held by this auth instance. + + Subclasses that hold closeable resources (e.g. credential objects) + should override this method. The default implementation is a no-op. + """ diff --git a/simple_ado/auth/ado_azid_auth.py b/simple_ado/auth/ado_azid_auth.py index cc57bc1..c712df2 100644 --- a/simple_ado/auth/ado_azid_auth.py +++ b/simple_ado/auth/ado_azid_auth.py @@ -1,3 +1,5 @@ +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/auth/ado_azid_auth.py. DO NOT EDIT. + """Azure Identity authentication auth class.""" import time @@ -11,9 +13,11 @@ class ADOAzIDAuth(ADOAuth): """Azure Identity auth.""" access_token: AccessToken | None + _credential: DefaultAzureCredential def __init__(self) -> None: self.access_token = None + self._credential = DefaultAzureCredential() def get_authorization_header(self) -> str: """Get the header value. @@ -23,8 +27,12 @@ def get_authorization_header(self) -> str: # The get_token parameter specifies the Azure DevOps resource and requests a token with # default permissions for API access. if self.access_token is None or self.access_token.expires_on <= time.time() + 60: - self.access_token = DefaultAzureCredential().get_token( + self.access_token = self._credential.get_token( "499b84ac-1321-427f-aa17-267ca6975798/.default" ) return "Bearer " + self.access_token.token + + def close(self) -> None: + """Close the underlying credential.""" + self._credential.close() diff --git a/simple_ado/auth/ado_basic_auth.py b/simple_ado/auth/ado_basic_auth.py index 85c4d91..7d92663 100644 --- a/simple_ado/auth/ado_basic_auth.py +++ b/simple_ado/auth/ado_basic_auth.py @@ -1,7 +1,8 @@ +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/auth/ado_basic_auth.py. DO NOT EDIT. + """Basic authentication auth class.""" import base64 -import functools from simple_ado.auth.ado_auth import ADOAuth @@ -10,16 +11,22 @@ class ADOBasicAuth(ADOAuth): username: str password: str + _cached_header: str | None def __init__(self, username: str, password: str) -> None: self.username = username self.password = password + self._cached_header = None - @functools.lru_cache(maxsize=1) - def get_authorization_header(self) -> str: # pyright: ignore[reportIncompatibleMethodOverride] + def get_authorization_header(self) -> str: """Get the header value. :return: The header value.""" - username_password_bytes = (self.username + ":" + self.password).encode("utf-8") - return "Basic " + base64.b64encode(username_password_bytes).decode("utf-8") + if self._cached_header is None: + username_password_bytes = (self.username + ":" + self.password).encode("utf-8") + self._cached_header = "Basic " + base64.b64encode(username_password_bytes).decode( + "utf-8" + ) + + return self._cached_header diff --git a/simple_ado/auth/ado_token_auth.py b/simple_ado/auth/ado_token_auth.py index 45c4e85..0f43bd3 100644 --- a/simple_ado/auth/ado_token_auth.py +++ b/simple_ado/auth/ado_token_auth.py @@ -1,3 +1,5 @@ +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/auth/ado_token_auth.py. DO NOT EDIT. + """Token authentication auth class.""" from simple_ado.auth.ado_auth import ADOAuth diff --git a/simple_ado/base_client.py b/simple_ado/base_client.py index 4448c85..d76156f 100755 --- a/simple_ado/base_client.py +++ b/simple_ado/base_client.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/base_client.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. diff --git a/simple_ado/builds.py b/simple_ado/builds.py index c3a244e..2454499 100755 --- a/simple_ado/builds.py +++ b/simple_ado/builds.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/builds.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. @@ -8,15 +10,16 @@ import enum import json import logging -from typing import Any, Callable, Iterator +from typing import Any, Iterator, Callable, cast +from urllib.parse import SplitResult import urllib.parse from simple_ado.base_client import ADOBaseClient -from simple_ado.exceptions import ADOHTTPException from simple_ado.http_client import ADOHTTPClient, ADOResponse -from simple_ado.types import TeamFoundationId from simple_ado.utilities import download_from_response_stream +from simple_ado.exceptions import ADOHTTPException +from simple_ado.ado_types import TeamFoundationId class BuildQueryOrder(enum.Enum): @@ -237,52 +240,45 @@ def download_artifact( self.log.debug(f"Fetching artifact {artifact_name} from build {build_id}...") - # This now redirects to a totally different domain. Since the domain is changing, requests will not keep the - # authentication headers. We need to handle the redirect ourselves to avoid this. - response = self.http_client.get( - request_url, stream=True, allow_redirects=False, set_accept_json=False - ) + # ADO redirects artifact downloads to a different domain. We follow redirects manually + # to enforce that only .visualstudio.com domains are accepted, preventing potential + # open-redirect attacks. + url = request_url - try: - while True: - if not response.is_redirect: - break + while True: + with self.http_client.stream_get( + url, follow_redirects=False, set_accept_json=False + ) as response: + if response.status_code < 300 or response.status_code >= 400: + download_from_response_stream( + response=response, + output_path=output_path, + log=self.log, + callback=progress_callback, + ) + return location = response.headers.get("location") if not location: + # Read the body before raising so downstream code can inspect the response + response.read() raise ADOHTTPException( f"ADO returned a redirect status code without a location header, configuration={self}", response, ) - location_components = urllib.parse.urlsplit(location) + parts = cast(SplitResult, urllib.parse.urlsplit(location)) - if location_components.hostname and not location_components.hostname.endswith( - ".visualstudio.com" - ): + if parts.hostname and not parts.hostname.endswith(".visualstudio.com"): + response.read() raise ADOHTTPException( "ADO returned a redirect status code with a location header that is not on visualstudio.com, " + f"configuration={self}", response, ) - response = self.http_client.get( - location, stream=True, allow_redirects=False, set_accept_json=False - ) - - download_from_response_stream( - response=response, output_path=output_path, log=self.log, callback=progress_callback - ) - - except Exception as ex: - try: - if response: - response.close() - except Exception: - pass - finally: - raise ex + url = location def get_file_manifest( self, @@ -354,20 +350,13 @@ def download_file( f"Fetching file {file_name} from artifact {artifact_name} from build {build_id}..." ) - response = self.http_client.get(request_url, stream=True) - - try: + with self.http_client.stream_get(request_url) as response: download_from_response_stream( - response=response, output_path=output_path, log=self.log, callback=progress_callback + response=response, + output_path=output_path, + log=self.log, + callback=progress_callback, ) - except Exception as ex: - try: - if response: - response.close() - except Exception: - pass - finally: - raise ex def get_leases(self, *, project_id: str, build_id: int) -> ADOResponse: """Get the retention leases for a build. diff --git a/simple_ado/endpoints.py b/simple_ado/endpoints.py index cc5e140..20b609c 100644 --- a/simple_ado/endpoints.py +++ b/simple_ado/endpoints.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/endpoints.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. diff --git a/simple_ado/exceptions.py b/simple_ado/exceptions.py index 401e02d..64ec495 100755 --- a/simple_ado/exceptions.py +++ b/simple_ado/exceptions.py @@ -5,13 +5,41 @@ """ADO exceptions.""" -import requests +import httpx class ADOException(Exception): """All ADO exceptions inherit from this or instantiate it.""" +class _CompatResponse: + """Thin wrapper around httpx.Response that adds backward-compatible .ok property. + + The requests library provided response.ok (status_code < 400). httpx uses + response.is_success (200 <= status_code < 300) instead. This wrapper adds .ok + so that code written against the old requests-based API still works. + + All other attribute access is delegated to the underlying httpx.Response. + """ + + _response: httpx.Response + + def __init__(self, response: httpx.Response) -> None: + # Use object.__setattr__ to avoid triggering __setattr__ if overridden + object.__setattr__(self, "_response", response) + + @property + def ok(self) -> bool: + """Backward-compatible alias: True when status_code < 400.""" + return self._response.status_code < 400 + + def __getattr__(self, name: str) -> object: + return getattr(self._response, name) + + def __repr__(self) -> str: + return repr(self._response) + + class ADOHTTPException(ADOException): """All ADO HTTP exceptions use this class. @@ -20,12 +48,12 @@ class ADOHTTPException(ADOException): """ message: str - response: requests.Response + response: _CompatResponse - def __init__(self, message: str, response: requests.Response) -> None: + def __init__(self, message: str, response: httpx.Response) -> None: super().__init__() self.message = message - self.response = response + self.response = _CompatResponse(response) def __str__(self) -> str: """Generate and return the string representation of the object. diff --git a/simple_ado/git.py b/simple_ado/git.py index 04cb2b4..ff0a569 100755 --- a/simple_ado/git.py +++ b/simple_ado/git.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/git.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. @@ -12,9 +14,9 @@ import urllib.parse from simple_ado.base_client import ADOBaseClient -from simple_ado.exceptions import ADOException from simple_ado.http_client import ADOHTTPClient, ADOResponse from simple_ado.utilities import download_from_response_stream +from simple_ado.exceptions import ADOException class ADOGitStatusState(enum.Enum): @@ -280,7 +282,7 @@ def download_zip( if os.path.exists(output_path): raise ADOException("The output path already exists") - with self.http_client.get(request_url, stream=True) as response: + with self.http_client.stream_get(request_url) as response: download_from_response_stream( response=response, output_path=output_path, @@ -811,10 +813,9 @@ def get_blobs( if os.path.exists(output_path): raise FileExistsError("The output path already exists") - with self.http_client.post( + with self.http_client.stream_post( request_url, additional_headers={"Accept": "application/zip"}, - stream=True, json_data=blob_ids, ) as response: download_from_response_stream(response=response, output_path=output_path, log=self.log) diff --git a/simple_ado/governance.py b/simple_ado/governance.py index f2a3eb5..167c031 100644 --- a/simple_ado/governance.py +++ b/simple_ado/governance.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/governance.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. @@ -128,7 +130,7 @@ def remove_policy( response = self.http_client.delete(request_url) - if not response.ok: + if not response.is_success: raise ADOHTTPException( f"Failed to remove policy {policy_id} from {governed_repository_id}", response, @@ -159,7 +161,7 @@ def _set_alert_settings( response = self.http_client.put(request_url, alert_settings) - if not response.ok: + if not response.is_success: raise ADOHTTPException( f"Failed to set alert settings on repo {governed_repository_id}", response, diff --git a/simple_ado/graph.py b/simple_ado/graph.py index 58488f8..f979a15 100644 --- a/simple_ado/graph.py +++ b/simple_ado/graph.py @@ -1,3 +1,5 @@ +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/graph.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. diff --git a/simple_ado/http_client.py b/simple_ado/http_client.py index a30a463..f8d94ce 100755 --- a/simple_ado/http_client.py +++ b/simple_ado/http_client.py @@ -1,17 +1,20 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/http_client.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. """ADO HTTP API wrapper.""" +import contextlib +import time import datetime import logging import os -import time -from typing import Any, TypeAlias +from typing import Any, Iterator, TypeAlias -import requests +import httpx from tenacity import ( retry, retry_if_exception, @@ -30,14 +33,27 @@ # pylint: enable=invalid-name +# Only retry status codes where a subsequent attempt may succeed: +# 400 Bad Request — ADO occasionally returns transient 400s under load +# 408 Request Timeout — server didn't receive the full request in time +# 429 Too Many Requests — rate-limited, back off and retry +# 500, 502, 503, 504 — server errors that are typically transient +# Previous versions retried all 4xx, but codes like 401/403/404 are +# deterministic failures that won't resolve on retry. +_RETRYABLE_STATUS_CODES = {400, 408, 429, 500, 502, 503, 504} + + def _is_retryable_get_failure(exception: Exception) -> bool: if not isinstance(exception, ADOHTTPException): return False - return exception.response.status_code in range(400, 500) + return exception.response.status_code in _RETRYABLE_STATUS_CODES def _is_connection_failure(exception: Exception) -> bool: + if isinstance(exception, (httpx.ConnectError, httpx.TimeoutException)): + return True + exception_checks = [ "Operation timed out", "Connection aborted.", @@ -67,7 +83,7 @@ class ADOHTTPClient: extra_headers: dict[str, str] auth: ADOAuth _not_before: datetime.datetime | None - _session: requests.Session + _client: httpx.Client def __init__( self, @@ -86,14 +102,34 @@ def __init__( self.auth = auth self._not_before = None - self._session = requests.Session() - self._session.headers.update({"User-Agent": f"simple_ado/{user_agent}"}) + self._client = httpx.Client( + headers={"User-Agent": f"simple_ado/{user_agent}"}, + follow_redirects=True, + timeout=httpx.Timeout(300.0), + ) if extra_headers is None: self.extra_headers = {} else: self.extra_headers = extra_headers + def close(self) -> None: + """Close the underlying HTTP client and auth resources.""" + self._client.close() + self.auth.close() + + def __enter__(self) -> "ADOHTTPClient": + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + def __del__(self) -> None: + try: + self._client.close() + except Exception: + pass + def graph_endpoint(self) -> str: """Generate the base url for all graph API calls (this varies depending on the API). @@ -160,7 +196,7 @@ def _wait(self) -> None: self.log.debug(f"Sleeping for {remaining} seconds before issuing next request") time.sleep(remaining.total_seconds()) - def _track_rate_limit(self, response: requests.Response) -> None: + def _track_rate_limit(self, response: httpx.Response) -> None: """Track the rate limit info from a request. :param response: The response to track the info from. @@ -196,36 +232,87 @@ def get( *, additional_headers: dict[str, str] | None = None, stream: bool = False, - allow_redirects: bool = True, + follow_redirects: bool = True, set_accept_json: bool = True, - ) -> requests.Response: + ) -> httpx.Response: """Issue a GET request with the correct headers. + When stream=True, the response body is not immediately loaded. The caller + must use the response as a context manager or call response.close() when done. + :param request_url: The URL to issue the request to :param additional_headers: Any additional headers to add to the request :param stream: Set to True to stream the response back - :param allow_redirects: Set to False to disable redirects + :param follow_redirects: Set to False to disable redirects :param set_accept_json: Set to False to disable setting the Accept header :returns: The raw response object from the API """ + self._wait() headers = self.construct_headers( additional_headers=additional_headers, set_accept_json=set_accept_json ) - response = self._session.get( - request_url, - headers=headers, - stream=stream, - allow_redirects=allow_redirects, - ) + if stream: + request = self._client.build_request( + "GET", + request_url, + headers=headers, + ) + response = self._client.send( + request, + stream=True, + follow_redirects=follow_redirects, + ) + else: + response = self._client.get( + request_url, + headers=headers, + follow_redirects=follow_redirects, + ) self._track_rate_limit(response) return response + @contextlib.contextmanager + def stream_get( + self, + request_url: str, + *, + additional_headers: dict[str, str] | None = None, + follow_redirects: bool = True, + set_accept_json: bool = True, + ) -> Iterator[httpx.Response]: + """Issue a streaming GET request. Must be used as a context manager. + + This is a convenience wrapper that handles cleanup automatically. + Prefer this over get(stream=True) when possible. + + :param request_url: The URL to issue the request to + :param additional_headers: Any additional headers to add to the request + :param follow_redirects: Set to False to disable redirects + :param set_accept_json: Set to False to disable setting the Accept header + + :yields: The raw response object from the API + """ + self._wait() + + headers = self.construct_headers( + additional_headers=additional_headers, set_accept_json=set_accept_json + ) + + with self._client.stream( + "GET", + request_url, + headers=headers, + follow_redirects=follow_redirects, + ) as response: + self._track_rate_limit(response) + yield response + @retry( retry=retry_if_exception(_is_connection_failure), # type: ignore wait=wait_random_exponential(max=10), @@ -239,12 +326,15 @@ def post( additional_headers: dict[str, str] | None = None, json_data: Any | None = None, stream: bool = False, - ) -> requests.Response: - """Issue a POST request with the correct headers. + ) -> httpx.Response: + """Issue a POST request with the correct headers. Note: If `json_data` and `operations` are not None, the latter will take precedence. + When stream=True, the response body is not immediately loaded. The caller + must use the response as a context manager or call response.close() when done. + :param request_url: The URL to issue the request to :param operations: The patch operations to send with the request :param additional_headers: Any additional headers to add to the request @@ -254,6 +344,8 @@ def post( :returns: The raw response object from the API """ + self._wait() + if operations is not None: json_data = [operation.serialize() for operation in operations] if additional_headers is None: @@ -262,12 +354,54 @@ def post( additional_headers["Content-Type"] = "application/json-patch+json" headers = self.construct_headers(additional_headers=additional_headers) - return self._session.post( + + if stream: + request = self._client.build_request( + "POST", + request_url, + headers=headers, + json=json_data, + ) + response = self._client.send(request, stream=True) + else: + response = self._client.post( + request_url, + headers=headers, + json=json_data, + ) + + self._track_rate_limit(response) + + return response + + @contextlib.contextmanager + def stream_post( + self, + request_url: str, + *, + additional_headers: dict[str, str] | None = None, + json_data: Any | None = None, + ) -> Iterator[httpx.Response]: + """Issue a streaming POST request. Must be used as a context manager. + + :param request_url: The URL to issue the request to + :param additional_headers: Any additional headers to add to the request + :param json_data: The JSON data to send with the request + + :yields: The raw response object from the API + """ + self._wait() + + headers = self.construct_headers(additional_headers=additional_headers) + + with self._client.stream( + "POST", request_url, headers=headers, json=json_data, - stream=stream, - ) + ) as response: + self._track_rate_limit(response) + yield response @retry( retry=retry_if_exception(_is_connection_failure), # type: ignore @@ -281,7 +415,7 @@ def patch( operations: list[PatchOperation] | None = None, json_data: Any | None = None, additional_headers: dict[str, Any] | None = None, - ) -> requests.Response: + ) -> httpx.Response: """Issue a PATCH request with the correct headers. Note: If `json_data` and `operations` are not None, the latter will take @@ -295,6 +429,8 @@ def patch( :returns: The raw response object from the API """ + self._wait() + if operations is not None: json_data = [operation.serialize() for operation in operations] if additional_headers is None: @@ -303,7 +439,9 @@ def patch( additional_headers["Content-Type"] = "application/json-patch+json" headers = self.construct_headers(additional_headers=additional_headers) - return self._session.patch(request_url, headers=headers, json=json_data) + response = self._client.patch(request_url, headers=headers, json=json_data) + self._track_rate_limit(response) + return response @retry( retry=retry_if_exception(_is_connection_failure), # type: ignore @@ -316,7 +454,7 @@ def put( json_data: Any | None = None, *, additional_headers: dict[str, Any] | None = None, - ) -> requests.Response: + ) -> httpx.Response: """Issue a PUT request with the correct headers. :param request_url: The URL to issue the request to @@ -325,8 +463,12 @@ def put( :returns: The raw response object from the API """ + self._wait() + headers = self.construct_headers(additional_headers=additional_headers) - return self._session.put(request_url, headers=headers, json=json_data) + response = self._client.put(request_url, headers=headers, json=json_data) + self._track_rate_limit(response) + return response @retry( retry=retry_if_exception(_is_connection_failure), # type: ignore @@ -335,7 +477,7 @@ def put( ) def delete( self, request_url: str, *, additional_headers: dict[str, Any] | None = None - ) -> requests.Response: + ) -> httpx.Response: """Issue a DELETE request with the correct headers. :param request_url: The URL to issue the request to @@ -343,8 +485,12 @@ def delete( :returns: The raw response object from the API """ + self._wait() + headers = self.construct_headers(additional_headers=additional_headers) - return self._session.delete(request_url, headers=headers) + response = self._client.delete(request_url, headers=headers) + self._track_rate_limit(response) + return response @retry( retry=retry_if_exception(_is_connection_failure), # type: ignore @@ -357,7 +503,7 @@ def post_file( file_path: str, *, additional_headers: dict[str, Any] | None = None, - ) -> requests.Response: + ) -> httpx.Response: """POST a file to the URL with the given file name. :param request_url: The URL to issue the request to @@ -366,23 +512,35 @@ def post_file( :returns: The raw response object from the API""" + self._wait() + file_size = os.path.getsize(file_path) headers = self.construct_headers(additional_headers=additional_headers) headers["Content-Length"] = str(file_size) headers["Content-Type"] = "application/json" - request = requests.Request("POST", request_url, headers=headers) - prepped = request.prepare() - - # Send the raw content, not with "Content-Disposition", etc. - with open(file_path, "rb") as file_handle: - prepped.body = file_handle.read(file_size) + content = self._read_file(file_path) - response: requests.Response = self._session.send(prepped) + response = self._client.post( + request_url, + headers=headers, + content=content, + ) + self._track_rate_limit(response) return response - def validate_response(self, response: requests.models.Response) -> None: + @staticmethod + def _read_file(file_path: str) -> bytes: + """Read a file's contents as bytes. + + :param file_path: The path to the file to read + :returns: The file contents + """ + with open(file_path, "rb") as file_handle: + return file_handle.read() + + def validate_response(self, response: httpx.Response) -> None: """Checking a response for errors. :param response: The response to check @@ -393,13 +551,13 @@ def validate_response(self, response: requests.models.Response) -> None: self.log.debug("Validating response from ADO") - if response.status_code < 200 or response.status_code >= 300: + if not response.is_success: raise ADOHTTPException( f"ADO returned a non-200 status code, configuration={self}", response, ) - def decode_response(self, response: requests.models.Response) -> ADOResponse: + def decode_response(self, response: httpx.Response) -> ADOResponse: """Decode the response from ADO, checking for errors. :param response: The response to check and parse diff --git a/simple_ado/identities.py b/simple_ado/identities.py index 007996d..4325d6b 100644 --- a/simple_ado/identities.py +++ b/simple_ado/identities.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/identities.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. @@ -11,7 +13,7 @@ from simple_ado.base_client import ADOBaseClient from simple_ado.exceptions import ADOException from simple_ado.http_client import ADOHTTPClient -from simple_ado.types import TeamFoundationId +from simple_ado.ado_types import TeamFoundationId class ADOIdentitiesClient(ADOBaseClient): diff --git a/simple_ado/pipelines.py b/simple_ado/pipelines.py index 08b84c7..518e93b 100755 --- a/simple_ado/pipelines.py +++ b/simple_ado/pipelines.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/pipelines.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. @@ -6,7 +8,7 @@ """ADO pipeline API wrapper.""" import logging -from typing import Any, cast, Iterator +from typing import Any, Iterator, cast import urllib.parse diff --git a/simple_ado/pools.py b/simple_ado/pools.py index 612c1bc..2404a6c 100644 --- a/simple_ado/pools.py +++ b/simple_ado/pools.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/pools.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. diff --git a/simple_ado/pull_requests.py b/simple_ado/pull_requests.py index b17b59d..1a47a8d 100755 --- a/simple_ado/pull_requests.py +++ b/simple_ado/pull_requests.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/pull_requests.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. @@ -10,9 +12,10 @@ from typing import Any import deserialize -import requests from simple_ado.base_client import ADOBaseClient +from simple_ado.git import ADOGitStatusState +from simple_ado.http_client import ADOHTTPClient, ADOResponse, ADOThread from simple_ado.comments import ( ADOComment, ADOCommentLocation, @@ -20,8 +23,6 @@ ADOCommentStatus, ) from simple_ado.exceptions import ADOException -from simple_ado.git import ADOGitStatusState -from simple_ado.http_client import ADOHTTPClient, ADOResponse, ADOThread from simple_ado.models import ( PatchOperation, @@ -253,10 +254,7 @@ def delete_thread(self, thread: ADOThread) -> None: request_url += f"/git/repositories/{self.repository_id}" request_url += f"/pullRequests/{self.pull_request_id}/threads/{thread_id}" request_url += f"/comments/{comment_id}?api-version=7.1" - requests.delete( - request_url, - headers=self.http_client.construct_headers(), - ) + self.http_client.delete(request_url) def create_thread_list( self, diff --git a/simple_ado/security.py b/simple_ado/security.py index 7f5bf84..750c388 100644 --- a/simple_ado/security.py +++ b/simple_ado/security.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/security.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. @@ -15,7 +17,7 @@ from simple_ado.base_client import ADOBaseClient from simple_ado.exceptions import ADOException from simple_ado.http_client import ADOHTTPClient, ADOResponse -from simple_ado.types import TeamFoundationId +from simple_ado.ado_types import TeamFoundationId class ADOBranchPermission(enum.IntEnum): diff --git a/simple_ado/types.py b/simple_ado/types.py index b71f5db..04b4646 100644 --- a/simple_ado/types.py +++ b/simple_ado/types.py @@ -1,10 +1,16 @@ -"""Custom types for the library.""" +"""Custom types for the library. + +.. deprecated:: + This module is deprecated. Use :mod:`simple_ado.ado_types` instead. + This module shadows the stdlib ``types`` module and will be removed in a + future major version. +""" # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import NewType - -# pylint: disable=invalid-name +# pylint: disable=wildcard-import,unused-wildcard-import,useless-import-alias -TeamFoundationId = NewType("TeamFoundationId", str) +# Re-export everything from ado_types for backward compatibility. +from simple_ado.ado_types import * # noqa: F401,F403 +from simple_ado.ado_types import TeamFoundationId as TeamFoundationId # noqa: F811 diff --git a/simple_ado/user.py b/simple_ado/user.py index fe26d3a..ce63681 100644 --- a/simple_ado/user.py +++ b/simple_ado/user.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/user.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. diff --git a/simple_ado/utilities.py b/simple_ado/utilities.py index f2c7b84..6c21532 100644 --- a/simple_ado/utilities.py +++ b/simple_ado/utilities.py @@ -1,9 +1,11 @@ +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/utilities.py. DO NOT EDIT. + """Utilities for dealing with the ADO REST API.""" import logging from typing import Callable -import requests +import httpx from simple_ado.exceptions import ADOHTTPException @@ -20,7 +22,7 @@ def boolstr(value: bool) -> str: def download_from_response_stream( *, - response: requests.Response, + response: httpx.Response, output_path: str, log: logging.Logger, callback: Callable[[int, int], None] | None = None, @@ -47,7 +49,7 @@ def download_from_response_stream( total_size = int(content_length_string) total_downloaded = 0 - for data in response.iter_content(chunk_size=chunk_size): + for data in response.iter_bytes(chunk_size=chunk_size): total_downloaded += len(data) output_file.write(data) diff --git a/simple_ado/wiki.py b/simple_ado/wiki.py index 137ca6c..a38cb65 100644 --- a/simple_ado/wiki.py +++ b/simple_ado/wiki.py @@ -1,3 +1,5 @@ +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/wiki.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. diff --git a/simple_ado/work_item.py b/simple_ado/work_item.py index 4645b69..21fe323 100644 --- a/simple_ado/work_item.py +++ b/simple_ado/work_item.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/work_item.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. @@ -86,7 +88,7 @@ def __getitem__(self, key: str | ADOWorkItemBuiltInFields) -> Any: if field_name in fields: return fields[field_name] - # Field not found, try refreshing from server + # Field not found — refresh from server (sync only; async must use get_field()) self._log.debug(f"Field '{field_name}' not found, refreshing work item") self.refresh() @@ -95,13 +97,58 @@ def __getitem__(self, key: str | ADOWorkItemBuiltInFields) -> Any: if field_name in fields: return fields[field_name] - # Still not found, raise KeyError + raise KeyError(f"Field '{field_name}' not found in work item {self.id}") + + def get_field( + self, + key: str | ADOWorkItemBuiltInFields, + *, + auto_refresh: bool = True, + ) -> Any: + """Get a field value, optionally refreshing from the server if not found. + + Unlike ``__getitem__``, this method can auto-refresh the work item data + from the server when a field is missing, which is useful for fields that + weren't included in the initial response. + + :param key: The field name or ADOWorkItemBuiltInFields enum value + :param auto_refresh: If True (default), refresh from server when field is missing + + :returns: The field value + + :raises KeyError: If the field is not found (even after refresh) + """ + field_name = key.value if isinstance(key, ADOWorkItemBuiltInFields) else key + fields = self._data.get("fields", {}) + + if field_name in fields: + return fields[field_name] + + if not auto_refresh: + raise KeyError(f"Field '{field_name}' not found in work item {self.id}") + + self.refresh() + + fields = self._data.get("fields", {}) + if field_name in fields: + return fields[field_name] + raise KeyError(f"Field '{field_name}' not found in work item {self.id}") def __setitem__(self, key: str | ADOWorkItemBuiltInFields, value: Any) -> None: """Set a field value and patch it on the server. - This is a convenience method that calls patch() internally. + Equivalent to ``work_item.patch(key, value)``. + + :param key: The field name or ADOWorkItemBuiltInFields enum value + :param value: The new value for the field + """ + self.patch(key, value) + + def set(self, key: str | ADOWorkItemBuiltInFields, value: Any) -> None: + """Set a field value and patch it on the server. + + Convenience alias for ``work_item.patch(key, value)``. :param key: The field name or ADOWorkItemBuiltInFields enum value :param value: The new value for the field diff --git a/simple_ado/workitems.py b/simple_ado/workitems.py index 9798b6c..ce2c3c8 100755 --- a/simple_ado/workitems.py +++ b/simple_ado/workitems.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# THIS FILE IS AUTO-GENERATED FROM simple_ado/_async/workitems.py. DO NOT EDIT. + # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. @@ -8,12 +10,13 @@ import logging import os import typing -from typing import Any, cast, Iterator, List, TypeVar +from typing import Any, Iterator, List, TypeVar, cast from simple_ado.base_client import ADOBaseClient -from simple_ado.exceptions import ADOException, ADOHTTPException from simple_ado.http_client import ADOHTTPClient, ADOResponse from simple_ado.utilities import boolstr +from simple_ado.work_item import ADOWorkItem +from simple_ado.exceptions import ADOException, ADOHTTPException from simple_ado.models import ( PatchOperation, @@ -21,7 +24,21 @@ WorkItemRelationType, ADOWorkItemBuiltInFields, ) -from simple_ado.work_item import ADOWorkItem + +T = TypeVar("T") + + +# batched is only available in Python 3.12+ +def _batched(sequence: List[T], n: int) -> Iterator[List[T]]: + """Batch data into lists of length n. + + :param sequence: The iterable to batch + :param n: The batch size + + :yields: Lists of size n (or smaller for the last batch) + """ + for i in range(0, len(sequence), n): + yield sequence[i : i + n] class BatchRequest: @@ -142,21 +159,7 @@ def ilist(self, identifiers: List[int], project_id: str) -> Iterator[dict[str, A :returns: The ADO response with the data in it """ - T = TypeVar("T") - - # batched is only available in Python 3.12+ - def batched(sequence: List[T], n: int) -> Iterator[List[T]]: - """Batch data into lists of length n. - - :param sequence: The iterable to batch - :param n: The batch size - - :returns: An iterator of lists of size n - """ - for i in range(0, len(sequence), n): - yield sequence[i : i + n] - - for id_chunk in batched(identifiers, 200): + for id_chunk in _batched(identifiers, 200): ids = ",".join(map(str, id_chunk)) @@ -179,21 +182,7 @@ def list_work_items(self, identifiers: List[int], project_id: str) -> Iterator[A :returns: An iterator of ADOWorkItem objects """ - T = TypeVar("T") - - # batched is only available in Python 3.12+ - def batched(sequence: List[T], n: int) -> Iterator[List[T]]: - """Batch data into lists of length n. - - :param sequence: The iterable to batch - :param n: The batch size - - :returns: An iterator of lists of size n - """ - for i in range(0, len(sequence), n): - yield sequence[i : i + n] - - for id_chunk in batched(identifiers, 200): + for id_chunk in _batched(identifiers, 200): ids = ",".join(map(str, id_chunk)) diff --git a/tests/conftest.py b/tests/conftest.py index 280a7ea..bf273aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,12 +3,16 @@ import json import os from pathlib import Path +from collections.abc import AsyncGenerator, Generator from typing import Any, Callable, cast import pytest +import pytest_asyncio from pytest import Config from simple_ado import ADOClient from simple_ado.auth import ADOTokenAuth +from simple_ado._async import ADOAsyncClient +from simple_ado._async.auth import ADOAsyncTokenAuth # Test data directory @@ -46,9 +50,27 @@ def fixture_mock_auth() -> ADOTokenAuth: @pytest.fixture(name="mock_client") -def fixture_mock_client(mock_tenant: str, mock_auth: ADOTokenAuth) -> ADOClient: +def fixture_mock_client(mock_tenant: str, mock_auth: ADOTokenAuth) -> Generator[ADOClient]: """Return a mock ADO client.""" - return ADOClient(tenant=mock_tenant, auth=mock_auth) + client = ADOClient(tenant=mock_tenant, auth=mock_auth) + yield client + client.close() + + +@pytest.fixture(name="mock_async_auth") +def fixture_mock_async_auth() -> ADOAsyncTokenAuth: + """Return a mock async authentication object.""" + return ADOAsyncTokenAuth("mock-token-12345") + + +@pytest_asyncio.fixture(name="mock_async_client") +async def fixture_mock_async_client( + mock_tenant: str, mock_async_auth: ADOAsyncTokenAuth +) -> AsyncGenerator[ADOAsyncClient]: + """Return a mock async ADO client.""" + client = ADOAsyncClient(tenant=mock_tenant, auth=mock_async_auth) + yield client + await client.close() @pytest.fixture(name="load_fixture") diff --git a/tests/unit/_async/__init__.py b/tests/unit/_async/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/_async/test_builds.py b/tests/unit/_async/test_builds.py new file mode 100644 index 0000000..21b2d7f --- /dev/null +++ b/tests/unit/_async/test_builds.py @@ -0,0 +1,224 @@ +"""Unit tests for the async Builds client.""" + +# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnusedImport=false + +from typing import Any, Callable + +import httpx +import pytest +import respx +from simple_ado._async import ADOAsyncClient +from simple_ado._async.builds import BuildQueryOrder + +# pylint: disable=line-too-long + + +@pytest.mark.asyncio +@respx.mock +async def test_get_builds( + mock_async_client: ADOAsyncClient, + mock_project_id: str, + load_fixture: Callable[[str], dict[str, Any]], +) -> None: + """Test getting builds.""" + builds_data = load_fixture("builds_list.json") + base_url = f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/" + + route = respx.get(url__startswith=base_url).mock( + return_value=httpx.Response(200, json=builds_data) + ) + + builds = [] + async for build in mock_async_client.builds.get_builds(project_id=mock_project_id): + builds.append(build) + + assert len(builds) == 2 + assert builds[0]["id"] == 12345 + assert builds[0]["status"] == "completed" + assert route.called + assert "api-version=" in str(route.calls[0].request.url) + + +@pytest.mark.asyncio +@respx.mock +async def test_get_builds_with_definition_filter( + mock_async_client: ADOAsyncClient, mock_project_id: str +) -> None: + """Test getting builds filtered by definition.""" + base_url = f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/" + + route = respx.get(url__startswith=base_url, params__contains={"definitions": "100,101"}).mock( + return_value=httpx.Response(200, json={"value": []}) + ) + + builds = [] + async for build in mock_async_client.builds.get_builds( + project_id=mock_project_id, definitions=[100, 101] + ): + builds.append(build) + + assert not builds + assert route.called + + +@pytest.mark.asyncio +@respx.mock +async def test_get_builds_with_order( + mock_async_client: ADOAsyncClient, mock_project_id: str +) -> None: + """Test getting builds with specific order.""" + base_url = f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/" + + route = respx.get( + url__startswith=base_url, params__contains={"queryOrder": "finishTimeDescending"} + ).mock(return_value=httpx.Response(200, json={"value": []})) + + builds = [] + async for build in mock_async_client.builds.get_builds( + project_id=mock_project_id, order=BuildQueryOrder.FINISH_TIME_DESCENDING + ): + builds.append(build) + + assert not builds + assert route.called + + +@pytest.mark.asyncio +@respx.mock +async def test_build_info(mock_async_client: ADOAsyncClient, mock_project_id: str) -> None: + """Test getting build info.""" + build_id = 12345 + base_url = f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/{build_id}" + build_data: dict[str, Any] = { + "id": build_id, + "buildNumber": "20231001.1", + "status": "completed", + "result": "succeeded", + } + + route = respx.get(url__startswith=base_url).mock( + return_value=httpx.Response(200, json=build_data) + ) + + result = await mock_async_client.builds.build_info( + project_id=mock_project_id, build_id=build_id + ) + + assert result["id"] == build_id + assert result["result"] == "succeeded" + assert route.called + request_url = str(route.calls[0].request.url) + assert f"/builds/{build_id}" in request_url + assert "api-version=" in request_url + + +@pytest.mark.asyncio +@respx.mock +async def test_queue_build(mock_async_client: ADOAsyncClient, mock_project_id: str) -> None: + """Test queueing a new build.""" + definition_id = 100 + source_branch = "refs/heads/main" + variables = {"myVar": "myValue"} + + queued_build: dict[str, Any] = {"id": 99999, "buildNumber": "queued", "status": "notStarted"} + base_url = f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds" + + route = respx.post(url__startswith=base_url).mock( + return_value=httpx.Response(200, json=queued_build) + ) + + result = await mock_async_client.builds.queue_build( + project_id=mock_project_id, + definition_id=definition_id, + source_branch=source_branch, + variables=variables, + ) + + assert result["id"] == 99999 + assert result["status"] == "notStarted" + assert route.called + assert "api-version=" in str(route.calls[0].request.url) + + +@pytest.mark.asyncio +@respx.mock +async def test_list_artifacts(mock_async_client: ADOAsyncClient, mock_project_id: str) -> None: + """Test listing build artifacts.""" + build_id = 12345 + artifacts_data: dict[str, Any] = { + "value": [ + {"id": 1, "name": "drop", "resource": {"type": "Container"}}, + {"id": 2, "name": "logs", "resource": {"type": "Container"}}, + ] + } + + respx.get( + url__startswith=f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/{build_id}/artifacts", + ).mock(return_value=httpx.Response(200, json=artifacts_data)) + + result = await mock_async_client.builds.list_artifacts( + project_id=mock_project_id, build_id=build_id + ) + + assert len(result) == 2 + assert result[0]["name"] == "drop" + + +@pytest.mark.asyncio +@respx.mock +async def test_get_definitions(mock_async_client: ADOAsyncClient, mock_project_id: str) -> None: + """Test getting build definitions.""" + definitions_data: dict[str, Any] = { + "value": [{"id": 100, "name": "CI Pipeline"}, {"id": 101, "name": "Release Pipeline"}] + } + + respx.get( + url__startswith=f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/definitions", + ).mock(return_value=httpx.Response(200, json=definitions_data)) + + result = await mock_async_client.builds.get_definitions(project_id=mock_project_id) + + assert len(result) == 2 + assert result[0]["name"] == "CI Pipeline" + + +@pytest.mark.asyncio +@respx.mock +async def test_get_definition(mock_async_client: ADOAsyncClient, mock_project_id: str) -> None: + """Test getting a specific build definition.""" + definition_id = 100 + definition_data: dict[str, Any] = { + "id": definition_id, + "name": "CI Pipeline", + "type": "build", + "quality": "definition", + } + + respx.get( + url__startswith=f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/definitions/{definition_id}", + ).mock(return_value=httpx.Response(200, json=definition_data)) + + result = await mock_async_client.builds.get_definition( + project_id=mock_project_id, definition_id=definition_id + ) + + assert result["id"] == definition_id + assert result["name"] == "CI Pipeline" + + +@pytest.mark.asyncio +@respx.mock +async def test_delete_definition(mock_async_client: ADOAsyncClient, mock_project_id: str) -> None: + """Test deleting a build definition.""" + definition_id = 100 + + route = respx.delete( + url__startswith=f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/definitions/{definition_id}", + ).mock(return_value=httpx.Response(204)) + + # Should not raise any exception + await mock_async_client.builds.delete_definition( + project_id=mock_project_id, definition_id=definition_id + ) + + assert route.called diff --git a/tests/unit/_async/test_client.py b/tests/unit/_async/test_client.py new file mode 100644 index 0000000..f853f70 --- /dev/null +++ b/tests/unit/_async/test_client.py @@ -0,0 +1,105 @@ +"""Unit tests for the async ADOAsyncClient class.""" + +# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnusedImport=false + +import httpx +import pytest +import respx +from simple_ado._async import ADOAsyncClient +from simple_ado._async.auth import ADOAsyncAuth, ADOAsyncTokenAuth + +# pylint: disable=line-too-long + + +@pytest.mark.asyncio +async def test_client_initialization(mock_tenant: str, mock_async_auth: ADOAsyncAuth) -> None: + """Test that async client initializes correctly.""" + async with ADOAsyncClient(tenant=mock_tenant, auth=mock_async_auth) as client: + assert client.http_client.tenant == mock_tenant + assert hasattr(client, "builds") + assert hasattr(client, "git") + assert hasattr(client, "pipelines") + assert hasattr(client, "workitems") + + +@pytest.mark.asyncio +async def test_client_has_all_sub_clients(mock_async_client: ADOAsyncClient) -> None: + """Test that async client has all expected sub-clients.""" + expected_clients = [ + "audit", + "builds", + "endpoints", + "git", + "governance", + "graph", + "identities", + "pipelines", + "pools", + "security", + "user", + "wiki", + "workitems", + ] + + for client_name in expected_clients: + assert hasattr(mock_async_client, client_name), f"Missing {client_name} client" + + +@pytest.mark.asyncio +@respx.mock +async def test_verify_access_success(mock_async_client: ADOAsyncClient) -> None: + """Test verify_access with successful response.""" + respx.get( + f"https://{mock_async_client.http_client.tenant}.visualstudio.com/_apis/projects", + ).mock(return_value=httpx.Response(200, json={"value": [], "count": 0})) + + result = await mock_async_client.verify_access() + assert result is True + + +@pytest.mark.asyncio +@respx.mock +async def test_verify_access_failure(mock_async_client: ADOAsyncClient) -> None: + """Test verify_access with failed response.""" + respx.get( + f"https://{mock_async_client.http_client.tenant}.visualstudio.com/_apis/projects", + ).mock(return_value=httpx.Response(401)) + + result = await mock_async_client.verify_access() + assert result is False + + +@pytest.mark.asyncio +async def test_auth_types() -> None: + """Test different authentication types.""" + token_auth = ADOAsyncTokenAuth("test-token") + async with ADOAsyncClient(tenant="test", auth=token_auth) as client: + assert client.http_client.auth == token_auth + + +@pytest.mark.asyncio +@respx.mock +async def test_custom_get(mock_async_client: ADOAsyncClient, mock_project_id: str) -> None: + """Test custom_get method.""" + base_url = f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/test/endpoint" + + route = respx.get(url__startswith=base_url).mock( + return_value=httpx.Response(200, json={"data": "test"}) + ) + + response = await mock_async_client.custom_get( + url_fragment="test/endpoint", + parameters={"api-version": "6.0"}, + project_id=mock_project_id, + ) + + assert response.status_code == 200 + assert route.called + assert "api-version=6.0" in str(route.calls[0].request.url) + + +@pytest.mark.asyncio +async def test_context_manager(mock_tenant: str, mock_async_auth: ADOAsyncAuth) -> None: + """Test that the async client works as a context manager.""" + async with ADOAsyncClient(tenant=mock_tenant, auth=mock_async_auth) as client: + assert client.http_client.tenant == mock_tenant diff --git a/tests/unit/_async/test_http_client.py b/tests/unit/_async/test_http_client.py new file mode 100644 index 0000000..a3ddd05 --- /dev/null +++ b/tests/unit/_async/test_http_client.py @@ -0,0 +1,271 @@ +"""Unit tests for the async HTTP client error paths and rate limiting.""" + +# pylint: disable=protected-access +# pyright: reportPrivateUsage=false + +import datetime +import logging +from collections.abc import AsyncGenerator +from typing import Any +from unittest.mock import patch + +import httpx +import pytest +import pytest_asyncio +import respx +from simple_ado._async.auth import ADOAsyncTokenAuth +from simple_ado._async.http_client import ADOAsyncHTTPClient, _is_retryable_get_failure +from simple_ado.exceptions import ADOException, ADOHTTPException + + +@pytest_asyncio.fixture(name="async_http_client") +async def fixture_async_http_client() -> AsyncGenerator[ADOAsyncHTTPClient]: + """Return a mock async HTTP client.""" + auth = ADOAsyncTokenAuth("mock-token") + client = ADOAsyncHTTPClient( + tenant="test-tenant", + auth=auth, + user_agent="test", + log=logging.getLogger("test"), + ) + yield client + await client.close() + + +class TestValidateResponse: + """Tests for validate_response.""" + + @pytest.mark.asyncio + async def test_success(self, async_http_client: ADOAsyncHTTPClient) -> None: + """200 response should not raise.""" + response = httpx.Response(200, json={"ok": True}) + async_http_client.validate_response(response) + + @pytest.mark.asyncio + async def test_non_200_raises(self, async_http_client: ADOAsyncHTTPClient) -> None: + """Non-200 response should raise ADOHTTPException.""" + response = httpx.Response(404, json={"error": "not found"}) + with pytest.raises(ADOHTTPException) as exc_info: + async_http_client.validate_response(response) + assert exc_info.value.response.status_code == 404 + + @pytest.mark.asyncio + async def test_server_error_raises(self, async_http_client: ADOAsyncHTTPClient) -> None: + """500 response should raise ADOHTTPException.""" + response = httpx.Response(500, text="Internal Server Error") + with pytest.raises(ADOHTTPException): + async_http_client.validate_response(response) + + +class TestDecodeResponse: + """Tests for decode_response.""" + + @pytest.mark.asyncio + async def test_valid_json(self, async_http_client: ADOAsyncHTTPClient) -> None: + """Valid JSON response should be decoded.""" + data = {"id": 1, "name": "test"} + response = httpx.Response(200, json=data) + result = async_http_client.decode_response(response) + assert result == data + + @pytest.mark.asyncio + async def test_non_200_raises_before_decode( + self, async_http_client: ADOAsyncHTTPClient + ) -> None: + """Non-200 response should raise before attempting decode.""" + response = httpx.Response(401, json={"error": "unauthorized"}) + with pytest.raises(ADOHTTPException): + async_http_client.decode_response(response) + + @pytest.mark.asyncio + async def test_invalid_json_raises(self, async_http_client: ADOAsyncHTTPClient) -> None: + """Non-JSON response should raise ADOException.""" + response = httpx.Response(200, text="not json") + with pytest.raises(ADOException, match="did not contain JSON"): + async_http_client.decode_response(response) + + +class TestExtractValue: + """Tests for extract_value.""" + + @pytest.mark.asyncio + async def test_valid_value(self, async_http_client: ADOAsyncHTTPClient) -> None: + """Response with 'value' key should extract it.""" + data: dict[str, Any] = {"value": [{"id": 1}], "count": 1} + result = async_http_client.extract_value(data) + assert result == [{"id": 1}] + + @pytest.mark.asyncio + async def test_missing_value_raises(self, async_http_client: ADOAsyncHTTPClient) -> None: + """Response without 'value' key should raise ADOException.""" + with pytest.raises(ADOException, match="did not contain a value"): + async_http_client.extract_value({"count": 0}) + + +class TestRateLimiting: + """Tests for rate limiting logic.""" + + @pytest.mark.asyncio + async def test_track_retry_after(self, async_http_client: ADOAsyncHTTPClient) -> None: + """Retry-After header should set _not_before.""" + response = httpx.Response(200, headers={"Retry-After": "5"}) + async_http_client._track_rate_limit(response) + assert async_http_client._not_before is not None + delta = async_http_client._not_before - datetime.datetime.now() + assert 3 < delta.total_seconds() <= 5 + + @pytest.mark.asyncio + async def test_track_retry_after_capped(self, async_http_client: ADOAsyncHTTPClient) -> None: + """Retry-After > 15 should be capped at 15.""" + response = httpx.Response(200, headers={"Retry-After": "120"}) + async_http_client._track_rate_limit(response) + assert async_http_client._not_before is not None + delta = async_http_client._not_before - datetime.datetime.now() + assert delta.total_seconds() <= 16 + + @pytest.mark.asyncio + async def test_track_low_remaining(self, async_http_client: ADOAsyncHTTPClient) -> None: + """Low X-RateLimit-Remaining should set a 1-second delay.""" + response = httpx.Response(200, headers={"X-RateLimit-Remaining": "5"}) + async_http_client._track_rate_limit(response) + assert async_http_client._not_before is not None + + @pytest.mark.asyncio + async def test_track_high_remaining_clears(self, async_http_client: ADOAsyncHTTPClient) -> None: + """High X-RateLimit-Remaining should clear the delay.""" + async_http_client._not_before = datetime.datetime.now() + datetime.timedelta(seconds=10) + response = httpx.Response(200, headers={"X-RateLimit-Remaining": "100"}) + async_http_client._track_rate_limit(response) + assert async_http_client._not_before is None + + @pytest.mark.asyncio + async def test_wait_skips_when_no_limit(self, async_http_client: ADOAsyncHTTPClient) -> None: + """_wait should return immediately when _not_before is None.""" + async_http_client._not_before = None + await async_http_client._wait() + + @pytest.mark.asyncio + async def test_wait_clears_expired(self, async_http_client: ADOAsyncHTTPClient) -> None: + """_wait should clear _not_before when it's in the past.""" + async_http_client._not_before = datetime.datetime.now() - datetime.timedelta(seconds=1) + await async_http_client._wait() + assert async_http_client._not_before is None + + +class TestHTTPMethods: + """Tests for HTTP methods honoring rate limiting.""" + + @pytest.mark.asyncio + @respx.mock + async def test_get_non_200_raises(self, async_http_client: ADOAsyncHTTPClient) -> None: + """GET returning non-200 should raise on validate_response.""" + respx.get("https://example.com/api").mock( + return_value=httpx.Response(403, json={"error": "forbidden"}) + ) + response = await async_http_client.get("https://example.com/api") + with pytest.raises(ADOHTTPException): + async_http_client.validate_response(response) + + @pytest.mark.asyncio + @respx.mock + async def test_post_calls_wait_and_track(self, async_http_client: ADOAsyncHTTPClient) -> None: + """POST should call _wait and _track_rate_limit.""" + respx.post("https://example.com/api").mock( + return_value=httpx.Response(200, json={"ok": True}) + ) + with patch.object(async_http_client, "_wait") as mock_wait, patch.object( + async_http_client, "_track_rate_limit" + ) as mock_track: + await async_http_client.post("https://example.com/api", json_data={"key": "value"}) + mock_wait.assert_called_once() + mock_track.assert_called_once() + + @pytest.mark.asyncio + @respx.mock + async def test_patch_calls_wait_and_track(self, async_http_client: ADOAsyncHTTPClient) -> None: + """PATCH should call _wait and _track_rate_limit.""" + respx.patch("https://example.com/api").mock( + return_value=httpx.Response(200, json={"ok": True}) + ) + with patch.object(async_http_client, "_wait") as mock_wait, patch.object( + async_http_client, "_track_rate_limit" + ) as mock_track: + await async_http_client.patch("https://example.com/api", json_data={"key": "value"}) + mock_wait.assert_called_once() + mock_track.assert_called_once() + + @pytest.mark.asyncio + @respx.mock + async def test_put_calls_wait_and_track(self, async_http_client: ADOAsyncHTTPClient) -> None: + """PUT should call _wait and _track_rate_limit.""" + respx.put("https://example.com/api").mock( + return_value=httpx.Response(200, json={"ok": True}) + ) + with patch.object(async_http_client, "_wait") as mock_wait, patch.object( + async_http_client, "_track_rate_limit" + ) as mock_track: + await async_http_client.put("https://example.com/api", json_data={"key": "value"}) + mock_wait.assert_called_once() + mock_track.assert_called_once() + + @pytest.mark.asyncio + @respx.mock + async def test_delete_calls_wait_and_track(self, async_http_client: ADOAsyncHTTPClient) -> None: + """DELETE should call _wait and _track_rate_limit.""" + respx.delete("https://example.com/api").mock(return_value=httpx.Response(204)) + with patch.object(async_http_client, "_wait") as mock_wait, patch.object( + async_http_client, "_track_rate_limit" + ) as mock_track: + await async_http_client.delete("https://example.com/api") + mock_wait.assert_called_once() + mock_track.assert_called_once() + + +class TestRetryableStatusCodes: + """Tests for retryable status code classification.""" + + def test_408_is_retryable(self) -> None: + """408 Request Timeout should be retryable.""" + response = httpx.Response(408) + exc = ADOHTTPException("timeout", response) + assert _is_retryable_get_failure(exc) is True + + def test_429_is_retryable(self) -> None: + """429 Too Many Requests should be retryable.""" + response = httpx.Response(429) + exc = ADOHTTPException("rate limited", response) + assert _is_retryable_get_failure(exc) is True + + def test_400_is_retryable(self) -> None: + """400 Bad Request should be retryable (transient ADO errors).""" + response = httpx.Response(400) + exc = ADOHTTPException("bad request", response) + assert _is_retryable_get_failure(exc) is True + + def test_500_is_retryable(self) -> None: + """500 Internal Server Error should be retryable.""" + response = httpx.Response(500) + exc = ADOHTTPException("server error", response) + assert _is_retryable_get_failure(exc) is True + + def test_401_not_retryable(self) -> None: + """401 Unauthorized should not be retryable.""" + response = httpx.Response(401) + exc = ADOHTTPException("unauthorized", response) + assert _is_retryable_get_failure(exc) is False + + def test_403_not_retryable(self) -> None: + """403 Forbidden should not be retryable.""" + response = httpx.Response(403) + exc = ADOHTTPException("forbidden", response) + assert _is_retryable_get_failure(exc) is False + + def test_404_not_retryable(self) -> None: + """404 Not Found should not be retryable.""" + response = httpx.Response(404) + exc = ADOHTTPException("not found", response) + assert _is_retryable_get_failure(exc) is False + + def test_non_http_exception_not_retryable(self) -> None: + """Non-ADOHTTPException should not be retryable.""" + assert _is_retryable_get_failure(ValueError("boom")) is False diff --git a/tests/unit/_async/test_work_item.py b/tests/unit/_async/test_work_item.py new file mode 100644 index 0000000..c1bfadc --- /dev/null +++ b/tests/unit/_async/test_work_item.py @@ -0,0 +1,249 @@ +"""Unit tests for the async ADOAsyncWorkItem class.""" + +import copy +import logging +from typing import Any + +import httpx +import pytest +import respx +from simple_ado._async import ADOAsyncClient +from simple_ado._async.work_item import ADOAsyncWorkItem +from simple_ado._async.workitems import ADOAsyncWorkItemsClient +from simple_ado.models import ADOWorkItemBuiltInFields +from simple_ado.exceptions import ADOException + + +@pytest.fixture(name="mock_async_work_item_data") +def fixture_mock_async_work_item_data() -> dict[str, Any]: + """Return mock work item data.""" + return { + "id": 12345, + "rev": 1, + "fields": { + "System.Title": "Test Work Item", + "System.State": "Active", + "System.AssignedTo": {"displayName": "Test User"}, + "System.WorkItemType": "Bug", + }, + "url": "https://test.visualstudio.com/_apis/wit/workitems/12345", + } + + +@pytest.fixture(name="mock_async_workitems_client") +def fixture_mock_async_workitems_client( + mock_async_client: ADOAsyncClient, +) -> ADOAsyncWorkItemsClient: + """Return a mock async work items client.""" + return mock_async_client.workitems + + +@pytest.mark.asyncio +async def test_work_item_initialization( + mock_async_work_item_data: dict[str, Any], + mock_async_workitems_client: ADOAsyncWorkItemsClient, +) -> None: + """Test that ADOAsyncWorkItem initializes correctly.""" + work_item = ADOAsyncWorkItem( + data=mock_async_work_item_data, + client=mock_async_workitems_client, + project_id="test-project", + log=logging.getLogger("test"), + ) + + assert work_item.id == 12345 + assert work_item.data == mock_async_work_item_data + + +@pytest.mark.asyncio +async def test_work_item_getitem_string_key( + mock_async_work_item_data: dict[str, Any], + mock_async_workitems_client: ADOAsyncWorkItemsClient, +) -> None: + """Test accessing fields using string keys.""" + work_item = ADOAsyncWorkItem( + data=mock_async_work_item_data, + client=mock_async_workitems_client, + project_id="test-project", + log=logging.getLogger("test"), + ) + + assert work_item["System.Title"] == "Test Work Item" + assert work_item["System.State"] == "Active" + assert work_item["System.WorkItemType"] == "Bug" + + +@pytest.mark.asyncio +async def test_work_item_getitem_enum_key( + mock_async_work_item_data: dict[str, Any], + mock_async_workitems_client: ADOAsyncWorkItemsClient, +) -> None: + """Test accessing fields using ADOWorkItemBuiltInFields enum.""" + work_item = ADOAsyncWorkItem( + data=mock_async_work_item_data, + client=mock_async_workitems_client, + project_id="test-project", + log=logging.getLogger("test"), + ) + + assert work_item[ADOWorkItemBuiltInFields.TITLE] == "Test Work Item" + assert work_item[ADOWorkItemBuiltInFields.STATE] == "Active" + assert work_item[ADOWorkItemBuiltInFields.WORK_ITEM_TYPE] == "Bug" + + +@pytest.mark.asyncio +async def test_work_item_getitem_missing_field_raises( + mock_async_work_item_data: dict[str, Any], + mock_async_workitems_client: ADOAsyncWorkItemsClient, +) -> None: + """Test that accessing a non-existent field raises KeyError.""" + work_item = ADOAsyncWorkItem( + data=mock_async_work_item_data, + client=mock_async_workitems_client, + project_id="test-project", + log=logging.getLogger("test"), + ) + + with pytest.raises(KeyError): + _ = work_item["NonExistent.Field"] + + +@pytest.mark.asyncio +@respx.mock +async def test_work_item_refresh( + mock_async_work_item_data: dict[str, Any], + mock_async_client: ADOAsyncClient, + mock_project_id: str, +) -> None: + """Test refreshing work item data.""" + updated_data = copy.deepcopy(mock_async_work_item_data) + updated_data["fields"]["System.State"] = "Resolved" + + respx.get( + url__startswith=f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=updated_data)) + + work_item = ADOAsyncWorkItem( + data=copy.deepcopy(mock_async_work_item_data), + client=mock_async_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + assert work_item["System.State"] == "Active" + + await work_item.refresh() + + assert work_item["System.State"] == "Resolved" + + +@pytest.mark.asyncio +@respx.mock +async def test_work_item_patch( + mock_async_work_item_data: dict[str, Any], + mock_async_client: ADOAsyncClient, + mock_project_id: str, +) -> None: + """Test patching a work item field.""" + updated_data = copy.deepcopy(mock_async_work_item_data) + updated_data["fields"]["System.State"] = "Resolved" + + respx.patch( + url__startswith=f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=updated_data)) + + work_item = ADOAsyncWorkItem( + data=copy.deepcopy(mock_async_work_item_data), + client=mock_async_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + assert work_item["System.State"] == "Active" + + await work_item.patch("System.State", "Resolved") + + assert work_item["System.State"] == "Resolved" + + +@pytest.mark.asyncio +@respx.mock +async def test_work_item_set( + mock_async_work_item_data: dict[str, Any], + mock_async_client: ADOAsyncClient, + mock_project_id: str, +) -> None: + """Test setting a field using the async set method.""" + updated_data = copy.deepcopy(mock_async_work_item_data) + updated_data["fields"]["System.Title"] = "New Title" + + respx.patch( + url__startswith=f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=updated_data)) + + work_item = ADOAsyncWorkItem( + data=copy.deepcopy(mock_async_work_item_data), + client=mock_async_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + await work_item.set("System.Title", "New Title") + + assert work_item["System.Title"] == "New Title" + + +@pytest.mark.asyncio +async def test_work_item_repr( + mock_async_work_item_data: dict[str, Any], + mock_async_workitems_client: ADOAsyncWorkItemsClient, +) -> None: + """Test work item string representation.""" + work_item = ADOAsyncWorkItem( + data=mock_async_work_item_data, + client=mock_async_workitems_client, + project_id="test-project", + log=logging.getLogger("test"), + ) + + repr_str = repr(work_item) + assert "ADOAsyncWorkItem" in repr_str + assert "12345" in repr_str + assert "Bug" in repr_str + + +@pytest.mark.asyncio +async def test_work_item_no_id_patch_raises( + mock_async_workitems_client: ADOAsyncWorkItemsClient, +) -> None: + """Test that patching without an ID raises an exception.""" + work_item_data: dict[str, Any] = {"fields": {"System.Title": "Test"}} + work_item = ADOAsyncWorkItem( + data=work_item_data, + client=mock_async_workitems_client, + project_id="test-project", + log=logging.getLogger("test"), + ) + + with pytest.raises(ADOException): + await work_item.patch("System.State", "Active") + + +@pytest.mark.asyncio +async def test_work_item_no_id_refresh_raises( + mock_async_workitems_client: ADOAsyncWorkItemsClient, +) -> None: + """Test that refreshing without an ID raises an exception.""" + work_item_data: dict[str, Any] = {"fields": {"System.Title": "Test"}} + work_item = ADOAsyncWorkItem( + data=work_item_data, + client=mock_async_workitems_client, + project_id="test-project", + log=logging.getLogger("test"), + ) + + with pytest.raises(ADOException): + await work_item.refresh() diff --git a/tests/unit/_async/test_work_item_get_field.py b/tests/unit/_async/test_work_item_get_field.py new file mode 100644 index 0000000..58153f0 --- /dev/null +++ b/tests/unit/_async/test_work_item_get_field.py @@ -0,0 +1,106 @@ +"""Unit tests for ADOAsyncWorkItem.get_field (auto-refresh).""" + +import copy +import logging +from typing import Any + +import httpx +import pytest +import respx +from simple_ado._async import ADOAsyncClient +from simple_ado._async.work_item import ADOAsyncWorkItem + + +@pytest.fixture(name="async_work_item_data") +def fixture_async_work_item_data() -> dict[str, Any]: + """Return mock work item data.""" + return { + "id": 12345, + "rev": 1, + "fields": { + "System.Title": "Test Work Item", + "System.State": "Active", + "System.WorkItemType": "Bug", + }, + "url": "https://test.visualstudio.com/_apis/wit/workitems/12345", + } + + +@pytest.mark.asyncio +@respx.mock +async def test_get_field_existing_field( + async_work_item_data: dict[str, Any], mock_async_client: ADOAsyncClient, mock_project_id: str +) -> None: + """get_field should return field value without refresh when field exists.""" + work_item = ADOAsyncWorkItem( + data=copy.deepcopy(async_work_item_data), + client=mock_async_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + result = await work_item.get_field("System.Title") + assert result == "Test Work Item" + + +@pytest.mark.asyncio +@respx.mock +async def test_get_field_missing_triggers_refresh( + async_work_item_data: dict[str, Any], mock_async_client: ADOAsyncClient, mock_project_id: str +) -> None: + """get_field should refresh from server when field is missing.""" + refreshed_data = copy.deepcopy(async_work_item_data) + refreshed_data["fields"]["Custom.Field"] = "found after refresh" + + respx.get( + url__startswith=f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=refreshed_data)) + + work_item = ADOAsyncWorkItem( + data=copy.deepcopy(async_work_item_data), + client=mock_async_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + result = await work_item.get_field("Custom.Field") + assert result == "found after refresh" + + +@pytest.mark.asyncio +@respx.mock +async def test_get_field_missing_still_raises_after_refresh( + async_work_item_data: dict[str, Any], mock_async_client: ADOAsyncClient, mock_project_id: str +) -> None: + """get_field should raise KeyError if field is still missing after refresh.""" + respx.get( + url__startswith=f"https://{mock_async_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=async_work_item_data)) + + work_item = ADOAsyncWorkItem( + data=copy.deepcopy(async_work_item_data), + client=mock_async_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + with pytest.raises(KeyError): + await work_item.get_field("NonExistent.Field") + + +@pytest.mark.asyncio +async def test_get_field_no_auto_refresh( + async_work_item_data: dict[str, Any], mock_async_client: ADOAsyncClient, mock_project_id: str +) -> None: + """get_field with auto_refresh=False should raise immediately.""" + work_item = ADOAsyncWorkItem( + data=copy.deepcopy(async_work_item_data), + client=mock_async_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + with pytest.raises(KeyError): + await work_item.get_field("NonExistent.Field", auto_refresh=False) diff --git a/tests/unit/test_builds.py b/tests/unit/test_builds.py index c9233da..11b5015 100644 --- a/tests/unit/test_builds.py +++ b/tests/unit/test_builds.py @@ -1,81 +1,84 @@ -"""Unit tests for the Builds client.""" +# THIS FILE IS AUTO-GENERATED FROM tests/unit/_async/test_builds.py. DO NOT EDIT. + +"""Unit tests for the async Builds client.""" + +# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnusedImport=false from typing import Any, Callable -import responses +import httpx +import respx from simple_ado import ADOClient from simple_ado.builds import BuildQueryOrder -# pylint: disable=line-too-long - -@responses.activate # pyright: ignore[reportUnknownArgumentType] +# pylint: disable=line-too-long +@respx.mock def test_get_builds( - mock_client: ADOClient, mock_project_id: str, load_fixture: Callable[[str], dict[str, Any]] + mock_client: ADOClient, + mock_project_id: str, + load_fixture: Callable[[str], dict[str, Any]], ) -> None: """Test getting builds.""" builds_data = load_fixture("builds_list.json") + base_url = f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/" - responses.add( - responses.GET, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/", - json=builds_data, - status=200, + route = respx.get(url__startswith=base_url).mock( + return_value=httpx.Response(200, json=builds_data) ) - builds = list(mock_client.builds.get_builds(project_id=mock_project_id)) + builds = [] + for build in mock_client.builds.get_builds(project_id=mock_project_id): + builds.append(build) assert len(builds) == 2 assert builds[0]["id"] == 12345 assert builds[0]["status"] == "completed" + assert route.called + assert "api-version=" in str(route.calls[0].request.url) -@responses.activate +@respx.mock def test_get_builds_with_definition_filter(mock_client: ADOClient, mock_project_id: str) -> None: """Test getting builds filtered by definition.""" - responses.add( - responses.GET, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/", - json={"value": []}, - status=200, - match=[ - responses.matchers.query_param_matcher({"api-version": "7.1", "definitions": "100,101"}) - ], + base_url = f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/" + + route = respx.get(url__startswith=base_url, params__contains={"definitions": "100,101"}).mock( + return_value=httpx.Response(200, json={"value": []}) ) - builds = list(mock_client.builds.get_builds(project_id=mock_project_id, definitions=[100, 101])) + builds = [] + for build in mock_client.builds.get_builds(project_id=mock_project_id, definitions=[100, 101]): + builds.append(build) assert not builds + assert route.called -@responses.activate +@respx.mock def test_get_builds_with_order(mock_client: ADOClient, mock_project_id: str) -> None: """Test getting builds with specific order.""" - responses.add( - responses.GET, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/", - json={"value": []}, - status=200, - match=[ - responses.matchers.query_param_matcher( - {"api-version": "7.1", "queryOrder": "finishTimeDescending"} - ) - ], - ) + base_url = f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/" - builds = list( - mock_client.builds.get_builds( - project_id=mock_project_id, order=BuildQueryOrder.FINISH_TIME_DESCENDING - ) - ) + route = respx.get( + url__startswith=base_url, params__contains={"queryOrder": "finishTimeDescending"} + ).mock(return_value=httpx.Response(200, json={"value": []})) + + builds = [] + for build in mock_client.builds.get_builds( + project_id=mock_project_id, order=BuildQueryOrder.FINISH_TIME_DESCENDING + ): + builds.append(build) assert not builds + assert route.called -@responses.activate +@respx.mock def test_build_info(mock_client: ADOClient, mock_project_id: str) -> None: """Test getting build info.""" build_id = 12345 + base_url = f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/{build_id}" build_data: dict[str, Any] = { "id": build_id, "buildNumber": "20231001.1", @@ -83,20 +86,21 @@ def test_build_info(mock_client: ADOClient, mock_project_id: str) -> None: "result": "succeeded", } - responses.add( - responses.GET, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/{build_id}", - json=build_data, - status=200, + route = respx.get(url__startswith=base_url).mock( + return_value=httpx.Response(200, json=build_data) ) result = mock_client.builds.build_info(project_id=mock_project_id, build_id=build_id) assert result["id"] == build_id assert result["result"] == "succeeded" + assert route.called + request_url = str(route.calls[0].request.url) + assert f"/builds/{build_id}" in request_url + assert "api-version=" in request_url -@responses.activate +@respx.mock def test_queue_build(mock_client: ADOClient, mock_project_id: str) -> None: """Test queueing a new build.""" definition_id = 100 @@ -104,12 +108,10 @@ def test_queue_build(mock_client: ADOClient, mock_project_id: str) -> None: variables = {"myVar": "myValue"} queued_build: dict[str, Any] = {"id": 99999, "buildNumber": "queued", "status": "notStarted"} + base_url = f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds" - responses.add( - responses.POST, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds", - json=queued_build, - status=200, + route = respx.post(url__startswith=base_url).mock( + return_value=httpx.Response(200, json=queued_build) ) result = mock_client.builds.queue_build( @@ -121,9 +123,11 @@ def test_queue_build(mock_client: ADOClient, mock_project_id: str) -> None: assert result["id"] == 99999 assert result["status"] == "notStarted" + assert route.called + assert "api-version=" in str(route.calls[0].request.url) -@responses.activate +@respx.mock def test_list_artifacts(mock_client: ADOClient, mock_project_id: str) -> None: """Test listing build artifacts.""" build_id = 12345 @@ -134,12 +138,9 @@ def test_list_artifacts(mock_client: ADOClient, mock_project_id: str) -> None: ] } - responses.add( - responses.GET, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/{build_id}/artifacts", - json=artifacts_data, - status=200, - ) + respx.get( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/builds/{build_id}/artifacts", + ).mock(return_value=httpx.Response(200, json=artifacts_data)) result = mock_client.builds.list_artifacts(project_id=mock_project_id, build_id=build_id) @@ -147,19 +148,16 @@ def test_list_artifacts(mock_client: ADOClient, mock_project_id: str) -> None: assert result[0]["name"] == "drop" -@responses.activate +@respx.mock def test_get_definitions(mock_client: ADOClient, mock_project_id: str) -> None: """Test getting build definitions.""" definitions_data: dict[str, Any] = { "value": [{"id": 100, "name": "CI Pipeline"}, {"id": 101, "name": "Release Pipeline"}] } - responses.add( - responses.GET, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/definitions", - json=definitions_data, - status=200, - ) + respx.get( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/definitions", + ).mock(return_value=httpx.Response(200, json=definitions_data)) result = mock_client.builds.get_definitions(project_id=mock_project_id) @@ -167,7 +165,7 @@ def test_get_definitions(mock_client: ADOClient, mock_project_id: str) -> None: assert result[0]["name"] == "CI Pipeline" -@responses.activate +@respx.mock def test_get_definition(mock_client: ADOClient, mock_project_id: str) -> None: """Test getting a specific build definition.""" definition_id = 100 @@ -178,12 +176,9 @@ def test_get_definition(mock_client: ADOClient, mock_project_id: str) -> None: "quality": "definition", } - responses.add( - responses.GET, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/definitions/{definition_id}", - json=definition_data, - status=200, - ) + respx.get( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/definitions/{definition_id}", + ).mock(return_value=httpx.Response(200, json=definition_data)) result = mock_client.builds.get_definition( project_id=mock_project_id, definition_id=definition_id @@ -193,19 +188,16 @@ def test_get_definition(mock_client: ADOClient, mock_project_id: str) -> None: assert result["name"] == "CI Pipeline" -@responses.activate +@respx.mock def test_delete_definition(mock_client: ADOClient, mock_project_id: str) -> None: """Test deleting a build definition.""" definition_id = 100 - responses.add( - responses.DELETE, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/definitions/{definition_id}", - status=204, - ) + route = respx.delete( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/build/definitions/{definition_id}", + ).mock(return_value=httpx.Response(204)) # Should not raise any exception mock_client.builds.delete_definition(project_id=mock_project_id, definition_id=definition_id) - assert len(responses.calls) == 1 - assert responses.calls[0].request.method == "DELETE" # pyright: ignore[reportUnknownMemberType] + assert route.called diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index ebb731f..aa3d6ea 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1,25 +1,28 @@ -"""Unit tests for the main ADOClient class.""" +# THIS FILE IS AUTO-GENERATED FROM tests/unit/_async/test_client.py. DO NOT EDIT. -import responses +"""Unit tests for the async ADOClient class.""" + +# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnusedImport=false + +import httpx +import respx from simple_ado import ADOClient from simple_ado.auth import ADOAuth, ADOTokenAuth -# pylint: disable=line-too-long - +# pylint: disable=line-too-long def test_client_initialization(mock_tenant: str, mock_auth: ADOAuth) -> None: - """Test that client initializes correctly.""" - client = ADOClient(tenant=mock_tenant, auth=mock_auth) - - assert client.http_client.tenant == mock_tenant - assert hasattr(client, "builds") - assert hasattr(client, "git") - assert hasattr(client, "pipelines") - assert hasattr(client, "workitems") + """Test that async client initializes correctly.""" + with ADOClient(tenant=mock_tenant, auth=mock_auth) as client: + assert client.http_client.tenant == mock_tenant + assert hasattr(client, "builds") + assert hasattr(client, "git") + assert hasattr(client, "pipelines") + assert hasattr(client, "workitems") def test_client_has_all_sub_clients(mock_client: ADOClient) -> None: - """Test that client has all expected sub-clients.""" + """Test that async client has all expected sub-clients.""" expected_clients = [ "audit", "builds", @@ -40,28 +43,23 @@ def test_client_has_all_sub_clients(mock_client: ADOClient) -> None: assert hasattr(mock_client, client_name), f"Missing {client_name} client" -@responses.activate +@respx.mock def test_verify_access_success(mock_client: ADOClient) -> None: """Test verify_access with successful response.""" - responses.add( - responses.GET, + respx.get( f"https://{mock_client.http_client.tenant}.visualstudio.com/_apis/projects", - json={"value": [], "count": 0}, - status=200, - ) + ).mock(return_value=httpx.Response(200, json={"value": [], "count": 0})) result = mock_client.verify_access() assert result is True -@responses.activate +@respx.mock def test_verify_access_failure(mock_client: ADOClient) -> None: """Test verify_access with failed response.""" - responses.add( - responses.GET, + respx.get( f"https://{mock_client.http_client.tenant}.visualstudio.com/_apis/projects", - status=401, - ) + ).mock(return_value=httpx.Response(401)) result = mock_client.verify_access() assert result is False @@ -69,26 +67,32 @@ def test_verify_access_failure(mock_client: ADOClient) -> None: def test_auth_types() -> None: """Test different authentication types.""" - # Token auth token_auth = ADOTokenAuth("test-token") - client1 = ADOClient(tenant="test", auth=token_auth) - assert client1.http_client.auth == token_auth - - # Can add more auth types when needed + with ADOClient(tenant="test", auth=token_auth) as client: + assert client.http_client.auth == token_auth -@responses.activate +@respx.mock def test_custom_get(mock_client: ADOClient, mock_project_id: str) -> None: """Test custom_get method.""" - responses.add( - responses.GET, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/test/endpoint", - json={"data": "test"}, - status=200, + base_url = f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/{mock_project_id}/_apis/test/endpoint" + + route = respx.get(url__startswith=base_url).mock( + return_value=httpx.Response(200, json={"data": "test"}) ) response = mock_client.custom_get( - url_fragment="test/endpoint", parameters={"api-version": "6.0"}, project_id=mock_project_id + url_fragment="test/endpoint", + parameters={"api-version": "6.0"}, + project_id=mock_project_id, ) assert response.status_code == 200 + assert route.called + assert "api-version=6.0" in str(route.calls[0].request.url) + + +def test_context_manager(mock_tenant: str, mock_auth: ADOAuth) -> None: + """Test that the async client works as a context manager.""" + with ADOClient(tenant=mock_tenant, auth=mock_auth) as client: + assert client.http_client.tenant == mock_tenant diff --git a/tests/unit/test_http_client.py b/tests/unit/test_http_client.py new file mode 100644 index 0000000..a1b0cee --- /dev/null +++ b/tests/unit/test_http_client.py @@ -0,0 +1,251 @@ +# THIS FILE IS AUTO-GENERATED FROM tests/unit/_async/test_http_client.py. DO NOT EDIT. + +"""Unit tests for the async HTTP client error paths and rate limiting.""" + +# pylint: disable=protected-access +# pyright: reportPrivateUsage=false + +import datetime +import logging +from collections.abc import Generator +from typing import Any +from unittest.mock import patch + +import httpx +import pytest +import respx +from simple_ado.auth import ADOTokenAuth +from simple_ado.http_client import ADOHTTPClient, _is_retryable_get_failure +from simple_ado.exceptions import ADOException, ADOHTTPException + + +@pytest.fixture(name="http_client") +def fixture_http_client() -> Generator[ADOHTTPClient]: + """Return a mock async HTTP client.""" + auth = ADOTokenAuth("mock-token") + client = ADOHTTPClient( + tenant="test-tenant", + auth=auth, + user_agent="test", + log=logging.getLogger("test"), + ) + yield client + client.close() + + +class TestValidateResponse: + """Tests for validate_response.""" + + def test_success(self, http_client: ADOHTTPClient) -> None: + """200 response should not raise.""" + response = httpx.Response(200, json={"ok": True}) + http_client.validate_response(response) + + def test_non_200_raises(self, http_client: ADOHTTPClient) -> None: + """Non-200 response should raise ADOHTTPException.""" + response = httpx.Response(404, json={"error": "not found"}) + with pytest.raises(ADOHTTPException) as exc_info: + http_client.validate_response(response) + assert exc_info.value.response.status_code == 404 + + def test_server_error_raises(self, http_client: ADOHTTPClient) -> None: + """500 response should raise ADOHTTPException.""" + response = httpx.Response(500, text="Internal Server Error") + with pytest.raises(ADOHTTPException): + http_client.validate_response(response) + + +class TestDecodeResponse: + """Tests for decode_response.""" + + def test_valid_json(self, http_client: ADOHTTPClient) -> None: + """Valid JSON response should be decoded.""" + data = {"id": 1, "name": "test"} + response = httpx.Response(200, json=data) + result = http_client.decode_response(response) + assert result == data + + def test_non_200_raises_before_decode(self, http_client: ADOHTTPClient) -> None: + """Non-200 response should raise before attempting decode.""" + response = httpx.Response(401, json={"error": "unauthorized"}) + with pytest.raises(ADOHTTPException): + http_client.decode_response(response) + + def test_invalid_json_raises(self, http_client: ADOHTTPClient) -> None: + """Non-JSON response should raise ADOException.""" + response = httpx.Response(200, text="not json") + with pytest.raises(ADOException, match="did not contain JSON"): + http_client.decode_response(response) + + +class TestExtractValue: + """Tests for extract_value.""" + + def test_valid_value(self, http_client: ADOHTTPClient) -> None: + """Response with 'value' key should extract it.""" + data: dict[str, Any] = {"value": [{"id": 1}], "count": 1} + result = http_client.extract_value(data) + assert result == [{"id": 1}] + + def test_missing_value_raises(self, http_client: ADOHTTPClient) -> None: + """Response without 'value' key should raise ADOException.""" + with pytest.raises(ADOException, match="did not contain a value"): + http_client.extract_value({"count": 0}) + + +class TestRateLimiting: + """Tests for rate limiting logic.""" + + def test_track_retry_after(self, http_client: ADOHTTPClient) -> None: + """Retry-After header should set _not_before.""" + response = httpx.Response(200, headers={"Retry-After": "5"}) + http_client._track_rate_limit(response) + assert http_client._not_before is not None + delta = http_client._not_before - datetime.datetime.now() + assert 3 < delta.total_seconds() <= 5 + + def test_track_retry_after_capped(self, http_client: ADOHTTPClient) -> None: + """Retry-After > 15 should be capped at 15.""" + response = httpx.Response(200, headers={"Retry-After": "120"}) + http_client._track_rate_limit(response) + assert http_client._not_before is not None + delta = http_client._not_before - datetime.datetime.now() + assert delta.total_seconds() <= 16 + + def test_track_low_remaining(self, http_client: ADOHTTPClient) -> None: + """Low X-RateLimit-Remaining should set a 1-second delay.""" + response = httpx.Response(200, headers={"X-RateLimit-Remaining": "5"}) + http_client._track_rate_limit(response) + assert http_client._not_before is not None + + def test_track_high_remaining_clears(self, http_client: ADOHTTPClient) -> None: + """High X-RateLimit-Remaining should clear the delay.""" + http_client._not_before = datetime.datetime.now() + datetime.timedelta(seconds=10) + response = httpx.Response(200, headers={"X-RateLimit-Remaining": "100"}) + http_client._track_rate_limit(response) + assert http_client._not_before is None + + def test_wait_skips_when_no_limit(self, http_client: ADOHTTPClient) -> None: + """_wait should return immediately when _not_before is None.""" + http_client._not_before = None + http_client._wait() + + def test_wait_clears_expired(self, http_client: ADOHTTPClient) -> None: + """_wait should clear _not_before when it's in the past.""" + http_client._not_before = datetime.datetime.now() - datetime.timedelta(seconds=1) + http_client._wait() + assert http_client._not_before is None + + +class TestHTTPMethods: + """Tests for HTTP methods honoring rate limiting.""" + + @respx.mock + def test_get_non_200_raises(self, http_client: ADOHTTPClient) -> None: + """GET returning non-200 should raise on validate_response.""" + respx.get("https://example.com/api").mock( + return_value=httpx.Response(403, json={"error": "forbidden"}) + ) + response = http_client.get("https://example.com/api") + with pytest.raises(ADOHTTPException): + http_client.validate_response(response) + + @respx.mock + def test_post_calls_wait_and_track(self, http_client: ADOHTTPClient) -> None: + """POST should call _wait and _track_rate_limit.""" + respx.post("https://example.com/api").mock( + return_value=httpx.Response(200, json={"ok": True}) + ) + with patch.object(http_client, "_wait") as mock_wait, patch.object( + http_client, "_track_rate_limit" + ) as mock_track: + http_client.post("https://example.com/api", json_data={"key": "value"}) + mock_wait.assert_called_once() + mock_track.assert_called_once() + + @respx.mock + def test_patch_calls_wait_and_track(self, http_client: ADOHTTPClient) -> None: + """PATCH should call _wait and _track_rate_limit.""" + respx.patch("https://example.com/api").mock( + return_value=httpx.Response(200, json={"ok": True}) + ) + with patch.object(http_client, "_wait") as mock_wait, patch.object( + http_client, "_track_rate_limit" + ) as mock_track: + http_client.patch("https://example.com/api", json_data={"key": "value"}) + mock_wait.assert_called_once() + mock_track.assert_called_once() + + @respx.mock + def test_put_calls_wait_and_track(self, http_client: ADOHTTPClient) -> None: + """PUT should call _wait and _track_rate_limit.""" + respx.put("https://example.com/api").mock( + return_value=httpx.Response(200, json={"ok": True}) + ) + with patch.object(http_client, "_wait") as mock_wait, patch.object( + http_client, "_track_rate_limit" + ) as mock_track: + http_client.put("https://example.com/api", json_data={"key": "value"}) + mock_wait.assert_called_once() + mock_track.assert_called_once() + + @respx.mock + def test_delete_calls_wait_and_track(self, http_client: ADOHTTPClient) -> None: + """DELETE should call _wait and _track_rate_limit.""" + respx.delete("https://example.com/api").mock(return_value=httpx.Response(204)) + with patch.object(http_client, "_wait") as mock_wait, patch.object( + http_client, "_track_rate_limit" + ) as mock_track: + http_client.delete("https://example.com/api") + mock_wait.assert_called_once() + mock_track.assert_called_once() + + +class TestRetryableStatusCodes: + """Tests for retryable status code classification.""" + + def test_408_is_retryable(self) -> None: + """408 Request Timeout should be retryable.""" + response = httpx.Response(408) + exc = ADOHTTPException("timeout", response) + assert _is_retryable_get_failure(exc) is True + + def test_429_is_retryable(self) -> None: + """429 Too Many Requests should be retryable.""" + response = httpx.Response(429) + exc = ADOHTTPException("rate limited", response) + assert _is_retryable_get_failure(exc) is True + + def test_400_is_retryable(self) -> None: + """400 Bad Request should be retryable (transient ADO errors).""" + response = httpx.Response(400) + exc = ADOHTTPException("bad request", response) + assert _is_retryable_get_failure(exc) is True + + def test_500_is_retryable(self) -> None: + """500 Internal Server Error should be retryable.""" + response = httpx.Response(500) + exc = ADOHTTPException("server error", response) + assert _is_retryable_get_failure(exc) is True + + def test_401_not_retryable(self) -> None: + """401 Unauthorized should not be retryable.""" + response = httpx.Response(401) + exc = ADOHTTPException("unauthorized", response) + assert _is_retryable_get_failure(exc) is False + + def test_403_not_retryable(self) -> None: + """403 Forbidden should not be retryable.""" + response = httpx.Response(403) + exc = ADOHTTPException("forbidden", response) + assert _is_retryable_get_failure(exc) is False + + def test_404_not_retryable(self) -> None: + """404 Not Found should not be retryable.""" + response = httpx.Response(404) + exc = ADOHTTPException("not found", response) + assert _is_retryable_get_failure(exc) is False + + def test_non_http_exception_not_retryable(self) -> None: + """Non-ADOHTTPException should not be retryable.""" + assert _is_retryable_get_failure(ValueError("boom")) is False diff --git a/tests/unit/test_work_item.py b/tests/unit/test_work_item.py index f39ed2e..8b38001 100644 --- a/tests/unit/test_work_item.py +++ b/tests/unit/test_work_item.py @@ -1,12 +1,14 @@ -"""Unit tests for ADOWorkItem class.""" +# THIS FILE IS AUTO-GENERATED FROM tests/unit/_async/test_work_item.py. DO NOT EDIT. + +"""Unit tests for the async ADOWorkItem class.""" import copy import logging from typing import Any -from unittest.mock import MagicMock +import httpx import pytest -import responses +import respx from simple_ado import ADOClient from simple_ado.work_item import ADOWorkItem from simple_ado.workitems import ADOWorkItemsClient @@ -31,13 +33,16 @@ def fixture_mock_work_item_data() -> dict[str, Any]: @pytest.fixture(name="mock_workitems_client") -def fixture_mock_workitems_client(mock_client: ADOClient) -> ADOWorkItemsClient: - """Return a mock work items client.""" +def fixture_mock_workitems_client( + mock_client: ADOClient, +) -> ADOWorkItemsClient: + """Return a mock async work items client.""" return mock_client.workitems def test_work_item_initialization( - mock_work_item_data: dict[str, Any], mock_workitems_client: ADOWorkItemsClient + mock_work_item_data: dict[str, Any], + mock_workitems_client: ADOWorkItemsClient, ) -> None: """Test that ADOWorkItem initializes correctly.""" work_item = ADOWorkItem( @@ -52,7 +57,8 @@ def test_work_item_initialization( def test_work_item_getitem_string_key( - mock_work_item_data: dict[str, Any], mock_workitems_client: ADOWorkItemsClient + mock_work_item_data: dict[str, Any], + mock_workitems_client: ADOWorkItemsClient, ) -> None: """Test accessing fields using string keys.""" work_item = ADOWorkItem( @@ -68,7 +74,8 @@ def test_work_item_getitem_string_key( def test_work_item_getitem_enum_key( - mock_work_item_data: dict[str, Any], mock_workitems_client: ADOWorkItemsClient + mock_work_item_data: dict[str, Any], + mock_workitems_client: ADOWorkItemsClient, ) -> None: """Test accessing fields using ADOWorkItemBuiltInFields enum.""" work_item = ADOWorkItem( @@ -83,43 +90,64 @@ def test_work_item_getitem_enum_key( assert work_item[ADOWorkItemBuiltInFields.WORK_ITEM_TYPE] == "Bug" -def test_work_item_getitem_missing_field_raises( - mock_work_item_data: dict[str, Any], mock_workitems_client: ADOWorkItemsClient +@respx.mock +def test_work_item_getitem_missing_field_refreshes( + mock_work_item_data: dict[str, Any], mock_client: ADOClient, mock_project_id: str ) -> None: - """Test that accessing a non-existent field raises KeyError after refresh.""" + """Test that accessing a missing field auto-refreshes and returns the value.""" + refreshed_data = copy.deepcopy(mock_work_item_data) + refreshed_data["fields"]["System.Reason"] = "Fixed" + + respx.get( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=refreshed_data)) + work_item = ADOWorkItem( - data=mock_work_item_data, - client=mock_workitems_client, - project_id="test-project", + data=copy.deepcopy(mock_work_item_data), + client=mock_client.workitems, + project_id=mock_project_id, log=logging.getLogger("test"), ) - # Mock the client's get method to return the same data (field still missing) - mock_workitems_client.get = MagicMock(return_value=mock_work_item_data) # type: ignore + assert work_item["System.Reason"] == "Fixed" + + +@respx.mock +def test_work_item_getitem_missing_field_raises_after_refresh( + mock_work_item_data: dict[str, Any], mock_client: ADOClient, mock_project_id: str +) -> None: + """Test that accessing a non-existent field raises KeyError after refresh.""" + respx.get( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=mock_work_item_data)) + + work_item = ADOWorkItem( + data=copy.deepcopy(mock_work_item_data), + client=mock_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) with pytest.raises(KeyError): _ = work_item["NonExistent.Field"] - # Verify refresh was attempted - mock_workitems_client.get.assert_called_once_with("12345", "test-project") - -@responses.activate +@respx.mock def test_work_item_refresh( - mock_work_item_data: dict[str, Any], mock_client: ADOClient, mock_project_id: str + mock_work_item_data: dict[str, Any], + mock_client: ADOClient, + mock_project_id: str, ) -> None: """Test refreshing work item data.""" - # Mock the API response updated_data = copy.deepcopy(mock_work_item_data) updated_data["fields"]["System.State"] = "Resolved" - responses.add( - responses.GET, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + respx.get( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + f"{mock_project_id}/_apis/wit/workitems/12345", - json=updated_data, - status=200, - ) + ).mock(return_value=httpx.Response(200, json=updated_data)) work_item = ADOWorkItem( data=copy.deepcopy(mock_work_item_data), @@ -128,32 +156,27 @@ def test_work_item_refresh( log=logging.getLogger("test"), ) - # Initially Active assert work_item["System.State"] == "Active" - # Refresh work_item.refresh() - # Now should be Resolved assert work_item["System.State"] == "Resolved" -@responses.activate +@respx.mock def test_work_item_patch( - mock_work_item_data: dict[str, Any], mock_client: ADOClient, mock_project_id: str + mock_work_item_data: dict[str, Any], + mock_client: ADOClient, + mock_project_id: str, ) -> None: """Test patching a work item field.""" - # Mock the API response updated_data = copy.deepcopy(mock_work_item_data) updated_data["fields"]["System.State"] = "Resolved" - responses.add( - responses.PATCH, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + respx.patch( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + f"{mock_project_id}/_apis/wit/workitems/12345", - json=updated_data, - status=200, - ) + ).mock(return_value=httpx.Response(200, json=updated_data)) work_item = ADOWorkItem( data=copy.deepcopy(mock_work_item_data), @@ -162,32 +185,27 @@ def test_work_item_patch( log=logging.getLogger("test"), ) - # Initially Active assert work_item["System.State"] == "Active" - # Patch the state work_item.patch("System.State", "Resolved") - # Should be updated assert work_item["System.State"] == "Resolved" -@responses.activate +@respx.mock def test_work_item_setitem( - mock_work_item_data: dict[str, Any], mock_client: ADOClient, mock_project_id: str + mock_work_item_data: dict[str, Any], + mock_client: ADOClient, + mock_project_id: str, ) -> None: """Test setting a field using setitem.""" - # Mock the API response updated_data = copy.deepcopy(mock_work_item_data) updated_data["fields"]["System.Title"] = "New Title" - responses.add( - responses.PATCH, - f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + respx.patch( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + f"{mock_project_id}/_apis/wit/workitems/12345", - json=updated_data, - status=200, - ) + ).mock(return_value=httpx.Response(200, json=updated_data)) work_item = ADOWorkItem( data=copy.deepcopy(mock_work_item_data), @@ -196,15 +214,14 @@ def test_work_item_setitem( log=logging.getLogger("test"), ) - # Set using item assignment work_item["System.Title"] = "New Title" - # Should be updated assert work_item["System.Title"] == "New Title" def test_work_item_repr( - mock_work_item_data: dict[str, Any], mock_workitems_client: ADOWorkItemsClient + mock_work_item_data: dict[str, Any], + mock_workitems_client: ADOWorkItemsClient, ) -> None: """Test work item string representation.""" work_item = ADOWorkItem( @@ -220,9 +237,11 @@ def test_work_item_repr( assert "Bug" in repr_str -def test_work_item_no_id_patch_raises(mock_workitems_client: ADOWorkItemsClient) -> None: +def test_work_item_no_id_patch_raises( + mock_workitems_client: ADOWorkItemsClient, +) -> None: """Test that patching without an ID raises an exception.""" - work_item_data = {"fields": {"System.Title": "Test"}} + work_item_data: dict[str, Any] = {"fields": {"System.Title": "Test"}} work_item = ADOWorkItem( data=work_item_data, client=mock_workitems_client, @@ -234,9 +253,11 @@ def test_work_item_no_id_patch_raises(mock_workitems_client: ADOWorkItemsClient) work_item.patch("System.State", "Active") -def test_work_item_no_id_refresh_raises(mock_workitems_client: ADOWorkItemsClient) -> None: +def test_work_item_no_id_refresh_raises( + mock_workitems_client: ADOWorkItemsClient, +) -> None: """Test that refreshing without an ID raises an exception.""" - work_item_data = {"fields": {"System.Title": "Test"}} + work_item_data: dict[str, Any] = {"fields": {"System.Title": "Test"}} work_item = ADOWorkItem( data=work_item_data, client=mock_workitems_client, diff --git a/tests/unit/test_work_item_get_field.py b/tests/unit/test_work_item_get_field.py new file mode 100644 index 0000000..be2b2af --- /dev/null +++ b/tests/unit/test_work_item_get_field.py @@ -0,0 +1,104 @@ +# THIS FILE IS AUTO-GENERATED FROM tests/unit/_async/test_work_item_get_field.py. DO NOT EDIT. + +"""Unit tests for ADOWorkItem.get_field (auto-refresh).""" + +import copy +import logging +from typing import Any + +import httpx +import pytest +import respx +from simple_ado import ADOClient +from simple_ado.work_item import ADOWorkItem + + +@pytest.fixture(name="work_item_data") +def fixture_work_item_data() -> dict[str, Any]: + """Return mock work item data.""" + return { + "id": 12345, + "rev": 1, + "fields": { + "System.Title": "Test Work Item", + "System.State": "Active", + "System.WorkItemType": "Bug", + }, + "url": "https://test.visualstudio.com/_apis/wit/workitems/12345", + } + + +@respx.mock +def test_get_field_existing_field( + work_item_data: dict[str, Any], mock_client: ADOClient, mock_project_id: str +) -> None: + """get_field should return field value without refresh when field exists.""" + work_item = ADOWorkItem( + data=copy.deepcopy(work_item_data), + client=mock_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + result = work_item.get_field("System.Title") + assert result == "Test Work Item" + + +@respx.mock +def test_get_field_missing_triggers_refresh( + work_item_data: dict[str, Any], mock_client: ADOClient, mock_project_id: str +) -> None: + """get_field should refresh from server when field is missing.""" + refreshed_data = copy.deepcopy(work_item_data) + refreshed_data["fields"]["Custom.Field"] = "found after refresh" + + respx.get( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=refreshed_data)) + + work_item = ADOWorkItem( + data=copy.deepcopy(work_item_data), + client=mock_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + result = work_item.get_field("Custom.Field") + assert result == "found after refresh" + + +@respx.mock +def test_get_field_missing_still_raises_after_refresh( + work_item_data: dict[str, Any], mock_client: ADOClient, mock_project_id: str +) -> None: + """get_field should raise KeyError if field is still missing after refresh.""" + respx.get( + url__startswith=f"https://{mock_client.http_client.tenant}.visualstudio.com/DefaultCollection/" + + f"{mock_project_id}/_apis/wit/workitems/12345", + ).mock(return_value=httpx.Response(200, json=work_item_data)) + + work_item = ADOWorkItem( + data=copy.deepcopy(work_item_data), + client=mock_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + with pytest.raises(KeyError): + work_item.get_field("NonExistent.Field") + + +def test_get_field_no_auto_refresh( + work_item_data: dict[str, Any], mock_client: ADOClient, mock_project_id: str +) -> None: + """get_field with auto_refresh=False should raise immediately.""" + work_item = ADOWorkItem( + data=copy.deepcopy(work_item_data), + client=mock_client.workitems, + project_id=mock_project_id, + log=logging.getLogger("test"), + ) + + with pytest.raises(KeyError): + work_item.get_field("NonExistent.Field", auto_refresh=False)