From d0bdc49871e12fbabc201725ed2069b20840b7ac Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Thu, 17 Apr 2025 18:00:49 +0530 Subject: [PATCH 1/4] add user agent to bigquery connector Signed-off-by: Samhita Alla --- .../flytekitplugins/bigquery/connector.py | 16 ++++++++++++++-- .../flytekitplugins/bigquery/task.py | 1 + 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py index da435efe38..cc492f5dfc 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py @@ -3,11 +3,17 @@ from typing import Dict, Optional from flyteidl.core.execution_pb2 import TaskExecution, TaskLog +from google.api_core.client_info import ClientInfo from google.cloud import bigquery from flytekit import FlyteContextManager, StructuredDataset, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_connector import AsyncConnectorBase, ConnectorRegistry, Resource, ResourceMeta +from flytekit.extend.backend.base_connector import ( + AsyncConnectorBase, + ConnectorRegistry, + Resource, + ResourceMeta, +) from flytekit.extend.backend.utils import convert_to_flyte_phase from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -59,9 +65,15 @@ def create( ) custom = task_template.custom + + domain = custom["Domain"] + sdk_version = task_template.metadata.runtime.version + cinfo = ClientInfo(user_agent=f"Flytekit/{sdk_version} (GPN:Union;{domain})") + project = custom["ProjectID"] location = custom["Location"] - client = bigquery.Client(project=project, location=location) + + client = bigquery.Client(project=project, location=location, client_info=cinfo) query_job = client.query(task_template.sql.statement, job_config=job_config) return BigQueryMetadata(job_id=str(query_job.job_id), location=location, project=project) diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index b2fe5a9d42..1f3529c9dc 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -73,6 +73,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: config = { "Location": self.task_config.Location, "ProjectID": self.task_config.ProjectID, + "Domain": settings.domain, } if self.task_config.QueryJobConfig is not None: config.update(self.task_config.QueryJobConfig.to_api_repr()["query"]) From abcb371af42dff69faf9804d209b7fe5e82dceed Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 22 Apr 2025 16:12:30 +0530 Subject: [PATCH 2/4] fix test Signed-off-by: Samhita Alla --- plugins/flytekit-bigquery/tests/test_bigquery.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-bigquery/tests/test_bigquery.py b/plugins/flytekit-bigquery/tests/test_bigquery.py index 7f4837ae0d..dc1eee9189 100644 --- a/plugins/flytekit-bigquery/tests/test_bigquery.py +++ b/plugins/flytekit-bigquery/tests/test_bigquery.py @@ -43,7 +43,7 @@ def my_wf(ds: str) -> StructuredDataset: assert "@version" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI s = Struct() - s.update({"ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True}) + s.update({"Domain": "dom", "ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True}) assert task_spec.template.custom == json_format.MessageToDict(s) assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 From 2174efe8e4f5c2c7aed6e530e862b949932daeb5 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 22 Apr 2025 16:40:06 +0530 Subject: [PATCH 3/4] fix test Signed-off-by: Samhita Alla --- plugins/flytekit-bigquery/tests/test_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/flytekit-bigquery/tests/test_connector.py b/plugins/flytekit-bigquery/tests/test_connector.py index 02e0ff380a..2e99e9ad4c 100644 --- a/plugins/flytekit-bigquery/tests/test_connector.py +++ b/plugins/flytekit-bigquery/tests/test_connector.py @@ -59,6 +59,7 @@ def __init__(self): task_config = { "Location": "us-central1", "ProjectID": "dummy_project", + "Domain": "dev", } int_type = types.LiteralType(types.SimpleType.INTEGER) From 6e574ab40b9c430b00e2ca623b23354b937b182b Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Tue, 22 Apr 2025 22:13:43 +0530 Subject: [PATCH 4/4] make domain fetch optional Signed-off-by: Samhita Alla --- .../flytekit-bigquery/flytekitplugins/bigquery/connector.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py index cc492f5dfc..eb8d083525 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py @@ -66,9 +66,11 @@ def create( custom = task_template.custom - domain = custom["Domain"] + domain = custom.get("Domain") sdk_version = task_template.metadata.runtime.version - cinfo = ClientInfo(user_agent=f"Flytekit/{sdk_version} (GPN:Union;{domain})") + + user_agent = f"Flytekit/{sdk_version} (GPN:Union;{domain or ''})" + cinfo = ClientInfo(user_agent=user_agent) project = custom["ProjectID"] location = custom["Location"]