From cbccf1edfd0d15782e5196850f839700a286466a Mon Sep 17 00:00:00 2001 From: Eddie A Tejeda <669988+eddietejeda@users.noreply.github.com> Date: Mon, 18 May 2026 20:24:46 -0700 Subject: [PATCH 1/4] feat: add managed database helpers to HotdataClient Expose create/list/load/delete managed database operations through the runtime contract, mirroring hotdata-cli and the latest SDK managed-table APIs. --- CONTRACT.md | 17 +++ README.md | 1 + hotdata_runtime/__init__.py | 18 +++ hotdata_runtime/client.py | 147 ++++++++++++++++++++++++ hotdata_runtime/databases.py | 95 ++++++++++++++++ pyproject.toml | 8 +- tests/test_contract.py | 8 ++ tests/test_databases.py | 209 +++++++++++++++++++++++++++++++++++ uv.lock | 16 ++- 9 files changed, 513 insertions(+), 6 deletions(-) create mode 100644 hotdata_runtime/databases.py create mode 100644 tests/test_databases.py diff --git a/CONTRACT.md b/CONTRACT.md index ae60b53..cca5d1f 100644 --- a/CONTRACT.md +++ b/CONTRACT.md @@ -30,6 +30,14 @@ The supported import surface is: - `ResultSummary` - `RunHistoryItem` - `WorkspaceSelection` +- `ManagedDatabase` +- `ManagedTable` +- `LoadManagedTableResult` +- `MANAGED_SOURCE_TYPE` +- `DEFAULT_SCHEMA` +- `build_managed_config` +- `create_connection_request` +- `is_parquet_path` Adapters should import from `hotdata_runtime` and treat this surface as the stable API. @@ -49,6 +57,15 @@ Adapters should import from `hotdata_runtime` and treat this surface as the stab - `list_qualified_table_names(...)` returns sorted fully qualified table names. - `columns_for_qualified(qualified, connection_id=...)` resolves table columns, and adapters should pass `connection_id` when known. +- `uploads()` returns the uploads API wrapper for parquet staging. +- `list_managed_databases()` returns managed-catalog connections (`source_type: managed`). +- `resolve_managed_database(name_or_id)` resolves a managed database by name or id. +- `create_managed_database(name, schema=..., tables=...)` creates a managed database and optionally declares tables up front. +- `delete_managed_database(name_or_id)` deletes a managed database connection. +- `list_managed_tables(database, schema=...)` lists tables in a managed database. +- `upload_parquet(path)` uploads a local parquet file and returns an upload id. +- `load_managed_table(database, table, schema=..., upload_id=..., file=...)` publishes parquet data into a declared managed table. +- `delete_managed_table(database, table, schema=...)` deletes a managed table. ### `QueryResult` diff --git a/README.md b/README.md index f4e41ad..71ffc84 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ Runtime boundary and guarantees are defined in `CONTRACT.md`. - **SQL execution helper** — run SQL through `POST /v1/query`, poll async query runs when needed, and return a `QueryResult`. - **Result utilities** — convert query results to records, pandas DataFrames, or metadata dictionaries for adapter display layers. - **History helpers** — list recent results and query run history with normalized dataclasses. +- **Managed databases** — create Hotdata-owned catalogs, declare tables, upload parquet, and load managed tables (mirrors `hotdata databases` in the CLI). - **Health helpers** — build compact API/workspace health summaries for UI integrations. Install: diff --git a/hotdata_runtime/__init__.py b/hotdata_runtime/__init__.py index ef6bd9c..b9d3118 100644 --- a/hotdata_runtime/__init__.py +++ b/hotdata_runtime/__init__.py @@ -8,6 +8,16 @@ RunHistoryItem, from_env, ) +from hotdata_runtime.databases import ( + DEFAULT_SCHEMA, + LoadManagedTableResult, + ManagedDatabase, + ManagedTable, + MANAGED_SOURCE_TYPE, + build_managed_config, + create_connection_request, + is_parquet_path, +) from hotdata_runtime.env import ( default_api_key, default_host, @@ -29,8 +39,16 @@ __all__ = [ "__version__", + "DEFAULT_SCHEMA", "HotdataClient", + "LoadManagedTableResult", + "MANAGED_SOURCE_TYPE", + "ManagedDatabase", + "ManagedTable", "QueryResult", + "build_managed_config", + "create_connection_request", + "is_parquet_path", "workspace_health_lines", "default_api_key", "default_host", diff --git a/hotdata_runtime/client.py b/hotdata_runtime/client.py index c648f78..1d2a7b8 100644 --- a/hotdata_runtime/client.py +++ b/hotdata_runtime/client.py @@ -13,10 +13,12 @@ from hotdata.api.query_api import QueryApi from hotdata.api.query_runs_api import QueryRunsApi from hotdata.api.results_api import ResultsApi +from hotdata.api.uploads_api import UploadsApi from hotdata.exceptions import ApiException from hotdata.models.async_query_response import AsyncQueryResponse from hotdata.models.query_request import QueryRequest from hotdata.models.query_response import QueryResponse +from hotdata.models.load_managed_table_request import LoadManagedTableRequest from hotdata.models.table_info import TableInfo from hotdata_runtime.env import ( @@ -26,6 +28,17 @@ normalize_host, pick_workspace, ) +from hotdata_runtime.databases import ( + DEFAULT_SCHEMA, + LoadManagedTableResult, + ManagedDatabase, + ManagedTable, + MANAGED_SOURCE_TYPE, + _api_error, + _managed_database, + create_connection_request, + is_parquet_path, +) from hotdata_runtime.http import default_http_retries from hotdata_runtime.result import QueryResult @@ -135,6 +148,140 @@ def query_runs(self) -> QueryRunsApi: def results(self) -> ResultsApi: return self._results_api() + def uploads(self) -> UploadsApi: + return UploadsApi(self._api) + + def list_managed_databases(self) -> list[ManagedDatabase]: + listing = self.connections().list_connections() + return [ + _managed_database(c) + for c in listing.connections + if c.source_type == MANAGED_SOURCE_TYPE + ] + + def resolve_managed_database(self, name_or_id: str) -> ManagedDatabase: + listing = self.connections().list_connections() + match = None + for c in listing.connections: + if c.id == name_or_id or c.name == name_or_id: + match = c + break + if match is None: + raise KeyError(f"No database named or with id {name_or_id!r}") + if match.source_type != MANAGED_SOURCE_TYPE: + raise ValueError( + f"{match.name!r} is not a managed database " + f"(source_type: {match.source_type})" + ) + return _managed_database(match) + + def create_managed_database( + self, + name: str, + *, + schema: str = DEFAULT_SCHEMA, + tables: list[str] | None = None, + ) -> ManagedDatabase: + request = create_connection_request(name, schema=schema, tables=tables) + try: + created = self.connections().create_connection(request) + except ApiException as e: + raise RuntimeError(_api_error(e)) from e + return _managed_database(created) + + def delete_managed_database(self, name_or_id: str) -> None: + db = self.resolve_managed_database(name_or_id) + try: + self.connections().delete_connection(db.id) + except ApiException as e: + raise RuntimeError(_api_error(e)) from e + + def list_managed_tables( + self, + database: str, + *, + schema: str | None = None, + ) -> list[ManagedTable]: + db = self.resolve_managed_database(database) + rows: list[ManagedTable] = [] + for t in self.iter_tables(connection_id=db.id): + if schema is not None and t.var_schema != schema: + continue + rows.append( + ManagedTable( + full_name=f"{db.name}.{t.var_schema}.{t.table}", + schema=t.var_schema, + table=t.table, + synced=t.synced, + last_sync=t.last_sync, + ) + ) + rows.sort(key=lambda row: (row.schema, row.table)) + return rows + + def upload_parquet(self, path: str) -> str: + if not is_parquet_path(path): + raise ValueError( + f"Managed table loads require a parquet file (got {path!r})" + ) + with open(path, "rb") as f: + data = f.read() + try: + uploaded = self.uploads().upload_file( + data, + _content_type="application/octet-stream", + ) + except ApiException as e: + raise RuntimeError(_api_error(e)) from e + return uploaded.id + + def load_managed_table( + self, + database: str, + table: str, + *, + schema: str = DEFAULT_SCHEMA, + upload_id: str | None = None, + file: str | None = None, + ) -> LoadManagedTableResult: + if (upload_id is None) == (file is None): + raise ValueError("Exactly one of upload_id or file is required") + db = self.resolve_managed_database(database) + resolved_upload_id = upload_id or self.upload_parquet(file or "") + request = LoadManagedTableRequest( + mode="replace", + upload_id=resolved_upload_id, + ) + try: + loaded = self.connections().load_managed_table( + db.id, + schema, + table, + request, + ) + except ApiException as e: + raise RuntimeError(_api_error(e)) from e + return LoadManagedTableResult( + connection_id=loaded.connection_id, + schema_name=loaded.schema_name, + table_name=loaded.table_name, + row_count=loaded.row_count, + full_name=f"{db.name}.{loaded.schema_name}.{loaded.table_name}", + ) + + def delete_managed_table( + self, + database: str, + table: str, + *, + schema: str = DEFAULT_SCHEMA, + ) -> None: + db = self.resolve_managed_database(database) + try: + self.connections().delete_managed_table(db.id, schema, table) + except ApiException as e: + raise RuntimeError(_api_error(e)) from e + def list_recent_results( self, *, diff --git a/hotdata_runtime/databases.py b/hotdata_runtime/databases.py new file mode 100644 index 0000000..e141bb6 --- /dev/null +++ b/hotdata_runtime/databases.py @@ -0,0 +1,95 @@ +"""Managed database helpers (Hotdata-owned catalogs with parquet table loads).""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +from hotdata.exceptions import ApiException +from hotdata.models.create_connection_request import CreateConnectionRequest +from hotdata.models.load_managed_table_request import LoadManagedTableRequest + +MANAGED_SOURCE_TYPE = "managed" +DEFAULT_SCHEMA = "public" + + +@dataclass(frozen=True) +class ManagedDatabase: + id: str + name: str + source_type: str + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +@dataclass(frozen=True) +class ManagedTable: + full_name: str + schema: str + table: str + synced: bool + last_sync: str | None + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +@dataclass(frozen=True) +class LoadManagedTableResult: + connection_id: str + schema_name: str + table_name: str + row_count: int + full_name: str + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +def is_parquet_path(path: str) -> bool: + lowered = path.lower() + if lowered.endswith(".parquet"): + return True + return Path(path).suffix.lower() == ".parquet" + + +def build_managed_config(schema: str, tables: list[str]) -> dict[str, Any]: + if not tables: + return {} + return { + "schemas": [ + { + "name": schema, + "tables": [{"name": table} for table in tables], + } + ] + } + + +def create_connection_request( + name: str, + *, + schema: str = DEFAULT_SCHEMA, + tables: list[str] | None = None, +) -> CreateConnectionRequest: + table_list = tables or [] + return CreateConnectionRequest( + name=name, + source_type=MANAGED_SOURCE_TYPE, + config=build_managed_config(schema, table_list), + skip_discovery=True, + ) + + +def _managed_database(conn: Any) -> ManagedDatabase: + return ManagedDatabase( + id=str(conn.id), + name=str(conn.name), + source_type=str(conn.source_type), + ) + + +def _api_error(exc: ApiException) -> str: + return exc.reason or str(exc) diff --git a/pyproject.toml b/pyproject.toml index 71a04fd..5e63b68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,9 @@ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.metadata] +allow-direct-references = true + [project] name = "hotdata-runtime" version = "0.1.0" @@ -10,7 +13,7 @@ readme = "README.md" requires-python = ">=3.10" license = { text = "MIT" } dependencies = [ - "hotdata>=0.1.0", + "hotdata @ git+https://github.com/hotdata-dev/sdk-python.git", "pandas>=2.0", ] @@ -23,6 +26,9 @@ dev = [ [tool.uv] default-groups = ["dev"] +[tool.uv.sources] +hotdata = { path = "../sdk-python", editable = true } + [tool.hatch.build.targets.wheel] packages = ["hotdata_runtime"] diff --git a/tests/test_contract.py b/tests/test_contract.py index edfd55c..f324864 100644 --- a/tests/test_contract.py +++ b/tests/test_contract.py @@ -11,8 +11,16 @@ def test_public_exports_contract(): assert hr.__all__ == [ "__version__", + "DEFAULT_SCHEMA", "HotdataClient", + "LoadManagedTableResult", + "MANAGED_SOURCE_TYPE", + "ManagedDatabase", + "ManagedTable", "QueryResult", + "build_managed_config", + "create_connection_request", + "is_parquet_path", "workspace_health_lines", "default_api_key", "default_host", diff --git a/tests/test_databases.py b/tests/test_databases.py new file mode 100644 index 0000000..a9a53a3 --- /dev/null +++ b/tests/test_databases.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from hotdata.exceptions import ApiException +from hotdata_runtime.client import HotdataClient +from hotdata_runtime.databases import ( + build_managed_config, + create_connection_request, + is_parquet_path, +) + + +def _client() -> HotdataClient: + return HotdataClient("k", "ws", host="https://api.hotdata.dev") + + +def test_build_managed_config_empty_without_tables(): + assert build_managed_config("public", []) == {} + + +def test_build_managed_config_declares_tables(): + cfg = build_managed_config("public", ["orders", "customers"]) + assert cfg == { + "schemas": [ + { + "name": "public", + "tables": [{"name": "orders"}, {"name": "customers"}], + } + ] + } + + +def test_create_connection_request_uses_managed_source_type(): + req = create_connection_request("sales", schema="public", tables=["orders"]) + assert req.name == "sales" + assert req.source_type == "managed" + assert req.skip_discovery is True + assert req.config["schemas"][0]["tables"][0]["name"] == "orders" + + +@pytest.mark.parametrize( + ("path", "expected"), + [ + ("/data/orders.parquet", True), + ("/data/ORDERS.PARQUET", True), + ("/data/orders.csv", False), + ], +) +def test_is_parquet_path(path: str, expected: bool): + assert is_parquet_path(path) is expected + + +def test_list_managed_databases_filters_managed_only(): + client = _client() + listing = SimpleNamespace( + connections=[ + SimpleNamespace(id="c1", name="sales", source_type="managed"), + SimpleNamespace(id="c2", name="warehouse", source_type="postgres"), + ] + ) + with patch.object(client, "connections") as connections: + connections.return_value.list_connections.return_value = listing + dbs = client.list_managed_databases() + assert [db.name for db in dbs] == ["sales"] + + +def test_resolve_managed_database_by_name_and_id(): + client = _client() + listing = SimpleNamespace( + connections=[ + SimpleNamespace(id="conn_abc", name="sales", source_type="managed"), + ] + ) + with patch.object(client, "connections") as connections: + connections.return_value.list_connections.return_value = listing + by_name = client.resolve_managed_database("sales") + by_id = client.resolve_managed_database("conn_abc") + assert by_name.id == "conn_abc" + assert by_id.name == "sales" + + +def test_resolve_managed_database_rejects_non_managed(): + client = _client() + listing = SimpleNamespace( + connections=[ + SimpleNamespace(id="c1", name="warehouse", source_type="postgres"), + ] + ) + with patch.object(client, "connections") as connections: + connections.return_value.list_connections.return_value = listing + with pytest.raises(ValueError, match="not a managed database"): + client.resolve_managed_database("warehouse") + + +def test_create_managed_database_returns_summary(): + client = _client() + created = SimpleNamespace(id="conn_new", name="mydb", source_type="managed") + with patch.object(client, "connections") as connections: + connections.return_value.create_connection.return_value = created + db = client.create_managed_database("mydb", tables=["orders"]) + assert db.id == "conn_new" + assert db.name == "mydb" + req = connections.return_value.create_connection.call_args.args[0] + assert req.config["schemas"][0]["tables"][0]["name"] == "orders" + + +def test_create_managed_database_wraps_api_errors(): + client = _client() + with patch.object(client, "connections") as connections: + connections.return_value.create_connection.side_effect = ApiException( + status=400, + reason="bad request", + ) + with pytest.raises(RuntimeError, match="bad request"): + client.create_managed_database("mydb") + + +def test_list_managed_tables_builds_full_names(): + client = _client() + listing = SimpleNamespace( + connections=[ + SimpleNamespace(id="conn1", name="sales", source_type="managed"), + ] + ) + table = SimpleNamespace( + connection="sales", + var_schema="public", + table="orders", + synced=True, + last_sync="2026-05-19T00:00:00Z", + ) + with patch.object(client, "connections") as connections, patch.object( + client, "iter_tables", return_value=[table] + ): + connections.return_value.list_connections.return_value = listing + rows = client.list_managed_tables("sales") + assert len(rows) == 1 + assert rows[0].full_name == "sales.public.orders" + assert rows[0].synced is True + + +def test_upload_parquet_rejects_non_parquet(): + client = _client() + with pytest.raises(ValueError, match="parquet"): + client.upload_parquet("/tmp/data.csv") + + +def test_upload_parquet_returns_upload_id(): + client = _client() + uploaded = SimpleNamespace(id="upl_123") + with patch("builtins.open", mock_open(read_data=b"PAR1")), patch.object( + client, "uploads" + ) as uploads: + uploads.return_value.upload_file.return_value = uploaded + upload_id = client.upload_parquet("/tmp/data.parquet") + assert upload_id == "upl_123" + + +def test_load_managed_table_with_upload_id(): + client = _client() + db = SimpleNamespace(id="conn1", name="sales", source_type="managed") + loaded = SimpleNamespace( + connection_id="conn1", + schema_name="public", + table_name="orders", + row_count=42, + ) + with patch.object(client, "resolve_managed_database", return_value=db), patch.object( + client, "connections" + ) as connections: + connections.return_value.load_managed_table.return_value = loaded + result = client.load_managed_table( + "sales", + "orders", + upload_id="upl_123", + ) + assert result.row_count == 42 + assert result.full_name == "sales.public.orders" + + +def test_load_managed_table_requires_exactly_one_source(): + client = _client() + with pytest.raises(ValueError, match="Exactly one"): + client.load_managed_table("sales", "orders") + with pytest.raises(ValueError, match="Exactly one"): + client.load_managed_table( + "sales", + "orders", + upload_id="upl_1", + file="/tmp/x.parquet", + ) + + +def test_delete_managed_table_calls_sdk(): + client = _client() + db = SimpleNamespace(id="conn1", name="sales", source_type="managed") + with patch.object(client, "resolve_managed_database", return_value=db), patch.object( + client, "connections" + ) as connections: + client.delete_managed_table("sales", "orders") + connections.return_value.delete_managed_table.assert_called_once_with( + "conn1", + "public", + "orders", + ) diff --git a/uv.lock b/uv.lock index 8d81942..94d50e6 100644 --- a/uv.lock +++ b/uv.lock @@ -44,17 +44,23 @@ wheels = [ [[package]] name = "hotdata" version = "0.1.0" -source = { registry = "https://pypi.org/simple" } +source = { editable = "../sdk-python" } dependencies = [ { name = "pydantic" }, { name = "python-dateutil" }, { name = "typing-extensions" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/63/a2/7e997581dc23fca35330c355cd433135c4d18cc5506fb77fb35fd0180e97/hotdata-0.1.0.tar.gz", hash = "sha256:6795ff7381fb8f2f258ee3f0c31f9b1ba2f5908728c51fa399840fdf603acc46", size = 97691, upload-time = "2026-04-25T17:57:00.102Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/21/e04ca377e7e3db50215bf207867ef02a56af11f61022390b7689e6ff2db3/hotdata-0.1.0-py3-none-any.whl", hash = "sha256:304f46d7c7ed5b586a9102684ef42e45972955dfb66a492c5e0b016e8bc545fa", size = 242376, upload-time = "2026-04-25T17:56:58.126Z" }, + +[package.metadata] +requires-dist = [ + { name = "pyarrow", marker = "extra == 'arrow'", specifier = ">=14" }, + { name = "pydantic", specifier = ">=2" }, + { name = "python-dateutil", specifier = ">=2.8.2" }, + { name = "typing-extensions", specifier = ">=4.7.1" }, + { name = "urllib3", specifier = ">=2.1.0,<3.0.0" }, ] +provides-extras = ["arrow"] [[package]] name = "hotdata-runtime" @@ -74,7 +80,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "hotdata", specifier = ">=0.1.0" }, + { name = "hotdata", editable = "../sdk-python" }, { name = "pandas", specifier = ">=2.0" }, ] From af29d9c788b0750f1721eca2ffcc0d411a532ac2 Mon Sep 17 00:00:00 2001 From: Eddie A Tejeda <669988+eddietejeda@users.noreply.github.com> Date: Mon, 18 May 2026 20:43:30 -0700 Subject: [PATCH 2/4] chore: pin hotdata>=0.1.1 and bump to 0.1.1 Use PyPI version constraints instead of git URLs. Managed database APIs require hotdata 0.1.1; keep a sibling sdk-python path override for local development until that release is published. --- pyproject.toml | 8 +++----- uv.lock | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5e63b68..6e467ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,18 +2,15 @@ requires = ["hatchling"] build-backend = "hatchling.build" -[tool.hatch.metadata] -allow-direct-references = true - [project] name = "hotdata-runtime" -version = "0.1.0" +version = "0.1.1" description = "Workspace/session runtime primitives for Hotdata integrations" readme = "README.md" requires-python = ">=3.10" license = { text = "MIT" } dependencies = [ - "hotdata @ git+https://github.com/hotdata-dev/sdk-python.git", + "hotdata>=0.1.1", "pandas>=2.0", ] @@ -26,6 +23,7 @@ dev = [ [tool.uv] default-groups = ["dev"] +# Resolve hotdata from a sibling checkout until v0.1.1 is on PyPI. [tool.uv.sources] hotdata = { path = "../sdk-python", editable = true } diff --git a/uv.lock b/uv.lock index 94d50e6..04edf99 100644 --- a/uv.lock +++ b/uv.lock @@ -43,7 +43,7 @@ wheels = [ [[package]] name = "hotdata" -version = "0.1.0" +version = "0.1.1" source = { editable = "../sdk-python" } dependencies = [ { name = "pydantic" }, @@ -64,7 +64,7 @@ provides-extras = ["arrow"] [[package]] name = "hotdata-runtime" -version = "0.1.0" +version = "0.1.1" source = { editable = "." } dependencies = [ { name = "hotdata" }, From a370abbbb6b8309eaf626484cbbb40b4cab16400 Mon Sep 17 00:00:00 2001 From: Eddie A Tejeda <669988+eddietejeda@users.noreply.github.com> Date: Mon, 18 May 2026 20:58:20 -0700 Subject: [PATCH 3/4] chore: require hotdata>=0.2.0 --- pyproject.toml | 4 ++-- uv.lock | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6e467ba..376a1ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ readme = "README.md" requires-python = ">=3.10" license = { text = "MIT" } dependencies = [ - "hotdata>=0.1.1", + "hotdata>=0.2.0", "pandas>=2.0", ] @@ -23,7 +23,7 @@ dev = [ [tool.uv] default-groups = ["dev"] -# Resolve hotdata from a sibling checkout until v0.1.1 is on PyPI. +# Resolve hotdata from a sibling checkout until v0.2.0 is on PyPI. [tool.uv.sources] hotdata = { path = "../sdk-python", editable = true } diff --git a/uv.lock b/uv.lock index 04edf99..610aa71 100644 --- a/uv.lock +++ b/uv.lock @@ -43,7 +43,7 @@ wheels = [ [[package]] name = "hotdata" -version = "0.1.1" +version = "0.2.0" source = { editable = "../sdk-python" } dependencies = [ { name = "pydantic" }, From 2d3d4106f3a5a08a6cacee3f54874284d4bea2ca Mon Sep 17 00:00:00 2001 From: Eddie A Tejeda <669988+eddietejeda@users.noreply.github.com> Date: Mon, 18 May 2026 21:50:18 -0700 Subject: [PATCH 4/4] fix: address PR review nits for managed database helpers Rename cross-module helpers, remove unused imports, simplify parquet detection, and clarify upload_id/file handling in load_managed_table. --- hotdata_runtime/client.py | 26 +++++++++++++++----------- hotdata_runtime/databases.py | 8 ++------ tests/test_databases.py | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/hotdata_runtime/client.py b/hotdata_runtime/client.py index 1d2a7b8..9c5ced5 100644 --- a/hotdata_runtime/client.py +++ b/hotdata_runtime/client.py @@ -34,10 +34,10 @@ ManagedDatabase, ManagedTable, MANAGED_SOURCE_TYPE, - _api_error, - _managed_database, + api_error_message, create_connection_request, is_parquet_path, + managed_database_from_connection, ) from hotdata_runtime.http import default_http_retries from hotdata_runtime.result import QueryResult @@ -154,7 +154,7 @@ def uploads(self) -> UploadsApi: def list_managed_databases(self) -> list[ManagedDatabase]: listing = self.connections().list_connections() return [ - _managed_database(c) + managed_database_from_connection(c) for c in listing.connections if c.source_type == MANAGED_SOURCE_TYPE ] @@ -173,7 +173,7 @@ def resolve_managed_database(self, name_or_id: str) -> ManagedDatabase: f"{match.name!r} is not a managed database " f"(source_type: {match.source_type})" ) - return _managed_database(match) + return managed_database_from_connection(match) def create_managed_database( self, @@ -186,15 +186,15 @@ def create_managed_database( try: created = self.connections().create_connection(request) except ApiException as e: - raise RuntimeError(_api_error(e)) from e - return _managed_database(created) + raise RuntimeError(api_error_message(e)) from e + return managed_database_from_connection(created) def delete_managed_database(self, name_or_id: str) -> None: db = self.resolve_managed_database(name_or_id) try: self.connections().delete_connection(db.id) except ApiException as e: - raise RuntimeError(_api_error(e)) from e + raise RuntimeError(api_error_message(e)) from e def list_managed_tables( self, @@ -232,7 +232,7 @@ def upload_parquet(self, path: str) -> str: _content_type="application/octet-stream", ) except ApiException as e: - raise RuntimeError(_api_error(e)) from e + raise RuntimeError(api_error_message(e)) from e return uploaded.id def load_managed_table( @@ -247,7 +247,11 @@ def load_managed_table( if (upload_id is None) == (file is None): raise ValueError("Exactly one of upload_id or file is required") db = self.resolve_managed_database(database) - resolved_upload_id = upload_id or self.upload_parquet(file or "") + if upload_id is not None: + resolved_upload_id = upload_id + else: + assert file is not None + resolved_upload_id = self.upload_parquet(file) request = LoadManagedTableRequest( mode="replace", upload_id=resolved_upload_id, @@ -260,7 +264,7 @@ def load_managed_table( request, ) except ApiException as e: - raise RuntimeError(_api_error(e)) from e + raise RuntimeError(api_error_message(e)) from e return LoadManagedTableResult( connection_id=loaded.connection_id, schema_name=loaded.schema_name, @@ -280,7 +284,7 @@ def delete_managed_table( try: self.connections().delete_managed_table(db.id, schema, table) except ApiException as e: - raise RuntimeError(_api_error(e)) from e + raise RuntimeError(api_error_message(e)) from e def list_recent_results( self, diff --git a/hotdata_runtime/databases.py b/hotdata_runtime/databases.py index e141bb6..f9e4b69 100644 --- a/hotdata_runtime/databases.py +++ b/hotdata_runtime/databases.py @@ -8,7 +8,6 @@ from hotdata.exceptions import ApiException from hotdata.models.create_connection_request import CreateConnectionRequest -from hotdata.models.load_managed_table_request import LoadManagedTableRequest MANAGED_SOURCE_TYPE = "managed" DEFAULT_SCHEMA = "public" @@ -49,9 +48,6 @@ def to_dict(self) -> dict[str, Any]: def is_parquet_path(path: str) -> bool: - lowered = path.lower() - if lowered.endswith(".parquet"): - return True return Path(path).suffix.lower() == ".parquet" @@ -83,7 +79,7 @@ def create_connection_request( ) -def _managed_database(conn: Any) -> ManagedDatabase: +def managed_database_from_connection(conn: Any) -> ManagedDatabase: return ManagedDatabase( id=str(conn.id), name=str(conn.name), @@ -91,5 +87,5 @@ def _managed_database(conn: Any) -> ManagedDatabase: ) -def _api_error(exc: ApiException) -> str: +def api_error_message(exc: ApiException) -> str: return exc.reason or str(exc) diff --git a/tests/test_databases.py b/tests/test_databases.py index a9a53a3..8673c64 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import SimpleNamespace -from unittest.mock import MagicMock, mock_open, patch +from unittest.mock import mock_open, patch import pytest