diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py index da435efe38..eb8d083525 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/connector.py @@ -3,11 +3,17 @@ from typing import Dict, Optional from flyteidl.core.execution_pb2 import TaskExecution, TaskLog +from google.api_core.client_info import ClientInfo from google.cloud import bigquery from flytekit import FlyteContextManager, StructuredDataset, logger from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_connector import AsyncConnectorBase, ConnectorRegistry, Resource, ResourceMeta +from flytekit.extend.backend.base_connector import ( + AsyncConnectorBase, + ConnectorRegistry, + Resource, + ResourceMeta, +) from flytekit.extend.backend.utils import convert_to_flyte_phase from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate @@ -59,9 +65,17 @@ def create( ) custom = task_template.custom + + domain = custom.get("Domain") + sdk_version = task_template.metadata.runtime.version + + user_agent = f"Flytekit/{sdk_version} (GPN:Union;{domain or ''})" + cinfo = ClientInfo(user_agent=user_agent) + project = custom["ProjectID"] location = custom["Location"] - client = bigquery.Client(project=project, location=location) + + client = bigquery.Client(project=project, location=location, client_info=cinfo) query_job = client.query(task_template.sql.statement, job_config=job_config) return BigQueryMetadata(job_id=str(query_job.job_id), location=location, project=project) diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index b2fe5a9d42..1f3529c9dc 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -73,6 +73,7 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: config = { "Location": self.task_config.Location, "ProjectID": self.task_config.ProjectID, + "Domain": settings.domain, } if self.task_config.QueryJobConfig is not None: config.update(self.task_config.QueryJobConfig.to_api_repr()["query"]) diff --git a/plugins/flytekit-bigquery/tests/test_bigquery.py b/plugins/flytekit-bigquery/tests/test_bigquery.py index 7f4837ae0d..dc1eee9189 100644 --- a/plugins/flytekit-bigquery/tests/test_bigquery.py +++ b/plugins/flytekit-bigquery/tests/test_bigquery.py @@ -43,7 +43,7 @@ def my_wf(ds: str) -> StructuredDataset: assert "@version" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI s = Struct() - s.update({"ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True}) + s.update({"Domain": "dom", "ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True}) assert task_spec.template.custom == json_format.MessageToDict(s) assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 diff --git a/plugins/flytekit-bigquery/tests/test_connector.py b/plugins/flytekit-bigquery/tests/test_connector.py index 02e0ff380a..2e99e9ad4c 100644 --- a/plugins/flytekit-bigquery/tests/test_connector.py +++ b/plugins/flytekit-bigquery/tests/test_connector.py @@ -59,6 +59,7 @@ def __init__(self): task_config = { "Location": "us-central1", "ProjectID": "dummy_project", + "Domain": "dev", } int_type = types.LiteralType(types.SimpleType.INTEGER)