From 3966a0b1e9d53775e9d80cff3a96c42d3d40c2ef Mon Sep 17 00:00:00 2001 From: Matt Slack Date: Fri, 3 Apr 2026 23:55:36 +0100 Subject: [PATCH] [Cosmos] Add Spark 4.1 connector module (azure-cosmos-spark_4-1_2-13) Add support for Apache Spark 4.1 which reorganized internal streaming packages (SPARK-52787). The following classes moved from `o.a.s.sql.execution.streaming` to `o.a.s.sql.execution.streaming.checkpointing`: - HDFSMetadataLog - MetadataVersionUtil This module overrides ChangeFeedInitialOffsetWriter, CosmosCatalogBase, and CosmosCatalogITestBase with updated imports, and uses a source-copy approach with excludes to avoid duplicate class definitions from the shared azure-cosmos-spark_3 source. Co-authored-by: Isaac --- .../azure-cosmos-spark_4-1_2-13/CHANGELOG.md | 56 + .../CONTRIBUTING.md | 84 ++ .../azure-cosmos-spark_4-1_2-13/README.md | 233 +++++ .../azure-cosmos-spark_4-1_2-13/pom.xml | 262 +++++ .../scalastyle_config.xml | 130 +++ .../resources/azure-cosmos-spark.properties | 2 + .../spark/ChangeFeedInitialOffsetWriter.scala | 60 ++ .../spark/ChangeFeedMicroBatchStream.scala | 271 +++++ .../spark/CosmosBytesWrittenMetric.scala | 11 + .../azure/cosmos/spark/CosmosCatalog.scala | 59 ++ .../cosmos/spark/CosmosCatalogBase.scala | 727 +++++++++++++ .../spark/CosmosRecordsWrittenMetric.scala | 11 + .../cosmos/spark/CosmosRowConverter.scala | 127 +++ .../com/azure/cosmos/spark/CosmosWriter.scala | 109 ++ .../com/azure/cosmos/spark/ItemsScan.scala | 41 + .../azure/cosmos/spark/ItemsScanBuilder.scala | 137 +++ .../cosmos/spark/ItemsWriterBuilder.scala | 185 ++++ .../cosmos/spark/RowSerializerPool.scala | 29 + .../cosmos/spark/SparkInternalsBridge.scala | 107 ++ .../spark/TotalRequestChargeMetric.scala | 11 + ...osmos.spark.CosmosClientBuilderInterceptor | 1 + ...azure.cosmos.spark.CosmosClientInterceptor | 1 + ...cosmos.spark.WriteOnRetryCommitInterceptor | 1 + .../ChangeFeedMetricsListenerITest.scala | 157 +++ .../cosmos/spark/CosmosCatalogITest.scala | 103 ++ .../cosmos/spark/CosmosCatalogITestBase.scala | 973 ++++++++++++++++++ .../cosmos/spark/CosmosRowConverterTest.scala | 97 ++ .../azure/cosmos/spark/ItemsScanITest.scala | 256 +++++ .../cosmos/spark/RowSerializerPollTest.scala | 27 + .../cosmos/spark/SparkE2EQueryITest.scala | 70 ++ 30 files changed, 4338 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/CHANGELOG.md create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/CONTRIBUTING.md create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/README.md create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/pom.xml create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/scalastyle_config.xml create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/resources/azure-cosmos-spark.properties create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedInitialOffsetWriter.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedMicroBatchStream.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosBytesWrittenMetric.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalog.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalogBase.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRecordsWrittenMetric.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRowConverter.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosWriter.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScan.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScanBuilder.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsWriterBuilder.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/RowSerializerPool.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/SparkInternalsBridge.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/TotalRequestChargeMetric.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientBuilderInterceptor create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ChangeFeedMetricsListenerITest.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITest.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITestBase.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosRowConverterTest.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ItemsScanITest.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/RowSerializerPollTest.scala create mode 100644 sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CHANGELOG.md new file mode 100644 index 000000000000..3972ae6aeb98 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CHANGELOG.md @@ -0,0 +1,56 @@ +## Release History + +### 4.47.0-beta.1 (Unreleased) + +#### Features Added + +#### Breaking Changes + +#### Bugs Fixed + +#### Other Changes + +### 4.46.0 (2026-03-27) + +#### Bugs Fixed +* Fixed an issue where creating containers with hierarchical partition keys (multi-hash) through the Spark catalog on the AAD path would fail. - See [PR 48548](https://github.com/Azure/azure-sdk-for-java/pull/48548) + +### 4.45.0 (2026-03-13) + +#### Features Added +* Added `vectorEmbeddingPolicy` support in Spark catalog `TBLPROPERTIES` for creating vector-search-enabled containers. - See [PR 48349](https://github.com/Azure/azure-sdk-for-java/pull/48349) + +### 4.44.2 (2026-03-05) + +#### Other Changes +* Changed azure-resourcemanager-cosmos usage to a pinned version which is deployed across all public and non-public clouds - [PR 48268](https://github.com/Azure/azure-sdk-for-java/pull/48268) + +### 4.44.1 (2026-03-03) + +#### Other Changes +* Reduced noisy warning logs in Gateway mode - [PR 48189](https://github.com/Azure/azure-sdk-for-java/pull/48189) + +### 4.44.0 (2026-02-27) + +#### Features Added +* Added config entry `spark.cosmos.account.azureEnvironment.management.scope` to allow specifying the Entra ID scope/audience to be used when retrieving tokens to authenticate against the ARM/management endpoint of non-public clouds. - See [PR 48137](https://github.com/Azure/azure-sdk-for-java/pull/48137) + +### 4.43.1 (2026-02-25) + +#### Bugs Fixed +* Fixed an issue where `TransientIOErrorsRetryingIterator` would trigger extra query during retries and on close. - See [PR 47996](https://github.com/Azure/azure-sdk-for-java/pull/47996) + +#### Other Changes +* Added status code history in `BulkWriterNoProgressException` error message. - See [PR 48022](https://github.com/Azure/azure-sdk-for-java/pull/48022) +* Reduced the log noise level for frequent transient errors - for example throttling - in Gateway mode - [PR 48112](https://github.com/Azure/azure-sdk-for-java/pull/48112) + +### 4.43.0 (2026-02-10) + +#### Features Added +* Initial release of Spark 4.0 connector with Scala 2.13 support +* Added transactional batch support. See [PR 47478](https://github.com/Azure/azure-sdk-for-java/pull/47478) and [PR 47697](https://github.com/Azure/azure-sdk-for-java/pull/47697) and [47803](https://github.com/Azure/azure-sdk-for-java/pull/47803) +* Added support for throughput bucket. - See [47856](https://github.com/Azure/azure-sdk-for-java/pull/47856) + +#### Other Changes + +### NOTE: See CHANGELOG.md in 3.3, 3.4, and 3.5 projects for changes in prior Spark versions diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CONTRIBUTING.md b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CONTRIBUTING.md new file mode 100644 index 000000000000..2435e3acead4 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CONTRIBUTING.md @@ -0,0 +1,84 @@ +# Contributing +This instruction is guideline for building and code contribution. + +## Prerequisites +- JDK 17 or above (Spark 4.0 requires Java 17+) +- [Maven](https://maven.apache.org/) 3.0 and above + +## Build from source +To build the project, run maven commands. + +```bash +git clone https://github.com/Azure/azure-sdk-for-java.git +cd sdk/cosmos/azure-cosmos-spark_4-0_2-13 +mvn clean install +``` + +## Test +There are integration tests on azure and on emulator to trigger integration test execution +against Azure Cosmos DB and against +[Azure Cosmos DB Emulator](https://docs.microsoft.com/azure/cosmos-db/local-emulator), you need to +follow the link to set up emulator before test execution. + +- Run unit tests +```bash +mvn clean install -Dgpg.skip +``` + +- Run integration tests + - on Azure + > **NOTE** Please note that integration test against Azure requires Azure Cosmos DB Document + API and will automatically create a Cosmos database in your Azure subscription, then there + will be **Azure usage fee.** + + Integration tests will require a Azure Subscription. If you don't already have an Azure + subscription, you can activate your + [MSDN subscriber benefits](https://azure.microsoft.com/pricing/member-offers/msdn-benefits-details/) + or sign up for a [free Azure account](https://azure.microsoft.com/free/). + + 1. Create an Azure Cosmos DB on Azure. + - Go to [Azure portal](https://portal.azure.com/) and click +New. + - Click Databases, and then click Azure Cosmos DB to create your database. + - Navigate to the database you have created, and click Access keys and copy your + URI and access keys for your database. + + 2. Set environment variables ACCOUNT_HOST, ACCOUNT_KEY and SECONDARY_ACCOUNT_KEY, where value + of them are Cosmos account URI, primary key and secondary key. + + So set the + second group environment variables NEW_ACCOUNT_HOST, NEW_ACCOUNT_KEY and + NEW_SECONDARY_ACCOUNT_KEY, the two group environment variables can be same. + 3. Run maven command with `integration-test-azure` profile. + + ```bash + set ACCOUNT_HOST=your-cosmos-account-uri + set ACCOUNT_KEY=your-cosmos-account-primary-key + set SECONDARY_ACCOUNT_KEY=your-cosmos-account-secondary-key + + set NEW_ACCOUNT_HOST=your-cosmos-account-uri + set NEW_ACCOUNT_KEY=your-cosmos-account-primary-key + set NEW_SECONDARY_ACCOUNT_KEY=your-cosmos-account-secondary-key + mvnw -P integration-test-azure clean install + ``` + + - on Emulator + + Setup Azure Cosmos DB Emulator by following + [this instruction](https://docs.microsoft.com/azure/cosmos-db/local-emulator), and set + associated environment variables. Then run test with: + ```bash + mvnw -P integration-test-emulator install + ``` + + +- Skip tests execution +```bash +mvn clean install -Dgpg.skip -DskipTests +``` + +## Version management +Developing version naming convention is like `0.1.2-beta.1`. Release version naming convention is like `0.1.2`. + +## Contribute to code +Contribution is welcome. Please follow +[this instruction](https://github.com/Azure/azure-sdk-for-java/blob/main/CONTRIBUTING.md) to contribute code. diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/README.md b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/README.md new file mode 100644 index 000000000000..81869fdba085 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/README.md @@ -0,0 +1,233 @@ +# Azure Cosmos DB OLTP Spark 4 connector + +## Azure Cosmos DB OLTP Spark 4 connector for Spark 4.0 +**Azure Cosmos DB OLTP Spark connector** provides Apache Spark support for Azure Cosmos DB using +the [SQL API][sql_api_query]. +[Azure Cosmos DB][cosmos_introduction] is a globally-distributed database service which allows +developers to work with data using a variety of standard APIs, such as SQL, MongoDB, Cassandra, Graph, and Table. + +If you have any feedback or ideas on how to improve your experience please let us know here: +https://github.com/Azure/azure-sdk-for-java/issues/new + +### Documentation + +- [Getting started](https://aka.ms/azure-cosmos-spark-3-quickstart) +- [Catalog API](https://aka.ms/azure-cosmos-spark-3-catalog-api) +- [Configuration Parameter Reference](https://aka.ms/azure-cosmos-spark-3-config) + +### Version Compatibility + +#### azure-cosmos-spark_4-0_2-13 +| Connector | Supported Spark Versions | Minimum Java Version | Supported Scala Versions | Supported Databricks Runtimes | Supported Fabric Runtimes | +|-----------|--------------------------|----------------------|---------------------------|-------------------------------|---------------------------| +| 4.46.0 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD | +| 4.45.0 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD | +| 4.44.2 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD | +| 4.44.1 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD | +| 4.44.0 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD | +| 4.43.1 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD | +| 4.43.0 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD | + +Note: Spark 4.0 requires Scala 2.13 and Java 17 or higher. When using the Scala API, it is necessary for applications +to use Scala 2.13 that Spark 4.0 was compiled for. + +#### azure-cosmos-spark_3-3_2-12 +| Connector | Supported Spark Versions | Supported JVM Versions | Supported Scala Versions | Supported Databricks Runtimes | +|-----------|--------------------------|------------------------|--------------------------|-------------------------------| +| 4.46.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.45.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.44.2 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.44.1 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.44.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.43.1 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.43.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.42.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.41.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.40.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.39.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.38.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.37.2 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.37.1 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.37.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.36.1 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.36.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.35.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.34.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.33.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.33.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.32.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.32.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.31.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.30.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.29.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.28.4 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.28.3 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.28.2 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.28.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.28.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.27.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.27.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.26.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.26.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.25.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.25.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.24.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.24.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.23.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.22.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.21.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.21.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* | +| 4.20.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* | +| 4.19.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* | +| 4.18.2 | 3.3.0 | [8, 11] | 2.12 | 11.\* | +| 4.18.1 | 3.3.0 | [8, 11] | 2.12 | 11.\* | +| 4.18.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* | +| 4.17.2 | 3.3.0 | [8, 11] | 2.12 | 11.\* | +| 4.17.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* | +| 4.16.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* | +| 4.15.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* | + +#### azure-cosmos-spark_3-4_2-12 +| Connector | Supported Spark Versions | Supported JVM Versions | Supported Scala Versions | Supported Databricks Runtimes | Supported Fabric Runtimes | +|-----------|--------------------------|------------------------|--------------------------|-------------------------------|---------------------------| +| 4.46.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.45.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.44.2 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.44.1 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.44.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.43.1 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.43.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.42.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.41.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.40.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.39.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.38.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.37.2 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.37.1 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.37.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.36.1 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | | +| 4.36.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.35.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.34.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.33.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.33.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.32.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.32.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.31.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.30.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.29.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.28.4 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.28.3 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.28.2 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.28.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.28.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.27.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.27.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.26.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.26.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.25.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.25.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.24.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.24.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.23.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.22.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.21.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | | +| 4.21.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | | + +#### azure-cosmos-spark_3-5_2-12 +| Connector | Supported Spark Versions | Minimum Java Version | Supported Scala Versions | Supported Databricks Runtimes | Supported Fabric Runtimes | +|-----------|--------------------------|-----------------------|---------------------------|-------------------------------|---------------------------| +| 4.46.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* | +| 4.45.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* | +| 4.44.2 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* | +| 4.44.1 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* | +| 4.44.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* | +| 4.43.1 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* | +| 4.43.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* | +| 4.42.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* | +| 4.41.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* | +| 4.40.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.39.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.38.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.37.2 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.37.1 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.37.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.36.1 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.36.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.35.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.34.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.33.1 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.33.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.32.1 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.32.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.31.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.30.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | +| 4.29.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | | + +Note: Java 8 prior to version 8u371 support is deprecated as of Spark 3.5.0. When using the Scala API, it is necessary for applications +to use the same version of Scala that Spark was compiled for. + +#### azure-cosmos-spark_3-5_2-13 +| Connector | Supported Spark Versions | Minimum Java Version | Supported Scala Versions | Supported Databricks Runtimes | Supported Fabric Runtimes | +|-----------|--------------------------|-----------------------|---------------------------|-------------------------------|---------------------------| +| 4.46.0 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD | +| 4.45.0 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD | +| 4.44.2 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD | +| 4.44.1 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD | +| 4.44.0 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD | +| 4.43.1 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD | +| 4.43.0 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD | + +### Download + +You can use the maven coordinate of the jar to auto install the Spark Connector to your Databricks Runtime from Maven: +`com.azure.cosmos.spark:azure-cosmos-spark_4-0_2-13:4.46.0` + +You can also integrate against Cosmos DB Spark Connector in your SBT project: +```scala +libraryDependencies += "com.azure.cosmos.spark" % "azure-cosmos-spark_4-0_2-13" % "4.46.0" +``` + +Cosmos DB Spark Connector is available on [Maven Central Repo](https://central.sonatype.com/search?namespace=com.azure.cosmos.spark). + +#### General + +If you encounter any bug, please file an issue [here](https://github.com/Azure/azure-sdk-for-java/issues/new). + +To suggest a new feature or changes that could be made, file an issue the same way you would for a bug. + +### License +This project is under MIT license and uses and repackages other third party libraries as an uber jar. +See [NOTICE.txt](https://github.com/Azure/azure-sdk-for-java/blob/main/NOTICE.txt). + +### Contributing + +This project welcomes contributions and suggestions. Most contributions require you to agree to a +[Contributor License Agreement (CLA)][cla] declaring that you have the right to, and actually do, grant us the rights +to use your contribution. + +When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate +the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to +do this once across all repos using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct][coc]. For more information see the [Code of Conduct FAQ][coc_faq] +or contact [opencode@microsoft.com][coc_contact] with any additional questions or comments. + + +[source_code]: src +[cosmos_introduction]: https://learn.microsoft.com/azure/cosmos-db/ +[cosmos_docs]: https://learn.microsoft.com/azure/cosmos-db/introduction +[jdk]: https://learn.microsoft.com/java/azure/jdk/?view=azure-java-stable +[maven]: https://maven.apache.org/ +[cla]: https://cla.microsoft.com +[coc]: https://opensource.microsoft.com/codeofconduct/ +[coc_faq]: https://opensource.microsoft.com/codeofconduct/faq/ +[coc_contact]: mailto:opencode@microsoft.com +[azure_subscription]: https://azure.microsoft.com/free/ +[samples]: https://github.com/Azure/azure-sdk-for-java/tree/main/sdk/spring/azure-spring-data-cosmos/src/samples/java/com/azure/spring/data/cosmos +[sql_api_query]: https://learn.microsoft.com/azure/cosmos-db/sql-api-sql-query +[local_emulator]: https://learn.microsoft.com/azure/cosmos-db/local-emulator +[local_emulator_export_ssl_certificates]: https://learn.microsoft.com/azure/cosmos-db/local-emulator-export-ssl-certificates +[azure_cosmos_db_partition]: https://learn.microsoft.com/azure/cosmos-db/partition-data +[sql_queries_in_cosmos]: https://learn.microsoft.com/azure/cosmos-db/tutorial-query-sql-api +[sql_queries_getting_started]: https://learn.microsoft.com/azure/cosmos-db/sql-query-getting-started diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/pom.xml b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/pom.xml new file mode 100644 index 000000000000..1499d93ee4d3 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/pom.xml @@ -0,0 +1,262 @@ + + + 4.0.0 + + com.azure.cosmos.spark + azure-cosmos-spark_3 + 0.0.1-beta.1 + ../azure-cosmos-spark_3 + + com.azure.cosmos.spark + azure-cosmos-spark_4-1_2-13 + 4.47.0-beta.1 + jar + https://github.com/Azure/azure-sdk-for-java/tree/main/sdk/cosmos/azure-cosmos-spark_4-1_2-13 + OLTP Spark 4.1 Connector for Azure Cosmos DB SQL API + OLTP Spark 4.1 Connector for Azure Cosmos DB SQL API + + scm:git:https://github.com/Azure/azure-sdk-for-java.git/sdk/cosmos/azure-cosmos-spark_4-1_2-13 + + https://github.com/Azure/azure-sdk-for-java/sdk/cosmos/azure-cosmos-spark_4-1_2-13 + + + Microsoft Corporation + http://microsoft.com + + + + The MIT License (MIT) + http://opensource.org/licenses/MIT + repo + + + + + microsoft + Microsoft Corporation + + + + false + 4.1 + 2.13 + 2.13.17 + 0.9.1 + 0.8.0 + 3.2.2 + 3.2.3 + 3.2.3 + 5.0.0 + true + + + + + + org.apache.maven.plugins + maven-resources-plugin + 2.4.3 + + + copy-shared-sources + initialize + + copy-resources + + + ${project.build.directory}/shared-sources + + + ${basedir}/../azure-cosmos-spark_3/src/main/scala + + **/ChangeFeedInitialOffsetWriter.scala + **/CosmosCatalogBase.scala + + + + + + + copy-shared-test-sources + initialize + + copy-resources + + + ${project.build.directory}/shared-test-sources + + + ${basedir}/../azure-cosmos-spark_3/src/test/scala + + **/CosmosCatalogITestBase.scala + + + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + 3.6.1 + + + add-sources + generate-sources + + add-source + + + + ${project.build.directory}/shared-sources + ${basedir}/src/main/scala + + + + + add-test-sources + generate-test-sources + + add-test-source + + + + ${project.build.directory}/shared-test-sources + ${basedir}/src/test/scala + + + + + add-resources + generate-resources + + add-resource + + + + ${basedir}/../azure-cosmos-spark_3/src/main/resources + ${basedir}/src/main/resources + + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + 3.6.1 + + + + + + + spark-e2e_4-1_2-13 + + + [17,) + + ${basedir}/scalastyle_config.xml + + + spark-e2e_4-1_2-13 + true + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.5.3 + + + **/*.* + **/*Test.* + **/*Suite.* + **/*Spec.* + + true + + + + org.scalatest + scalatest-maven-plugin + 2.1.0 + + ${scalatest.argLine} + stdOut=true,verbose=true,stdErr=true + false + FDEF + FDEF + once + true + ${project.build.directory}/surefire-reports + . + SparkTestSuite.txt + (ITest|Test|Spec|Suite) + + + + test + + test + + + + + + + + + + spark-4-1-disable-tests-java-lt-17 + + (,17) + + + true + + + + java9-plus + + [9,) + + + --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false + + + + + + org.apache.spark + spark-sql_2.13 + 4.1.0 + + + io.netty + netty-all + + + org.slf4j + * + + + provided + + + com.fasterxml.jackson.core + jackson-databind + 2.18.4 + + + com.fasterxml.jackson.module + jackson-module-scala_2.13 + 2.18.4 + + + diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/scalastyle_config.xml b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/scalastyle_config.xml new file mode 100644 index 000000000000..7a8ad2823fb8 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/scalastyle_config.xml @@ -0,0 +1,130 @@ + + Scalastyle standard configuration + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/resources/azure-cosmos-spark.properties b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/resources/azure-cosmos-spark.properties new file mode 100644 index 000000000000..ca812989b4f2 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/resources/azure-cosmos-spark.properties @@ -0,0 +1,2 @@ +name=${project.artifactId} +version=${project.version} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedInitialOffsetWriter.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedInitialOffsetWriter.scala new file mode 100644 index 000000000000..c4687f4102ed --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedInitialOffsetWriter.scala @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.checkpointing.{HDFSMetadataLog, MetadataVersionUtil} + +import java.io.{BufferedWriter, InputStream, InputStreamReader, OutputStream, OutputStreamWriter} +import java.nio.charset.StandardCharsets + +private class ChangeFeedInitialOffsetWriter +( + sparkSession: SparkSession, + metadataPath: String +) extends HDFSMetadataLog[String](sparkSession, metadataPath) { + + val VERSION = 1 + + override def serialize(offsetJson: String, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write(s"v$VERSION\n") + writer.write(offsetJson) + writer.flush() + } + + override def deserialize(in: InputStream): String = { + val content = readerToString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog would never create a partial file. + require(content.nonEmpty) + val indexOfNewLine = content.indexOf("\n") + if (content(0) != 'v' || indexOfNewLine < 0) { + throw new IllegalStateException( + "Log file was malformed: failed to detect the log file version line.") + } + + MetadataVersionUtil.validateVersion(content.substring(0, indexOfNewLine), VERSION) + content.substring(indexOfNewLine + 1) + } + + private def readerToString(reader: java.io.Reader): String = { + val writer = new StringBuilderWriter + val buffer = new Array[Char](4096) + Stream.continually(reader.read(buffer)).takeWhile(_ != -1).foreach(writer.write(buffer, 0, _)) + writer.toString + } + + private class StringBuilderWriter extends java.io.Writer { + private val stringBuilder = new StringBuilder + + override def write(cbuf: Array[Char], off: Int, len: Int): Unit = { + stringBuilder.appendAll(cbuf, off, len) + } + + override def flush(): Unit = {} + + override def close(): Unit = {} + + override def toString: String = stringBuilder.toString() + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedMicroBatchStream.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedMicroBatchStream.scala new file mode 100644 index 000000000000..bf4632cf609a --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedMicroBatchStream.scala @@ -0,0 +1,271 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.changeFeedMetrics.{ChangeFeedMetricsListener, ChangeFeedMetricsTracker} +import com.azure.cosmos.implementation.SparkBridgeImplementationInternal +import com.azure.cosmos.implementation.guava25.collect.{HashBiMap, Maps} +import com.azure.cosmos.spark.CosmosPredicates.{assertNotNull, assertNotNullOrEmpty, assertOnSparkDriver} +import com.azure.cosmos.spark.diagnostics.{DiagnosticsContext, LoggerHelper} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset, ReadLimit, SupportsAdmissionControl} +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} +import org.apache.spark.sql.types.StructType + +import java.time.Duration +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicLong + +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +// scala style rule flaky - even complaining on partial log messages +// scalastyle:off multiple.string.literals +private class ChangeFeedMicroBatchStream +( + val session: SparkSession, + val schema: StructType, + val config: Map[String, String], + val cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots], + val checkpointLocation: String, + diagnosticsConfig: DiagnosticsConfig +) extends MicroBatchStream + with SupportsAdmissionControl { + + @transient private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass) + + private val correlationActivityId = UUID.randomUUID() + private val streamId = correlationActivityId.toString + log.logTrace(s"Instantiated ${this.getClass.getSimpleName}.$streamId") + + private val defaultParallelism = session.sparkContext.defaultParallelism + private val readConfig = CosmosReadConfig.parseCosmosReadConfig(config) + private val sparkEnvironmentInfo = CosmosClientConfiguration.getSparkEnvironmentInfo(Some(session)) + private val clientConfiguration = CosmosClientConfiguration.apply( + config, + readConfig.readConsistencyStrategy, + sparkEnvironmentInfo) + private val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(config) + private val partitioningConfig = CosmosPartitioningConfig.parseCosmosPartitioningConfig(config) + private val changeFeedConfig = CosmosChangeFeedConfig.parseCosmosChangeFeedConfig(config) + private val clientCacheItem = CosmosClientCache( + clientConfiguration, + Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches), + s"ChangeFeedMicroBatchStream(streamId $streamId)") + private val throughputControlClientCacheItemOpt = + ThroughputControlHelper.getThroughputControlClientCacheItem( + config, clientCacheItem.context, Some(cosmosClientStateHandles), sparkEnvironmentInfo) + private val container = + ThroughputControlHelper.getContainer( + config, + containerConfig, + clientCacheItem, + throughputControlClientCacheItemOpt) + + private var latestOffsetSnapshot: Option[ChangeFeedOffset] = None + + private val partitionIndex = new AtomicLong(0) + private val partitionIndexMap = Maps.synchronizedBiMap(HashBiMap.create[NormalizedRange, Long]()) + private val partitionMetricsMap = new ConcurrentHashMap[NormalizedRange, ChangeFeedMetricsTracker]() + + if (changeFeedConfig.performanceMonitoringEnabled) { + log.logInfo("ChangeFeed performance monitoring is enabled, registering ChangeFeedMetricsListener") + session.sparkContext.addSparkListener(new ChangeFeedMetricsListener(partitionIndexMap, partitionMetricsMap)) + } else { + log.logInfo("ChangeFeed performance monitoring is disabled") + } + + override def latestOffset(): Offset = { + // For Spark data streams implementing SupportsAdmissionControl trait + // latestOffset(Offset, ReadLimit) is called instead + throw new UnsupportedOperationException( + "latestOffset(Offset, ReadLimit) should be called instead of this method") + } + + /** + * Returns a list of `InputPartition` given the start and end offsets. Each + * `InputPartition` represents a data split that can be processed by one Spark task. The + * number of input partitions returned here is the same as the number of RDD partitions this scan + * outputs. + *

+ * If the `Scan` supports filter push down, this stream is likely configured with a filter + * and is responsible for creating splits for that filter, which is not a full scan. + *

+ *

+ * This method will be called multiple times, to launch one Spark job for each micro-batch in this + * data stream. + *

+ */ + override def planInputPartitions(startOffset: Offset, endOffset: Offset): Array[InputPartition] = { + assertNotNull(startOffset, "startOffset") + assertNotNull(endOffset, "endOffset") + assert(startOffset.isInstanceOf[ChangeFeedOffset], "Argument 'startOffset' is not a change feed offset.") + assert(endOffset.isInstanceOf[ChangeFeedOffset], "Argument 'endOffset' is not a change feed offset.") + + log.logDebug(s"--> planInputPartitions.$streamId, startOffset: ${startOffset.json()} - endOffset: ${endOffset.json()}") + val start = startOffset.asInstanceOf[ChangeFeedOffset] + val end = endOffset.asInstanceOf[ChangeFeedOffset] + + val startChangeFeedState = new String(java.util.Base64.getUrlDecoder.decode(start.changeFeedState)) + log.logDebug(s"Start-ChangeFeedState.$streamId: $startChangeFeedState") + + val endChangeFeedState = new String(java.util.Base64.getUrlDecoder.decode(end.changeFeedState)) + log.logDebug(s"End-ChangeFeedState.$streamId: $endChangeFeedState") + + assert(end.inputPartitions.isDefined, "Argument 'endOffset.inputPartitions' must not be null or empty.") + + val parsedStartChangeFeedState = SparkBridgeImplementationInternal.parseChangeFeedState(start.changeFeedState) + end + .inputPartitions + .get + .map(partition => { + val index = partitionIndexMap.asScala.getOrElseUpdate(partition.feedRange, partitionIndex.incrementAndGet()) + partition + .withContinuationState( + SparkBridgeImplementationInternal + .extractChangeFeedStateForRange(parsedStartChangeFeedState, partition.feedRange), + clearEndLsn = false) + .withIndex(index) + }) + } + + /** + * Returns a factory to create a `PartitionReader` for each `InputPartition`. + */ + override def createReaderFactory(): PartitionReaderFactory = { + log.logDebug(s"--> createReaderFactory.$streamId") + ChangeFeedScanPartitionReaderFactory( + config, + schema, + DiagnosticsContext(correlationActivityId, checkpointLocation), + cosmosClientStateHandles, + diagnosticsConfig, + CosmosClientConfiguration.getSparkEnvironmentInfo(Some(session))) + } + + /** + * Returns the most recent offset available given a read limit. The start offset can be used + * to figure out how much new data should be read given the limit. Users should implement this + * method instead of latestOffset for a MicroBatchStream or getOffset for Source. + * + * When this method is called on a `Source`, the source can return `null` if there is no + * data to process. In addition, for the very first micro-batch, the `startOffset` will be + * null as well. + * + * When this method is called on a MicroBatchStream, the `startOffset` will be `initialOffset` + * for the very first micro-batch. The source can return `null` if there is no data to process. + */ + // This method is doing all the heavy lifting - after calculating the latest offset + // all information necessary to plan partitions is available - so we plan partitions here and + // serialize them in the end offset returned to avoid any IO calls for the actual partitioning + override def latestOffset(startOffset: Offset, readLimit: ReadLimit): Offset = { + + log.logDebug(s"--> latestOffset.$streamId") + + val startChangeFeedOffset = startOffset.asInstanceOf[ChangeFeedOffset] + val offset = CosmosPartitionPlanner.getLatestOffset( + config, + startChangeFeedOffset, + readLimit, + Duration.ZERO, + this.clientConfiguration, + this.cosmosClientStateHandles, + this.containerConfig, + this.partitioningConfig, + this.defaultParallelism, + this.container, + Some(this.partitionMetricsMap) + ) + + if (offset.changeFeedState != startChangeFeedOffset.changeFeedState) { + log.logDebug(s"<-- latestOffset.$streamId - new offset ${offset.json()}") + this.latestOffsetSnapshot = Some(offset) + offset + } else { + log.logDebug(s"<-- latestOffset.$streamId - Finished returning null") + + this.latestOffsetSnapshot = None + + // scalastyle:off null + // null means no more data to process + // null is used here because the DataSource V2 API is defined in Java + null + // scalastyle:on null + } + } + + /** + * Returns the initial offset for a streaming query to start reading from. Note that the + * streaming data source should not assume that it will start reading from its initial offset: + * if Spark is restarting an existing query, it will restart from the check-pointed offset rather + * than the initial one. + */ + // Mapping start form settings to the initial offset/LSNs + override def initialOffset(): Offset = { + assertOnSparkDriver() + + val metadataLog = new ChangeFeedInitialOffsetWriter( + assertNotNull(session, "session"), + assertNotNullOrEmpty(checkpointLocation, "checkpointLocation")) + val offsetJson = metadataLog.get(0).getOrElse { + val newOffsetJson = CosmosPartitionPlanner.createInitialOffset( + container, containerConfig, changeFeedConfig, partitioningConfig, Some(streamId)) + metadataLog.add(0, newOffsetJson) + newOffsetJson + } + + log.logDebug(s"MicroBatch stream $streamId: Initial offset '$offsetJson'.") + ChangeFeedOffset(offsetJson, None) + } + + /** + * Returns the read limits potentially passed to the data source through options when creating + * the data source. + */ + override def getDefaultReadLimit: ReadLimit = { + this.changeFeedConfig.toReadLimit + } + + /** + * Returns the most recent offset available. + * + * The source can return `null`, if there is no data to process or the source does not support + * to this method. + */ + override def reportLatestOffset(): Offset = { + this.latestOffsetSnapshot.orNull + } + + /** + * Deserialize a JSON string into an Offset of the implementation-defined offset type. + * + * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader + */ + override def deserializeOffset(s: String): Offset = { + log.logDebug(s"MicroBatch stream $streamId: Deserialized offset '$s'.") + ChangeFeedOffset.fromJson(s) + } + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + override def commit(offset: Offset): Unit = { + log.logDebug(s"MicroBatch stream $streamId: Committed offset '${offset.json()}'.") + } + + /** + * Stop this source and free any resources it has allocated. + */ + override def stop(): Unit = { + clientCacheItem.close() + if (throughputControlClientCacheItemOpt.isDefined) { + throughputControlClientCacheItemOpt.get.close() + } + log.logDebug(s"MicroBatch stream $streamId: stopped.") + } +} +// scalastyle:on multiple.string.literals diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosBytesWrittenMetric.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosBytesWrittenMetric.scala new file mode 100644 index 000000000000..9d7f645227bf --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosBytesWrittenMetric.scala @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import org.apache.spark.sql.connector.metric.CustomSumMetric + +private[cosmos] class CosmosBytesWrittenMetric extends CustomSumMetric { + override def name(): String = CosmosConstants.MetricNames.BytesWritten + + override def description(): String = CosmosConstants.MetricNames.BytesWritten +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalog.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalog.scala new file mode 100644 index 000000000000..778c2311e2e0 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalog.scala @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException + +import java.util +// scalastyle:off underscore.import +// scalastyle:on underscore.import +import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException} +import org.apache.spark.sql.connector.catalog.{NamespaceChange, SupportsNamespaces} + +// scalastyle:off underscore.import + +class CosmosCatalog + extends CosmosCatalogBase + with SupportsNamespaces { + + override def listNamespaces(): Array[Array[String]] = { + super.listNamespacesBase() + } + + @throws(classOf[NoSuchNamespaceException]) + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + super.listNamespacesBase(namespace) + } + + @throws(classOf[NoSuchNamespaceException]) + override def loadNamespaceMetadata(namespace: Array[String]): util.Map[String, String] = { + super.loadNamespaceMetadataBase(namespace) + } + + @throws(classOf[NamespaceAlreadyExistsException]) + override def createNamespace(namespace: Array[String], + metadata: util.Map[String, String]): Unit = { + super.createNamespaceBase(namespace, metadata) + } + + @throws(classOf[UnsupportedOperationException]) + override def alterNamespace(namespace: Array[String], + changes: NamespaceChange*): Unit = { + super.alterNamespaceBase(namespace, changes) + } + + @throws(classOf[NoSuchNamespaceException]) + @throws(classOf[NonEmptyNamespaceException]) + override def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = { + if (!cascade) { + if (this.listTables(namespace).length > 0) { + throw new NonEmptyNamespaceException(namespace) + } + } + super.dropNamespaceBase(namespace) + } +} +// scalastyle:on multiple.string.literals +// scalastyle:on number.of.methods +// scalastyle:on file.size.limit diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalogBase.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalogBase.scala new file mode 100644 index 000000000000..6aed2016bba1 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalogBase.scala @@ -0,0 +1,727 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.spark.catalog.{CosmosCatalogConflictException, CosmosCatalogException, CosmosCatalogNotFoundException, CosmosThroughputProperties} +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, NamespaceChange, Table, TableCatalog, TableChange} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.execution.streaming.checkpointing.HDFSMetadataLog +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +import java.util +import scala.annotation.tailrec +import scala.collection.mutable.ArrayBuffer + +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +// CosmosCatalog provides a meta data store for Cosmos database, container control plane +// This will be required for hive integration +// relevant interfaces to implement: +// - SupportsNamespaces (Cosmos Database and Cosmos Container can be modeled as namespace) +// - SupportsCatalogOptions // TODO moderakh +// - CatalogPlugin - A marker interface to provide a catalog implementation for Spark. +// Implementations can provide catalog functions by implementing additional interfaces +// for tables, views, and functions. +// - TableCatalog Catalog methods for working with Tables. + +// All Hive keywords are case-insensitive, including the names of Hive operators and functions. +// scalastyle:off multiple.string.literals +// scalastyle:off number.of.methods +// scalastyle:off file.size.limit +class CosmosCatalogBase + extends CatalogPlugin + with TableCatalog + with BasicLoggingTrait { + + private lazy val sparkSession = SparkSession.active + private lazy val sparkEnvironmentInfo = CosmosClientConfiguration.getSparkEnvironmentInfo(SparkSession.getActiveSession) + + // mutable but only expected to be changed from within initialize method + private var catalogName: String = _ + //private var client: CosmosAsyncClient = _ + private var config: Map[String, String] = _ + private var readConfig: CosmosReadConfig = _ + private var tableOptions: Map[String, String] = _ + private var viewRepository: Option[HDFSMetadataLog[String]] = None + + /** + * Called to initialize configuration. + *
+ * This method is called once, just after the provider is instantiated. + * + * @param name the name used to identify and load this catalog + * @param options a case-insensitive string map of configuration + */ + override def initialize(name: String, + options: CaseInsensitiveStringMap): Unit = { + this.config = CosmosConfig.getEffectiveConfig( + None, + None, + options.asCaseSensitiveMap().asScala.toMap) + this.readConfig = CosmosReadConfig.parseCosmosReadConfig(config) + + tableOptions = toTableConfig(options) + this.catalogName = name + + val viewRepositoryConfig = CosmosViewRepositoryConfig.parseCosmosViewRepositoryConfig(config) + if (viewRepositoryConfig.metaDataPath.isDefined) { + this.viewRepository = Some(new HDFSMetadataLog[String]( + this.sparkSession, + viewRepositoryConfig.metaDataPath.get)) + } + } + + /** + * Catalog implementations are registered to a name by adding a configuration option to Spark: + * spark.sql.catalog.catalog-name=com.example.YourCatalogClass. + * All configuration properties in the Spark configuration that share the catalog name prefix, + * spark.sql.catalog.catalog-name.(key)=(value) will be passed in the case insensitive + * string map of options in initialization with the prefix removed. + * name, is also passed and is the catalog's name; in this case, "catalog-name". + * + * @return catalog name + */ + override def name(): String = catalogName + + /** + * List top-level namespaces from the catalog. + *
+ * If an object such as a table, view, or function exists, its parent namespaces must also exist + * and must be returned by this discovery method. For example, if table a.t exists, this method + * must return ["a"] in the result array. + * + * @return an array of multi-part namespace names. + */ + def listNamespacesBase(): Array[Array[String]] = { + logDebug("catalog:listNamespaces") + + TransientErrorsRetryPolicy.executeWithRetry(() => listNamespacesImpl()) + } + + private[this] def listNamespacesImpl(): Array[Array[String]] = { + logDebug("catalog:listNamespaces") + + Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + None, + s"CosmosCatalog(name $catalogName).listNamespaces" + )) + )) + .to(cosmosClientCacheItems => { + cosmosClientCacheItems(0) + .get + .sparkCatalogClient + .readAllDatabases() + .map(Array(_)) + .collectSeq() + .block() + .toArray + }) + } + + /** + * List namespaces in a namespace. + *
+ * Cosmos supports only single depth database. Hence we always return an empty list of namespaces. + * or throw if the root namespace doesn't exist + */ + @throws(classOf[NoSuchNamespaceException]) + def listNamespacesBase(namespace: Array[String]): Array[Array[String]] = { + loadNamespaceMetadataBase(namespace) // throws NoSuchNamespaceException if namespace doesn't exist + // Cosmos DB only has one single level depth databases + Array.empty[Array[String]] + } + + /** + * Load metadata properties for a namespace. + * + * @param namespace a multi-part namespace + * @return a string map of properties for the given namespace + * @throws NoSuchNamespaceException If the namespace does not exist (optional) + */ + @throws(classOf[NoSuchNamespaceException]) + def loadNamespaceMetadataBase(namespace: Array[String]): util.Map[String, String] = { + + TransientErrorsRetryPolicy.executeWithRetry(() => loadNamespaceMetadataImpl(namespace)) + } + + private[this] def loadNamespaceMetadataImpl( + namespace: Array[String]): util.Map[String, String] = { + + checkNamespace(namespace) + + Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + None, + s"CosmosCatalog(name $catalogName).loadNamespaceMetadata([${namespace.mkString(", ")}])" + )) + )) + .to(clientCacheItems => { + try { + clientCacheItems(0) + .get + .sparkCatalogClient + .readDatabaseThroughput(toCosmosDatabaseName(namespace.head)) + .block() + .asJava + } catch { + case _: CosmosCatalogNotFoundException => + throw new NoSuchNamespaceException(namespace) + } + }) + } + + @throws(classOf[NamespaceAlreadyExistsException]) + def createNamespaceBase(namespace: Array[String], + metadata: util.Map[String, String]): Unit = { + TransientErrorsRetryPolicy.executeWithRetry(() => createNamespaceImpl(namespace, metadata)) + } + + @throws(classOf[NamespaceAlreadyExistsException]) + private[this] def createNamespaceImpl(namespace: Array[String], + metadata: util.Map[String, String]): Unit = { + checkNamespace(namespace) + val databaseName = toCosmosDatabaseName(namespace.head) + + Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + None, + s"CosmosCatalog(name $catalogName).createNamespace([${namespace.mkString(", ")}])" + )) + )) + .to(cosmosClientCacheItems => { + try { + cosmosClientCacheItems(0) + .get + .sparkCatalogClient + .createDatabase(databaseName, metadata.asScala.toMap) + .block() + } catch { + case _: CosmosCatalogConflictException => + throw new NamespaceAlreadyExistsException(namespace) + } + }) + } + + @throws(classOf[UnsupportedOperationException]) + def alterNamespaceBase(namespace: Array[String], + changes: Seq[NamespaceChange]): Unit = { + checkNamespace(namespace) + + if (changes.size > 0) { + val invalidChangesCount = changes + .count(change => !CosmosThroughputProperties.isThroughputProperty(change)) + if (invalidChangesCount > 0) { + throw new UnsupportedOperationException("ALTER NAMESPACE contains unsupported changes.") + } + + val finalThroughputProperty = changes.last.asInstanceOf[NamespaceChange.SetProperty] + + val databaseName = toCosmosDatabaseName(namespace.head) + + alterNamespaceImpl(databaseName, finalThroughputProperty) + } + } + + //scalastyle:off method.length + private def alterNamespaceImpl(databaseName: String, finalThroughputProperty: NamespaceChange.SetProperty): Unit = { + logInfo(s"alterNamespace DB:$databaseName") + + Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + None, + s"CosmosCatalog(name $catalogName).alterNamespace($databaseName)" + )) + )) + .to(cosmosClientCacheItems => { + cosmosClientCacheItems(0).get + .sparkCatalogClient + .alterDatabase(databaseName, finalThroughputProperty) + .block() + }) + } + //scalastyle:on method.length + + /** + * Drop a namespace from the catalog, recursively dropping all objects within the namespace. + * + * @param namespace - a multi-part namespace + * @return true if the namespace was dropped + */ + @throws(classOf[NoSuchNamespaceException]) + def dropNamespaceBase(namespace: Array[String]): Boolean = { + TransientErrorsRetryPolicy.executeWithRetry(() => dropNamespaceImpl(namespace)) + } + + @throws(classOf[NoSuchNamespaceException]) + private[this] def dropNamespaceImpl(namespace: Array[String]): Boolean = { + checkNamespace(namespace) + try { + Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + None, + s"CosmosCatalog(name $catalogName).dropNamespace([${namespace.mkString(", ")}])" + )) + )) + .to(cosmosClientCacheItems => { + cosmosClientCacheItems(0) + .get + .sparkCatalogClient + .deleteDatabase(toCosmosDatabaseName(namespace.head)) + .block() + }) + true + } catch { + case _: CosmosCatalogNotFoundException => + throw new NoSuchNamespaceException(namespace) + } + } + + override def listTables(namespace: Array[String]): Array[Identifier] = { + TransientErrorsRetryPolicy.executeWithRetry(() => listTablesImpl(namespace)) + } + + private[this] def listTablesImpl(namespace: Array[String]): Array[Identifier] = { + checkNamespace(namespace) + val databaseName = toCosmosDatabaseName(namespace.head) + + try { + val cosmosTables = + Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + None, + s"CosmosCatalog(name $catalogName).listTables([${namespace.mkString(", ")}])" + )) + )) + .to(cosmosClientCacheItems => { + cosmosClientCacheItems(0).get + .sparkCatalogClient + .readAllContainers(databaseName) + .map(containerId => getContainerIdentifier(namespace.head, containerId)) + .collectSeq() + .block() + .toList + }) + + val tableIdentifiers = this.tryGetViewDefinitions(databaseName) match { + case Some(viewDefinitions) => + cosmosTables ++ viewDefinitions.map(viewDef => getContainerIdentifier(namespace.head, viewDef)).toIterable + case None => cosmosTables + } + + tableIdentifiers.toArray + } catch { + case _: CosmosCatalogNotFoundException => + throw new NoSuchNamespaceException(namespace) + } + } + + override def loadTable(ident: Identifier): Table = { + TransientErrorsRetryPolicy.executeWithRetry(() => loadTableImpl(ident)) + } + + private[this] def loadTableImpl(ident: Identifier): Table = { + checkNamespace(ident.namespace()) + val databaseName = toCosmosDatabaseName(ident.namespace().head) + val containerName = toCosmosContainerName(ident.name()) + logInfo(s"loadTable DB:$databaseName, Container: $containerName") + + this.tryGetContainerMetadata(databaseName, containerName) match { + case Some(tableProperties) => + new ItemsTable( + sparkSession, + Array[Transform](), + Some(databaseName), + Some(containerName), + tableOptions.asJava, + None, + tableProperties) + case None => + this.tryGetViewDefinition(databaseName, containerName) match { + case Some(viewDefinition) => + val effectiveOptions = tableOptions ++ viewDefinition.options + new ItemsReadOnlyTable( + sparkSession, + Array[Transform](), + None, + None, + effectiveOptions.asJava, + viewDefinition.userProvidedSchema) + case None => + throw new NoSuchTableException(ident) + } + } + } + + override def createTable(ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + + TransientErrorsRetryPolicy.executeWithRetry(() => + createTableImpl(ident, schema, partitions, properties)) + } + + private[this] def createTableImpl(ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + checkNamespace(ident.namespace()) + + val databaseName = toCosmosDatabaseName(ident.namespace().head) + val containerName = toCosmosContainerName(ident.name()) + val containerProperties = properties.asScala.toMap + + if (CosmosViewRepositoryConfig.isCosmosView(containerProperties)) { + createViewTable(ident, databaseName, containerName, schema, partitions, containerProperties) + } else { + createPhysicalTable(databaseName, containerName, schema, partitions, containerProperties) + } + } + + @throws(classOf[UnsupportedOperationException]) + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + checkNamespace(ident.namespace()) + + if (changes.size > 0) { + val invalidChangesCount = changes + .count(change => !CosmosThroughputProperties.isThroughputProperty(change)) + if (invalidChangesCount > 0) { + throw new UnsupportedOperationException("ALTER TABLE contains unsupported changes.") + } + + val finalThroughputProperty = changes.last.asInstanceOf[TableChange.SetProperty] + + val tableBeforeModification = loadTableImpl(ident) + if (!tableBeforeModification.isInstanceOf[ItemsTable]) { + throw new UnsupportedOperationException("ALTER TABLE cannot be applied to Cosmos views.") + } + + val databaseName = toCosmosDatabaseName(ident.namespace().head) + val containerName = toCosmosContainerName(ident.name()) + + alterPhysicalTable(databaseName, containerName, finalThroughputProperty) + } + + loadTableImpl(ident) + } + + override def dropTable(ident: Identifier): Boolean = { + TransientErrorsRetryPolicy.executeWithRetry(() => dropTableImpl(ident)) + } + + private[this] def dropTableImpl(ident: Identifier): Boolean = { + checkNamespace(ident.namespace()) + + val databaseName = toCosmosDatabaseName(ident.namespace().head) + val containerName = toCosmosContainerName(ident.name()) + + if (deleteViewTable(databaseName, containerName)) { + true + } else { + this.deletePhysicalTable(databaseName, containerName) + } + } + + @throws(classOf[UnsupportedOperationException]) + override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { + throw new UnsupportedOperationException("renaming table not supported") + } + + //scalastyle:off method.length + private def createPhysicalTable(databaseName: String, + containerName: String, + schema: StructType, + partitions: Array[Transform], + containerProperties: Map[String, String]): Table = { + logInfo(s"createPhysicalTable DB:$databaseName, Container: $containerName") + + Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + None, + s"CosmosCatalog(name $catalogName).createPhysicalTable($databaseName, $containerName)" + )) + )) + .to(cosmosClientCacheItems => { + cosmosClientCacheItems(0).get + .sparkCatalogClient + .createContainer(databaseName, containerName, containerProperties) + .block() + }) + + val effectiveOptions = tableOptions ++ containerProperties + + new ItemsTable( + sparkSession, + partitions, + Some(databaseName), + Some(containerName), + effectiveOptions.asJava, + Option.apply(schema)) + } + //scalastyle:on method.length + + //scalastyle:off method.length + @tailrec + private def createViewTable(ident: Identifier, + databaseName: String, + viewName: String, + schema: StructType, + partitions: Array[Transform], + containerProperties: Map[String, String]): Table = { + + logInfo(s"createViewTable DB:$databaseName, View: $viewName") + + this.viewRepository match { + case Some(viewRepositorySnapshot) => + val userProvidedSchema = if (schema != null && schema.length > 0) { + Some(schema) + } else { + None + } + val viewDefinition = ViewDefinition( + databaseName, viewName, userProvidedSchema, redactAuthInfo(containerProperties)) + var lastBatchId = 0L + val newViewDefinitionsSnapshot = viewRepositorySnapshot.getLatest() match { + case Some(viewDefinitionsEnvelopeSnapshot) => + lastBatchId = viewDefinitionsEnvelopeSnapshot._1 + val alreadyExistingViews = ViewDefinitionEnvelopeSerializer.fromJson(viewDefinitionsEnvelopeSnapshot._2) + + if (alreadyExistingViews.exists(v => v.databaseName.equals(databaseName) && + v.viewName.equals(viewName))) { + + throw new IllegalArgumentException(s"View '$viewName' already exists in database '$databaseName'") + } + + alreadyExistingViews ++ Array(viewDefinition) + case None => Array(viewDefinition) + } + + if (viewRepositorySnapshot.add( + lastBatchId + 1, + ViewDefinitionEnvelopeSerializer.toJson(newViewDefinitionsSnapshot))) { + + logInfo(s"LatestBatchId: ${viewRepositorySnapshot.getLatestBatchId().getOrElse(-1)}") + viewRepositorySnapshot.purge(lastBatchId) + logInfo(s"LatestBatchId: ${viewRepositorySnapshot.getLatestBatchId().getOrElse(-1)}") + val effectiveOptions = tableOptions ++ viewDefinition.options + + new ItemsReadOnlyTable( + sparkSession, + partitions, + None, + None, + effectiveOptions.asJava, + userProvidedSchema) + } else { + createViewTable(ident, databaseName, viewName, schema, partitions, containerProperties) + } + case None => + throw new IllegalArgumentException( + s"Catalog configuration for '${CosmosViewRepositoryConfig.MetaDataPathKeyName}' must " + + "be set when creating views'") + } + } + //scalastyle:on method.length + + //scalastyle:off method.length + private def alterPhysicalTable(databaseName: String, + containerName: String, + finalThroughputProperty: TableChange.SetProperty): Unit = { + logInfo(s"alterPhysicalTable DB:$databaseName, Container: $containerName") + + Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + None, + s"CosmosCatalog(name $catalogName).alterPhysicalTable($databaseName, $containerName)" + )) + )) + .to(cosmosClientCacheItems => { + cosmosClientCacheItems(0).get + .sparkCatalogClient + .alterContainer(databaseName, containerName, finalThroughputProperty) + .block() + }) + } + //scalastyle:on method.length + + private def deletePhysicalTable(databaseName: String, containerName: String): Boolean = { + try { + Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + None, + s"CosmosCatalog(name $catalogName).deletePhysicalTable($databaseName, $containerName)" + )) + )) + .to (cosmosClientCacheItems => + cosmosClientCacheItems(0).get + .sparkCatalogClient + .deleteContainer(databaseName, containerName)) + .block() + true + } catch { + case _: CosmosCatalogNotFoundException => false + } + } + + @tailrec + private def deleteViewTable(databaseName: String, viewName: String): Boolean = { + logInfo(s"deleteViewTable DB:$databaseName, View: $viewName") + + this.viewRepository match { + case Some(viewRepositorySnapshot) => + viewRepositorySnapshot.getLatest() match { + case Some(viewDefinitionsEnvelopeSnapshot) => + val lastBatchId = viewDefinitionsEnvelopeSnapshot._1 + val viewDefinitions = ViewDefinitionEnvelopeSerializer.fromJson(viewDefinitionsEnvelopeSnapshot._2) + + viewDefinitions.find(v => v.databaseName.equals(databaseName) && + v.viewName.equals(viewName)) match { + case Some(existingView) => + val updatedViewDefinitionsSnapshot: Array[ViewDefinition] = + ArrayBuffer(viewDefinitions: _*).filterNot(_ == existingView).toArray + + if (viewRepositorySnapshot.add( + lastBatchId + 1, + ViewDefinitionEnvelopeSerializer.toJson(updatedViewDefinitionsSnapshot))) { + + viewRepositorySnapshot.purge(lastBatchId) + true + } else { + deleteViewTable(databaseName, viewName) + } + case None => false + } + case None => false + } + case None => + false + } + } + + //scalastyle:off method.length + private def tryGetContainerMetadata + ( + databaseName: String, + containerName: String + ): Option[util.HashMap[String, String]] = { + Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache( + CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo), + None, + s"CosmosCatalog(name $catalogName).tryGetContainerMetadata($databaseName, $containerName)" + )) + )) + .to(cosmosClientCacheItems => { + cosmosClientCacheItems(0) + .get + .sparkCatalogClient + .readContainerMetadata(databaseName, containerName) + .block() + }) + } + //scalastyle:on method.length + + private def tryGetViewDefinition(databaseName: String, + containerName: String): Option[ViewDefinition] = { + + this.tryGetViewDefinitions(databaseName) match { + case Some(viewDefinitions) => + viewDefinitions.find(v => databaseName.equals(v.databaseName) && + containerName.equals(v.viewName)) + case None => None + } + } + + private def tryGetViewDefinitions(databaseName: String): Option[Array[ViewDefinition]] = { + + this.viewRepository match { + case Some(viewRepositorySnapshot) => + viewRepositorySnapshot.getLatest() match { + case Some(latestMetadataSnapshot) => + val viewDefinitions = ViewDefinitionEnvelopeSerializer.fromJson(latestMetadataSnapshot._2) + .filter(v => databaseName.equals(v.databaseName)) + if (viewDefinitions.length > 0) { + Some(viewDefinitions) + } else { + None + } + case None => None + } + case None => None + } + } + + private def getContainerIdentifier( + namespaceName: String, + containerId: String): Identifier = { + Identifier.of(Array(namespaceName), containerId) + } + + private def getContainerIdentifier + ( + namespaceName: String, + viewDefinition: ViewDefinition + ): Identifier = { + + Identifier.of(Array(namespaceName), viewDefinition.viewName) + } + + private def checkNamespace(namespace: Array[String]): Unit = { + if (namespace == null || namespace.length != 1) { + throw new CosmosCatalogException( + s"invalid namespace ${namespace.mkString("Array(", ", ", ")")}." + + s" Cosmos DB already support single depth namespace.") + } + } + + private def toCosmosDatabaseName(namespace: String): String = { + namespace + } + + private def toCosmosContainerName(tableIdent: String): String = { + tableIdent + } + + private def toTableConfig(options: CaseInsensitiveStringMap): Map[String, String] = { + options.asCaseSensitiveMap().asScala.toMap + } + + + private def redactAuthInfo(cfg: Map[String, String]): Map[String, String] = { + cfg.filter((kvp) => !CosmosConfigNames.AccountEndpoint.equalsIgnoreCase(kvp._1) && + !CosmosConfigNames.AccountKey.equalsIgnoreCase(kvp._1) && + !kvp._1.toLowerCase.contains(CosmosConfigNames.AccountEndpoint.toLowerCase()) && + !kvp._1.toLowerCase.contains(CosmosConfigNames.AccountKey.toLowerCase()) + ) + } +} +// scalastyle:on multiple.string.literals +// scalastyle:on number.of.methods +// scalastyle:on file.size.limit diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRecordsWrittenMetric.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRecordsWrittenMetric.scala new file mode 100644 index 000000000000..8814c59d0c7d --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRecordsWrittenMetric.scala @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import org.apache.spark.sql.connector.metric.CustomSumMetric + +private[cosmos] class CosmosRecordsWrittenMetric extends CustomSumMetric { + override def name(): String = CosmosConstants.MetricNames.RecordsWritten + + override def description(): String = CosmosConstants.MetricNames.RecordsWritten +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRowConverter.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRowConverter.scala new file mode 100644 index 000000000000..fb4e9db760a0 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRowConverter.scala @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.spark.SchemaConversionModes.SchemaConversionMode +import com.fasterxml.jackson.annotation.JsonInclude.Include +// scalastyle:off underscore.import +import com.fasterxml.jackson.databind.node._ +import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} +import java.time.format.DateTimeFormatter +import java.time.LocalDateTime +import scala.collection.concurrent.TrieMap + +// scalastyle:off underscore.import +import org.apache.spark.sql.types._ +// scalastyle:on underscore.import + +import scala.util.{Try, Success, Failure} + +// scalastyle:off +private[cosmos] object CosmosRowConverter { + + // TODO: Expose configuration to handle duplicate fields + // See: https://github.com/Azure/azure-sdk-for-java/pull/18642#discussion_r558638474 + private val rowConverterMap = new TrieMap[CosmosSerializationConfig, CosmosRowConverter] + + def get(serializationConfig: CosmosSerializationConfig): CosmosRowConverter = { + rowConverterMap.get(serializationConfig) match { + case Some(existingRowConverter) => existingRowConverter + case None => + val newRowConverterCandidate = createRowConverter(serializationConfig) + rowConverterMap.putIfAbsent(serializationConfig, newRowConverterCandidate) match { + case Some(existingConcurrentlyCreatedRowConverter) => existingConcurrentlyCreatedRowConverter + case None => newRowConverterCandidate + } + } + } + + private def createRowConverter(serializationConfig: CosmosSerializationConfig): CosmosRowConverter = { + val objectMapper = new ObjectMapper() + import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule + objectMapper.registerModule(new JavaTimeModule) + serializationConfig.serializationInclusionMode match { + case SerializationInclusionModes.NonNull => objectMapper.setSerializationInclusion(Include.NON_NULL) + case SerializationInclusionModes.NonEmpty => objectMapper.setSerializationInclusion(Include.NON_EMPTY) + case SerializationInclusionModes.NonDefault => objectMapper.setSerializationInclusion(Include.NON_DEFAULT) + case _ => objectMapper.setSerializationInclusion(Include.ALWAYS) + } + + new CosmosRowConverter(objectMapper, serializationConfig) + } +} + +private[cosmos] class CosmosRowConverter(private val objectMapper: ObjectMapper, private val serializationConfig: CosmosSerializationConfig) + extends CosmosRowConverterBase(objectMapper, serializationConfig) { + + override def convertSparkDataTypeToJsonNodeConditionallyForSparkRuntimeSpecificDataType + ( + fieldType: DataType, + rowData: Any + ): Option[JsonNode] = { + fieldType match { + case TimestampNTZType if rowData.isInstanceOf[java.time.LocalDateTime] => convertToJsonNodeConditionally(rowData.asInstanceOf[java.time.LocalDateTime].toString) + case _ => + throw new Exception(s"Cannot cast $rowData into a Json value. $fieldType has no matching Json value.") + } + } + + override def convertSparkDataTypeToJsonNodeNonNullForSparkRuntimeSpecificDataType(fieldType: DataType, rowData: Any): JsonNode = { + fieldType match { + case TimestampNTZType if rowData.isInstanceOf[java.time.LocalDateTime] => objectMapper.convertValue(rowData.asInstanceOf[java.time.LocalDateTime].toString, classOf[JsonNode]) + case _ => + throw new Exception(s"Cannot cast $rowData into a Json value. $fieldType has no matching Json value.") + } + } + + override def convertToSparkDataTypeForSparkRuntimeSpecificDataType + (dataType: DataType, + value: JsonNode, + schemaConversionMode: SchemaConversionMode): Any = + (value, dataType) match { + case (_, _: TimestampNTZType) => handleConversionErrors(() => toTimestampNTZ(value), schemaConversionMode) + case _ => + throw new IllegalArgumentException( + s"Unsupported datatype conversion [Value: $value] of ${value.getClass}] to $dataType]") + } + + + def toTimestampNTZ(value: JsonNode): LocalDateTime = { + value match { + case isJsonNumber() => LocalDateTime.parse(value.asText()) + case textNode: TextNode => + parseDateTimeNTZFromString(textNode.asText()) match { + case Some(odt) => odt + case None => + throw new IllegalArgumentException( + s"Value '${textNode.asText()} cannot be parsed as LocalDateTime (TIMESTAMP_NTZ).") + } + case _ => LocalDateTime.parse(value.asText()) + } + } + + private def handleConversionErrors[A] = (conversion: () => A, + schemaConversionMode: SchemaConversionMode) => { + Try(conversion()) match { + case Success(convertedValue) => convertedValue + case Failure(error) => + if (schemaConversionMode == SchemaConversionModes.Relaxed) { + null + } + else { + throw error + } + } + } + + def parseDateTimeNTZFromString(value: String): Option[LocalDateTime] = { + try { + val odt = LocalDateTime.parse(value, DateTimeFormatter.ISO_DATE_TIME) + Some(odt) + } + catch { + case _: Exception => None + } + } + +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosWriter.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosWriter.scala new file mode 100644 index 000000000000..042c6ca5636e --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosWriter.scala @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.CosmosDiagnosticsContext +import com.azure.cosmos.implementation.ImplementationBridgeHelpers +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.connector.metric.CustomTaskMetric +import org.apache.spark.sql.connector.write.WriterCommitMessage +import org.apache.spark.sql.execution.metric.CustomMetrics +import org.apache.spark.sql.types.StructType + +import java.util.concurrent.atomic.AtomicLong + +private class CosmosWriter( + userConfig: Map[String, String], + cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots], + diagnosticsConfig: DiagnosticsConfig, + inputSchema: StructType, + partitionId: Int, + taskId: Long, + epochId: Option[Long], + sparkEnvironmentInfo: String) + extends CosmosWriterBase( + userConfig, + cosmosClientStateHandles, + diagnosticsConfig, + inputSchema, + partitionId, + taskId, + epochId, + sparkEnvironmentInfo + ) with OutputMetricsPublisherTrait { + + private val recordsWritten = new AtomicLong(0) + private val bytesWritten = new AtomicLong(0) + private val totalRequestCharge = new AtomicLong(0) + + private val recordsWrittenMetric = new CustomTaskMetric { + override def name(): String = CosmosConstants.MetricNames.RecordsWritten + override def value(): Long = recordsWritten.get() + } + + private val bytesWrittenMetric = new CustomTaskMetric { + override def name(): String = CosmosConstants.MetricNames.BytesWritten + + override def value(): Long = bytesWritten.get() + } + + private val totalRequestChargeMetric = new CustomTaskMetric { + override def name(): String = CosmosConstants.MetricNames.TotalRequestCharge + + // Internally we capture RU/s up to 2 fractional digits to have more precise rounding + override def value(): Long = totalRequestCharge.get() / 100L + } + + private val metrics = Array(recordsWrittenMetric, bytesWrittenMetric, totalRequestChargeMetric) + + override def currentMetricsValues(): Array[CustomTaskMetric] = { + metrics + } + + override def getOutputMetricsPublisher(): OutputMetricsPublisherTrait = this + + override def trackWriteOperation(recordCount: Long, diagnostics: Option[CosmosDiagnosticsContext]): Unit = { + if (recordCount > 0) { + recordsWritten.addAndGet(recordCount) + } + + diagnostics match { + case Some(ctx) => + // Capturing RU/s with 2 fractional digits internally + totalRequestCharge.addAndGet((ctx.getTotalRequestCharge * 100L).toLong) + bytesWritten.addAndGet( + if (ImplementationBridgeHelpers + .CosmosDiagnosticsContextHelper + .getCosmosDiagnosticsContextAccessor + .getOperationType(ctx) + .isReadOnlyOperation) { + + ctx.getMaxRequestPayloadSizeInBytes + ctx.getMaxResponsePayloadSizeInBytes + } else { + ctx.getMaxRequestPayloadSizeInBytes + } + ) + case None => + } + } + + override def commit(): WriterCommitMessage = { + val commitMessage = super.commit() + + // TODO @fabianm - this is a workaround - it shouldn't be necessary to do this here + // Unfortunately WriteToDataSourceV2Exec.scala is not updating custom metrics after the + // call to commit - meaning DataSources which asynchronously write data and flush in commit + // won't get accurate metrics because updates between the last call to write and flushing the + // writes are lost. See https://issues.apache.org/jira/browse/SPARK-45759 + // Once above issue is addressed (probably in Spark 3.4.1 or 3.5 - this needs to be changed + // + // NOTE: This also means that the RU/s metrics cannot be updated in commit - so the + // RU/s metric at the end of a task will be slightly outdated/behind + CustomMetrics.updateMetrics( + currentMetricsValues(), + SparkInternalsBridge.getInternalCustomTaskMetricsAsSQLMetric(CosmosConstants.MetricNames.KnownCustomMetricNames)) + + commitMessage + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScan.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScan.scala new file mode 100644 index 000000000000..1e193b9e6959 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScan.scala @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.models.PartitionKeyDefinition +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType + +private[spark] class ItemsScan(session: SparkSession, + schema: StructType, + config: Map[String, String], + readConfig: CosmosReadConfig, + analyzedFilters: AnalyzedAggregatedFilters, + cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots], + diagnosticsConfig: DiagnosticsConfig, + sparkEnvironmentInfo: String, + partitionKeyDefinition: PartitionKeyDefinition) + extends ItemsScanBase( + session, + schema, + config, + readConfig, + analyzedFilters, + cosmosClientStateHandles, + diagnosticsConfig, + sparkEnvironmentInfo, + partitionKeyDefinition) + with SupportsRuntimeFiltering { // SupportsRuntimeFiltering extends scan + override def filterAttributes(): Array[NamedReference] = { + runtimeFilterAttributesCore() + } + + override def filter(filters: Array[Filter]): Unit = { + runtimeFilterCore(filters) + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScanBuilder.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScanBuilder.scala new file mode 100644 index 000000000000..340a40585eb0 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScanBuilder.scala @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.SparkBridgeInternal +import com.azure.cosmos.models.PartitionKeyDefinition +import com.azure.cosmos.spark.diagnostics.LoggerHelper +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +private case class ItemsScanBuilder(session: SparkSession, + config: CaseInsensitiveStringMap, + inputSchema: StructType, + cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots], + diagnosticsConfig: DiagnosticsConfig, + sparkEnvironmentInfo: String) + extends ScanBuilder + with SupportsPushDownFilters + with SupportsPushDownRequiredColumns { + + @transient private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass) + log.logTrace(s"Instantiated ${this.getClass.getSimpleName}") + + private val configMap = config.asScala.toMap + private val readConfig = CosmosReadConfig.parseCosmosReadConfig(configMap) + private var processedPredicates : Option[AnalyzedAggregatedFilters] = Option.empty + + private val clientConfiguration = CosmosClientConfiguration.apply( + configMap, + readConfig.readConsistencyStrategy, + CosmosClientConfiguration.getSparkEnvironmentInfo(Some(session)) + ) + private val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(configMap) + private val description = { + s"""Cosmos ItemsScanBuilder: ${containerConfig.database}.${containerConfig.container}""".stripMargin + } + + private val partitionKeyDefinition: PartitionKeyDefinition = { + TransientErrorsRetryPolicy.executeWithRetry(() => { + val calledFrom = s"ItemsScan($description()).getPartitionKeyDefinition" + Loan( + List[Option[CosmosClientCacheItem]]( + Some(CosmosClientCache.apply( + clientConfiguration, + Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches), + calledFrom + )), + ThroughputControlHelper.getThroughputControlClientCacheItem( + configMap, calledFrom, Some(cosmosClientStateHandles), sparkEnvironmentInfo) + )) + .to(clientCacheItems => { + val container = + ThroughputControlHelper.getContainer( + configMap, + containerConfig, + clientCacheItems(0).get, + clientCacheItems(1)) + + SparkBridgeInternal + .getContainerPropertiesFromCollectionCache(container) + .getPartitionKeyDefinition() + }) + }) + } + + private val filterAnalyzer = FilterAnalyzer(readConfig, partitionKeyDefinition) + + /** + * Pushes down filters, and returns filters that need to be evaluated after scanning. + * @param filters pushed down filters. + * @return the filters that spark need to evaluate + */ + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + this.processedPredicates = Option.apply(filterAnalyzer.analyze(filters)) + + // return the filters that spark need to evaluate + this.processedPredicates.get.filtersNotSupportedByCosmos + } + + /** + * Returns the filters that are pushed to Cosmos as query predicates + * @return filters to be pushed to cosmos db. + */ + override def pushedFilters: Array[Filter] = { + if (this.processedPredicates.isDefined) { + this.processedPredicates.get.filtersToBePushedDownToCosmos + } else { + Array[Filter]() + } + } + + override def build(): Scan = { + val effectiveAnalyzedFilters = this.processedPredicates match { + case Some(analyzedFilters) => analyzedFilters + case None => filterAnalyzer.analyze(Array.empty[Filter]) + } + + // TODO when inferring schema we should consolidate the schema from pruneColumns + new ItemsScan( + session, + inputSchema, + this.configMap, + this.readConfig, + effectiveAnalyzedFilters, + cosmosClientStateHandles, + diagnosticsConfig, + sparkEnvironmentInfo, + partitionKeyDefinition) + } + + /** + * Applies column pruning w.r.t. the given requiredSchema. + * + * Implementation should try its best to prune the unnecessary columns or nested fields, but it's + * also OK to do the pruning partially, e.g., a data source may not be able to prune nested + * fields, and only prune top-level columns. + * + * Note that, `Scan` implementation should take care of the column + * pruning applied here. + */ + override def pruneColumns(requiredSchema: StructType): Unit = { + // TODO: we need to decide whether do a push down or not on the projection + // spark will do column pruning on the returned data. + // pushing down projection to cosmos has tradeoffs: + // - it increases consumed RU in cosmos query engine + // - it decrease the networking layer latency + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsWriterBuilder.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsWriterBuilder.scala new file mode 100644 index 000000000000..ea759335091b --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsWriterBuilder.scala @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.{CosmosAsyncClient, ReadConsistencyStrategy, SparkBridgeInternal} +import com.azure.cosmos.spark.diagnostics.LoggerHelper +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} +import org.apache.spark.sql.connector.expressions.{Expression, Expressions, NullOrdering, SortDirection, SortOrder} +import org.apache.spark.sql.connector.metric.CustomMetric +import org.apache.spark.sql.connector.write.streaming.StreamingWrite +import org.apache.spark.sql.connector.write.{BatchWrite, RequiresDistributionAndOrdering, Write, WriteBuilder} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +private class ItemsWriterBuilder +( + userConfig: CaseInsensitiveStringMap, + inputSchema: StructType, + cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots], + diagnosticsConfig: DiagnosticsConfig, + sparkEnvironmentInfo: String +) + extends WriteBuilder { + @transient private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass) + log.logTrace(s"Instantiated ${this.getClass.getSimpleName}") + + override def build(): Write = { + new CosmosWrite + } + + override def buildForBatch(): BatchWrite = + new ItemsBatchWriter( + userConfig.asCaseSensitiveMap().asScala.toMap, + inputSchema, + cosmosClientStateHandles, + diagnosticsConfig, + sparkEnvironmentInfo) + + override def buildForStreaming(): StreamingWrite = + new ItemsBatchWriter( + userConfig.asCaseSensitiveMap().asScala.toMap, + inputSchema, + cosmosClientStateHandles, + diagnosticsConfig, + sparkEnvironmentInfo) + + private class CosmosWrite extends Write with RequiresDistributionAndOrdering { + + private[this] val supportedCosmosMetrics: Array[CustomMetric] = { + Array( + new CosmosBytesWrittenMetric(), + new CosmosRecordsWrittenMetric(), + new TotalRequestChargeMetric() + ) + } + + // Extract userConfig conversion to avoid repeated calls + private[this] val userConfigMap = userConfig.asCaseSensitiveMap().asScala.toMap + + private[this] val writeConfig = CosmosWriteConfig.parseWriteConfig( + userConfigMap, + inputSchema + ) + + private[this] val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig( + userConfigMap + ) + + override def toBatch(): BatchWrite = + new ItemsBatchWriter( + userConfigMap, + inputSchema, + cosmosClientStateHandles, + diagnosticsConfig, + sparkEnvironmentInfo) + + override def toStreaming: StreamingWrite = + new ItemsBatchWriter( + userConfigMap, + inputSchema, + cosmosClientStateHandles, + diagnosticsConfig, + sparkEnvironmentInfo) + + override def supportedCustomMetrics(): Array[CustomMetric] = supportedCosmosMetrics + + override def requiredDistribution(): Distribution = { + if (writeConfig.bulkEnabled && writeConfig.bulkTransactional) { + log.logInfo("Transactional batch mode enabled - configuring data distribution by partition key columns") + // For transactional writes, partition by all partition key columns + val partitionKeyPaths = getPartitionKeyColumnNames() + if (partitionKeyPaths.nonEmpty) { + // Use public Expressions.column() factory - returns NamedReference + val clustering = partitionKeyPaths.map(path => Expressions.column(path): Expression).toArray + Distributions.clustered(clustering) + } else { + Distributions.unspecified() + } + } else { + Distributions.unspecified() + } + } + + override def requiredOrdering(): Array[SortOrder] = { + if (writeConfig.bulkEnabled && writeConfig.bulkTransactional) { + // For transactional writes, order by all partition key columns (ascending) + val partitionKeyPaths = getPartitionKeyColumnNames() + if (partitionKeyPaths.nonEmpty) { + partitionKeyPaths.map { path => + // Use public Expressions.sort() factory for creating SortOrder + Expressions.sort( + Expressions.column(path), + SortDirection.ASCENDING, + NullOrdering.NULLS_FIRST + ) + }.toArray + } else { + Array.empty[SortOrder] + } + } else { + Array.empty[SortOrder] + } + } + + private def getPartitionKeyColumnNames(): Seq[String] = { + try { + Loan( + List[Option[CosmosClientCacheItem]]( + Some(createClientForPartitionKeyLookup()) + )) + .to(clientCacheItems => { + val container = ThroughputControlHelper.getContainer( + userConfigMap, + containerConfig, + clientCacheItems(0).get, + None + ) + + // Simplified retrieval using SparkBridgeInternal directly + val containerProperties = SparkBridgeInternal.getContainerPropertiesFromCollectionCache(container) + val partitionKeyDefinition = containerProperties.getPartitionKeyDefinition + + extractPartitionKeyPaths(partitionKeyDefinition) + }) + } catch { + case ex: Exception => + log.logWarning(s"Failed to get partition key definition for transactional writes: ${ex.getMessage}") + Seq.empty[String] + } + } + + private def createClientForPartitionKeyLookup(): CosmosClientCacheItem = { + CosmosClientCache( + CosmosClientConfiguration( + userConfigMap, + ReadConsistencyStrategy.EVENTUAL, + sparkEnvironmentInfo + ), + Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches), + "ItemsWriterBuilder-PKLookup" + ) + } + + private def extractPartitionKeyPaths(partitionKeyDefinition: com.azure.cosmos.models.PartitionKeyDefinition): Seq[String] = { + if (partitionKeyDefinition != null && partitionKeyDefinition.getPaths != null) { + val paths = partitionKeyDefinition.getPaths.asScala + if (paths.isEmpty) { + log.logError("Partition key definition has 0 columns - this should not happen for modern containers") + } + paths.map(path => { + // Remove leading '/' from partition key path (e.g., "/pk" -> "pk") + if (path.startsWith("/")) path.substring(1) else path + }).toSeq + } else { + log.logError("Partition key definition is null - this should not happen for modern containers") + Seq.empty[String] + } + } + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/RowSerializerPool.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/RowSerializerPool.scala new file mode 100644 index 000000000000..427b8757e3e5 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/RowSerializerPool.scala @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.types.StructType + +/** + * Spark serializers are not thread-safe - and expensive to create (dynamic code generation) + * So we will use this object pool to allow reusing serializers based on the targeted schema. + * The main purpose for pooling serializers (vs. creating new ones in each PartitionReader) is for Structured + * Streaming scenarios where PartitionReaders for the same schema could be created every couple of 100 + * milliseconds + * A clean-up task is used to purge serializers for schemas which weren't used anymore + * For each schema we have an object pool that will use a soft-limit to limit the memory footprint + */ +private object RowSerializerPool { + private val serializerFactorySingletonInstance = + new RowSerializerPoolInstance((schema: StructType) => ExpressionEncoder.apply(schema).createSerializer()) + + def getOrCreateSerializer(schema: StructType): ExpressionEncoder.Serializer[Row] = { + serializerFactorySingletonInstance.getOrCreateSerializer(schema) + } + + def returnSerializerToPool(schema: StructType, serializer: ExpressionEncoder.Serializer[Row]): Boolean = { + serializerFactorySingletonInstance.returnSerializerToPool(schema, serializer) + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/SparkInternalsBridge.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/SparkInternalsBridge.scala new file mode 100644 index 000000000000..45d7bacef995 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/SparkInternalsBridge.scala @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.guava25.base.MoreObjects.firstNonNull +import com.azure.cosmos.implementation.guava25.base.Strings.emptyToNull +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import org.apache.spark.TaskContext +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.util.AccumulatorV2 + +import java.lang.reflect.Method +import java.util.Locale +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} +class SparkInternalsBridge { + // Only used in ChangeFeedMetricsListener, which is easier for test validation + def getInternalCustomTaskMetricsAsSQLMetric( + knownCosmosMetricNames: Set[String], + taskMetrics: TaskMetrics) : Map[String, SQLMetric] = { + SparkInternalsBridge.getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskMetrics) + } +} + +object SparkInternalsBridge extends BasicLoggingTrait { + private val SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY = "COSMOS.SPARK_REFLECTION_ACCESS_ALLOWED" + private val SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE = "COSMOS_SPARK_REFLECTION_ACCESS_ALLOWED" + + private val DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED = true + private val accumulatorsMethod : AtomicReference[Method] = new AtomicReference[Method]() + + private def getSparkReflectionAccessAllowed: Boolean = { + val allowedText = System.getProperty( + SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY, + firstNonNull( + emptyToNull(System.getenv.get(SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE)), + String.valueOf(DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED))) + + try { + java.lang.Boolean.valueOf(allowedText.toUpperCase(Locale.ROOT)) + } + catch { + case e: Exception => + logError(s"Parsing spark reflection access allowed $allowedText failed. Using the default $DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED.", e) + DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED + } + } + + private final lazy val reflectionAccessAllowed = new AtomicBoolean(getSparkReflectionAccessAllowed) + + def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String]) : Map[String, SQLMetric] = { + Option.apply(TaskContext.get()) match { + case Some(taskCtx) => getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames, taskCtx.taskMetrics()) + case None => Map.empty[String, SQLMetric] + } + } + + def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String], taskMetrics: TaskMetrics) : Map[String, SQLMetric] = { + + if (!reflectionAccessAllowed.get) { + Map.empty[String, SQLMetric] + } else { + getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskMetrics) + } + } + + private def getAccumulators(taskMetrics: TaskMetrics): Option[Seq[AccumulatorV2[_, _]]] = { + try { + val method = Option(accumulatorsMethod.get) match { + case Some(existing) => existing + case None => + val newMethod = taskMetrics.getClass.getMethod("accumulators") + newMethod.setAccessible(true) + accumulatorsMethod.set(newMethod) + newMethod + } + + val accums = method.invoke(taskMetrics).asInstanceOf[Seq[AccumulatorV2[_, _]]] + + Some(accums) + } catch { + case e: Exception => + logInfo(s"Could not invoke getAccumulators via reflection - Error ${e.getMessage}", e) + + // reflection failed - disabling it for the future + reflectionAccessAllowed.set(false) + None + } + } + + private def getInternalCustomTaskMetricsAsSQLMetricInternal( + knownCosmosMetricNames: Set[String], + taskMetrics: TaskMetrics): Map[String, SQLMetric] = { + getAccumulators(taskMetrics) match { + case Some(accumulators) => accumulators + .filter(accumulable => accumulable.isInstanceOf[SQLMetric] + && accumulable.name.isDefined + && knownCosmosMetricNames.contains(accumulable.name.get)) + .map(accumulable => { + val sqlMetric = accumulable.asInstanceOf[SQLMetric] + sqlMetric.name.get -> sqlMetric + }) + .toMap[String, SQLMetric] + case None => Map.empty[String, SQLMetric] + } + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/TotalRequestChargeMetric.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/TotalRequestChargeMetric.scala new file mode 100644 index 000000000000..56d1f0ba2b78 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/TotalRequestChargeMetric.scala @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import org.apache.spark.sql.connector.metric.CustomSumMetric + +private[cosmos] class TotalRequestChargeMetric extends CustomSumMetric { + override def name(): String = CosmosConstants.MetricNames.TotalRequestCharge + + override def description(): String = CosmosConstants.MetricNames.TotalRequestCharge +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientBuilderInterceptor b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientBuilderInterceptor new file mode 100644 index 000000000000..0d43a5bfc657 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientBuilderInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestCosmosClientBuilderInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor new file mode 100644 index 000000000000..e2239720776d --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestFaultInjectionClientInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor new file mode 100644 index 000000000000..c60cbf2f14e4 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor @@ -0,0 +1 @@ +com.azure.cosmos.spark.TestWriteOnRetryCommitInterceptor \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ChangeFeedMetricsListenerITest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ChangeFeedMetricsListenerITest.scala new file mode 100644 index 000000000000..6b9de815ea9c --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ChangeFeedMetricsListenerITest.scala @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// scalastyle:off magic.number +// scalastyle:off multiple.string.literals + +package com.azure.cosmos.spark + +import com.azure.cosmos.changeFeedMetrics.{ChangeFeedMetricsListener, ChangeFeedMetricsTracker} +import com.azure.cosmos.implementation.guava25.collect.{HashBiMap, Maps} +import org.apache.spark.Success +import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} +import org.apache.spark.scheduler.{SparkListenerTaskEnd, TaskInfo} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.mockito.ArgumentMatchers +import org.mockito.Mockito.{mock, when} + +import java.lang.reflect.Field +import java.util.concurrent.ConcurrentHashMap + +class ChangeFeedMetricsListenerITest extends IntegrationSpec with SparkWithJustDropwizardAndNoSlf4jMetrics { + "ChangeFeedMetricsListener" should "be able to capture changeFeed performance metrics" in { + val taskEnd = SparkListenerTaskEnd( + stageId = 1, + stageAttemptId = 0, + taskType = "ResultTask", + reason = Success, + taskInfo = mock(classOf[TaskInfo]), + taskExecutorMetrics = mock(classOf[ExecutorMetrics]), + taskMetrics = mock(classOf[TaskMetrics]) + ) + + val indexMetric = SQLMetrics.createMetric(spark.sparkContext, "index") + indexMetric.set(1) + val lsnMetric = SQLMetrics.createMetric(spark.sparkContext, "lsn") + lsnMetric.set(100) + val itemsMetric = SQLMetrics.createMetric(spark.sparkContext, "items") + itemsMetric.set(100) + + val metrics = Map[String, SQLMetric]( + CosmosConstants.MetricNames.ChangeFeedPartitionIndex -> indexMetric, + CosmosConstants.MetricNames.ChangeFeedLsnRange -> lsnMetric, + CosmosConstants.MetricNames.ChangeFeedItemsCnt -> itemsMetric + ) + + // create sparkInternalsBridge mock + val sparkInternalsBridge = mock(classOf[SparkInternalsBridge]) + when(sparkInternalsBridge.getInternalCustomTaskMetricsAsSQLMetric( + ArgumentMatchers.any[Set[String]], + ArgumentMatchers.any[TaskMetrics] + )).thenReturn(metrics) + + val partitionIndexMap = Maps.synchronizedBiMap(HashBiMap.create[NormalizedRange, Long]()) + partitionIndexMap.put(NormalizedRange("0", "FF"), 1) + + val partitionMetricsMap = new ConcurrentHashMap[NormalizedRange, ChangeFeedMetricsTracker]() + val changeFeedMetricsListener = new ChangeFeedMetricsListener(partitionIndexMap, partitionMetricsMap) + + // set the internal sparkInternalsBridgeField + val sparkInternalsBridgeField: Field = classOf[ChangeFeedMetricsListener].getDeclaredField("sparkInternalsBridge") + sparkInternalsBridgeField.setAccessible(true) + sparkInternalsBridgeField.set(changeFeedMetricsListener, sparkInternalsBridge) + + // verify that metrics will be properly tracked + changeFeedMetricsListener.onTaskEnd(taskEnd) + partitionMetricsMap.size() shouldBe 1 + partitionMetricsMap.containsKey(NormalizedRange("0", "FF")) shouldBe true + partitionMetricsMap.get(NormalizedRange("0", "FF")).getWeightedChangeFeedItemsPerLsn.get shouldBe 1 + } + + it should "ignore metrics for unknown partition index" in { + val taskEnd = SparkListenerTaskEnd( + stageId = 1, + stageAttemptId = 0, + taskType = "ResultTask", + reason = Success, + taskInfo = mock(classOf[TaskInfo]), + taskExecutorMetrics = mock(classOf[ExecutorMetrics]), + taskMetrics = mock(classOf[TaskMetrics]) + ) + + val indexMetric2 = SQLMetrics.createMetric(spark.sparkContext, "index") + indexMetric2.set(10) + val lsnMetric2 = SQLMetrics.createMetric(spark.sparkContext, "lsn") + lsnMetric2.set(100) + val itemsMetric2 = SQLMetrics.createMetric(spark.sparkContext, "items") + itemsMetric2.set(100) + + val metrics = Map[String, SQLMetric]( + CosmosConstants.MetricNames.ChangeFeedPartitionIndex -> indexMetric2, + CosmosConstants.MetricNames.ChangeFeedLsnRange -> lsnMetric2, + CosmosConstants.MetricNames.ChangeFeedItemsCnt -> itemsMetric2 + ) + + // create sparkInternalsBridge mock + val sparkInternalsBridge = mock(classOf[SparkInternalsBridge]) + when(sparkInternalsBridge.getInternalCustomTaskMetricsAsSQLMetric( + ArgumentMatchers.any[Set[String]], + ArgumentMatchers.any[TaskMetrics] + )).thenReturn(metrics) + + val partitionIndexMap = Maps.synchronizedBiMap(HashBiMap.create[NormalizedRange, Long]()) + partitionIndexMap.put(NormalizedRange("0", "FF"), 1) + + val partitionMetricsMap = new ConcurrentHashMap[NormalizedRange, ChangeFeedMetricsTracker]() + val changeFeedMetricsListener = new ChangeFeedMetricsListener(partitionIndexMap, partitionMetricsMap) + + // set the internal sparkInternalsBridgeField + val sparkInternalsBridgeField: Field = classOf[ChangeFeedMetricsListener].getDeclaredField("sparkInternalsBridge") + sparkInternalsBridgeField.setAccessible(true) + sparkInternalsBridgeField.set(changeFeedMetricsListener, sparkInternalsBridge) + + // because partition index 10 does not exist in the partitionIndexMap, it will be ignored + changeFeedMetricsListener.onTaskEnd(taskEnd) + partitionMetricsMap shouldBe empty + } + + it should "ignore unrelated metrics" in { + val taskEnd = SparkListenerTaskEnd( + stageId = 1, + stageAttemptId = 0, + taskType = "ResultTask", + reason = Success, + taskInfo = mock(classOf[TaskInfo]), + taskExecutorMetrics = mock(classOf[ExecutorMetrics]), + taskMetrics = mock(classOf[TaskMetrics]) + ) + + val unknownMetric3 = SQLMetrics.createMetric(spark.sparkContext, "unknown") + unknownMetric3.set(10) + + val metrics = Map[String, SQLMetric]( + "unknownMetrics" -> unknownMetric3 + ) + + // create sparkInternalsBridge mock + val sparkInternalsBridge = mock(classOf[SparkInternalsBridge]) + when(sparkInternalsBridge.getInternalCustomTaskMetricsAsSQLMetric( + ArgumentMatchers.any[Set[String]], + ArgumentMatchers.any[TaskMetrics] + )).thenReturn(metrics) + + val partitionIndexMap = Maps.synchronizedBiMap(HashBiMap.create[NormalizedRange, Long]()) + partitionIndexMap.put(NormalizedRange("0", "FF"), 1) + + val partitionMetricsMap = new ConcurrentHashMap[NormalizedRange, ChangeFeedMetricsTracker]() + val changeFeedMetricsListener = new ChangeFeedMetricsListener(partitionIndexMap, partitionMetricsMap) + + // set the internal sparkInternalsBridgeField + val sparkInternalsBridgeField: Field = classOf[ChangeFeedMetricsListener].getDeclaredField("sparkInternalsBridge") + sparkInternalsBridgeField.setAccessible(true) + sparkInternalsBridgeField.set(changeFeedMetricsListener, sparkInternalsBridge) + + // because partition index 10 does not exist in the partitionIndexMap, it will be ignored + changeFeedMetricsListener.onTaskEnd(taskEnd) + partitionMetricsMap shouldBe empty + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITest.scala new file mode 100644 index 000000000000..c9fc02a6482d --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITest.scala @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import org.apache.commons.lang3.RandomStringUtils +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException + +class CosmosCatalogITest + extends CosmosCatalogITestBase(skipHive = true) { + + //scalastyle:off magic.number + + // TODO: spark on windows has issue with this test. + // java.lang.RuntimeException: java.io.IOException: (null) entry in command string: null chmod 0733 D:\tmp\hive; + // once we move Linux CI re-enable the test: + it can "drop an empty database" in { + assume(!Platform.isWindows) + + for (cascade <- Array(true, false)) { + val databaseName = getAutoCleanableDatabaseName + spark.catalog.databaseExists(databaseName) shouldEqual false + + createDatabase(spark, databaseName) + databaseExists(databaseName) shouldEqual true + + dropDatabase(spark, databaseName, cascade) + spark.catalog.databaseExists(databaseName) shouldEqual false + } + } + + // TODO: spark on windows has issue with this test. + // java.lang.RuntimeException: java.io.IOException: (null) entry in command string: null chmod 0733 D:\tmp\hive; + // once we move Linux CI re-enable the test: + it can "drop an non-empty database with cascade true" in { + assume(!Platform.isWindows) + + val databaseName = getAutoCleanableDatabaseName + spark.catalog.databaseExists(databaseName) shouldEqual false + + createDatabase(spark, databaseName) + databaseExists(databaseName) shouldEqual true + + val containerName = RandomStringUtils.randomAlphabetic(5) + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;") + + dropDatabase(spark, databaseName, true) + spark.catalog.databaseExists(databaseName) shouldEqual false + } + + // TODO: spark on windows has issue with this test. + // java.lang.RuntimeException: java.io.IOException: (null) entry in command string: null chmod 0733 D:\tmp\hive; + // once we move Linux CI re-enable the test: + "drop an non-empty database with cascade false" should "throw NonEmptyNamespaceException" in { + assume(!Platform.isWindows) + + try { + val databaseName = getAutoCleanableDatabaseName + spark.catalog.databaseExists(databaseName) shouldEqual false + + createDatabase(spark, databaseName) + databaseExists(databaseName) shouldEqual true + + val containerName = RandomStringUtils.randomAlphabetic(5) + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;") + + dropDatabase(spark, databaseName, false) + fail("Expected NonEmptyNamespaceException is not thrown") + } + catch { + case expectedError: NonEmptyNamespaceException => { + logInfo(s"Expected NonEmptyNamespaceException: $expectedError") + succeed + } + } + } + + it can "list all databases" in { + val databaseName1 = getAutoCleanableDatabaseName + val databaseName2 = getAutoCleanableDatabaseName + + // creating those databases ahead of time + cosmosClient.createDatabase(databaseName1).block() + cosmosClient.createDatabase(databaseName2).block() + + val databases = spark.sql("SHOW DATABASES IN testCatalog").collect() + databases.size should be >= 2 + //validate databases has the above database name1 + databases + .filter( + row => row.getAs[String]("namespace").equals(databaseName1) + || row.getAs[String]("namespace").equals(databaseName2)) should have size 2 + } + + private def dropDatabase(spark: SparkSession, databaseName: String, cascade: Boolean) = { + if (cascade) { + spark.sql(s"DROP DATABASE testCatalog.$databaseName CASCADE;") + } else { + spark.sql(s"DROP DATABASE testCatalog.$databaseName;") + } + } +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITestBase.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITestBase.scala new file mode 100644 index 000000000000..f10fddac7138 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITestBase.scala @@ -0,0 +1,973 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.CosmosException +import com.azure.cosmos.implementation.{TestConfigurations, Utils} +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import org.apache.commons.lang3.RandomStringUtils +import org.apache.spark.sql.execution.streaming.checkpointing.HDFSMetadataLog +import org.apache.spark.sql.{DataFrame, SparkSession} + +import java.util.UUID +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +abstract class CosmosCatalogITestBase(val skipHive: Boolean = false) extends IntegrationSpec with CosmosClient with BasicLoggingTrait { + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + + var spark : SparkSession = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + var sparkBuilder = SparkSession.builder() + .appName("spark connector sample") + .master("local") + + if (!skipHive) { + sparkBuilder = sparkBuilder.enableHiveSupport() + } + + spark = sparkBuilder.getOrCreate() + + LocalJavaFileSystem.applyToSparkSession(spark) + + spark.conf.set(s"spark.sql.catalog.testCatalog", "com.azure.cosmos.spark.CosmosCatalog") + spark.conf.set(s"spark.sql.catalog.testCatalog.spark.cosmos.accountEndpoint", cosmosEndpoint) + spark.conf.set(s"spark.sql.catalog.testCatalog.spark.cosmos.accountKey", cosmosMasterKey) + spark.conf.set( + "spark.sql.catalog.testCatalog.spark.cosmos.views.repositoryPath", + s"/viewRepository/${UUID.randomUUID().toString}") + spark.conf.set( + "spark.sql.catalog.testCatalog.spark.cosmos.read.partitioning.strategy", + "Restrictive") + } + + override def afterAll(): Unit = { + try spark.close() + finally super.afterAll() + } + + it can "create a database with shared throughput" in { + val databaseName = getAutoCleanableDatabaseName + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName WITH DBPROPERTIES ('manualThroughput' = '1000');") + + cosmosClient.getDatabase(databaseName).read().block() + val throughput = cosmosClient.getDatabase(databaseName).readThroughput().block() + + throughput.getProperties.getManualThroughput shouldEqual 1000 + } + + it can "create a table with customized properties and hierarchical partition keys, without partition kind and version" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " + + s"TBLPROPERTIES(partitionKeyPath = '/tenantId,/userId,/sessionId', manualThroughput = '1100')") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/tenantId", "/userId", "/sessionId")) + // scalastyle:off null + containerProperties.getDefaultTimeToLiveInSeconds shouldEqual null + // scalastyle:on null + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 1100 + } + + it can "create a table with customized properties and hierarchical partition keys, with correct partition kind" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " + + s"TBLPROPERTIES(partitionKeyPath = '/tenantId,/userId,/sessionId', partitionKeyVersion = 'V2', partitionKeyKind = 'MultiHash', manualThroughput = '1100')") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/tenantId", "/userId", "/sessionId")) + // scalastyle:off null + containerProperties.getDefaultTimeToLiveInSeconds shouldEqual null + // scalastyle:on null + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 1100 + } + + it can "create a table with customized properties and hierarchical partition keys, with wrong partition kind" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + try { + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " + + s"TBLPROPERTIES(partitionKeyPath = '/tenantId,/userId,/sessionId', partitionKeyVersion = 'V1', partitionKeyKind = 'Hash', manualThroughput = '1100')") + fail("Expected IllegalArgumentException not thrown") + } + catch + { + case expectedError: IllegalArgumentException => + logInfo(s"Expected IllegaleArgumentException: $expectedError") + succeed // expected error + } + + } + + it can "create a database with shared throughput and alter throughput afterwards" in { + val databaseName = getAutoCleanableDatabaseName + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName WITH DBPROPERTIES ('manualThroughput' = '1000');") + + cosmosClient.getDatabase(databaseName).read().block() + var throughput = cosmosClient.getDatabase(databaseName).readThroughput().block() + + throughput.getProperties.getManualThroughput shouldEqual 1000 + + spark.sql(s"ALTER DATABASE testCatalog.$databaseName SET DBPROPERTIES ('manualThroughput' = '4000');") + + cosmosClient.getDatabase(databaseName).read().block() + throughput = cosmosClient.getDatabase(databaseName).readThroughput().block() + + throughput.getProperties.getManualThroughput shouldEqual 4000 + } + + it can "create a table with defaults" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + cleanupDatabaseLater(databaseName) + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp;") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + + // verify default partition key path is used + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id")) + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 400 + + val tblProperties = getTblProperties(spark, databaseName, containerName) + + tblProperties should have size 8 + + tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "null" + tblProperties("CosmosPartitionCount") shouldEqual "1" + tblProperties("CosmosPartitionKeyDefinition") shouldEqual "{\"paths\":[\"/id\"],\"kind\":\"Hash\"}" + tblProperties("DefaultTtlInSeconds") shouldEqual "null" + tblProperties("VectorEmbeddingPolicy") shouldEqual "null" + tblProperties("IndexingPolicy") shouldEqual + "{\"indexingMode\":\"consistent\",\"automatic\":true,\"includedPaths\":[{\"path\":\"/*\"}]," + + "\"excludedPaths\":[{\"path\":\"/\\\"_etag\\\"/?\"}]}" + + // would look like Manual|RUProvisioned|LastOfferModification + // - last modified as iso datetime like 2021-12-07T10:33:44Z + tblProperties("ProvisionedThroughput").startsWith("Manual|400|") shouldEqual true + tblProperties("ProvisionedThroughput").length shouldEqual 31 + + // last modified as iso datetime like 2021-12-07T10:33:44Z + tblProperties("LastModified").length shouldEqual 20 + } + + it can "create a table and alter throughput afterwards" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + cleanupDatabaseLater(databaseName) + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp;") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + + // verify default partition key path is used + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id")) + + // validate throughput + var throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 400 + + var tblProperties = getTblProperties(spark, databaseName, containerName) + + tblProperties should have size 8 + + // would look like Manual|RUProvisioned|LastOfferModification + // - last modified as iso datetime like 2021-12-07T10:33:44Z + tblProperties("ProvisionedThroughput").startsWith("Manual|400|") shouldEqual true + tblProperties("ProvisionedThroughput").length shouldEqual 31 + + // last modified as iso datetime like 2021-12-07T10:33:44Z + tblProperties("LastModified").length shouldEqual 20 + + spark.sql(s"ALTER TABLE testCatalog.$databaseName.$containerName SET TBLPROPERTIES ('manualThroughput' = '4000');") + + // validate throughput + throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 4000 + + tblProperties = getTblProperties(spark, databaseName, containerName) + + tblProperties should have size 8 + + // would look like Manual|RUProvisioned|LastOfferModification + // - last modified as iso datetime like 2021-12-07T10:33:44Z + tblProperties("ProvisionedThroughput").startsWith("Manual|4000|") shouldEqual true + } + + it can "create a table with shared throughput and Hash V2" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + cleanupDatabaseLater(databaseName) + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName WITH DBPROPERTIES ('manualThroughput' = '1000');") + spark.sql( + s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " + + // TODO @fabianm Emulator doesn't seem to support analytical store - needs to be tested separately + // s"TBLPROPERTIES(partitionKeyVersion = 'V2', analyticalStoreTtlInSeconds = '3000000')") + s"TBLPROPERTIES(partitionKeyVersion = 'V2')") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + + // verify default partition key path is used + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id")) + + try { + // validate that container uses shared database throughput as default + cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + + fail("Expected CosmosException not thrown") + } + catch { + case expectedError: CosmosException => + expectedError.getStatusCode shouldEqual 400 + logInfo(s"Expected CosmosException: $expectedError") + } + + val tblProperties = getTblProperties(spark, databaseName, containerName) + + tblProperties should have size 8 + + // tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "3000000" + tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "null" + tblProperties("CosmosPartitionCount") shouldEqual "1" + tblProperties("CosmosPartitionKeyDefinition") shouldEqual "{\"paths\":[\"/id\"],\"kind\":\"Hash\",\"version\":2}" + tblProperties("DefaultTtlInSeconds") shouldEqual "null" + tblProperties("VectorEmbeddingPolicy") shouldEqual "null" + tblProperties("IndexingPolicy") shouldEqual + "{\"indexingMode\":\"consistent\",\"automatic\":true,\"includedPaths\":[{\"path\":\"/*\"}]," + + "\"excludedPaths\":[{\"path\":\"/\\\"_etag\\\"/?\"}]}" + + // would look like Manual|RUProvisioned|LastOfferModification + // - last modified as iso datetime like 2021-12-07T10:33:44Z + logInfo(s"ProvisionedThroughput: ${tblProperties("ProvisionedThroughput")}") + tblProperties("ProvisionedThroughput").startsWith("Shared.Manual|1000|") shouldEqual true + tblProperties("ProvisionedThroughput").length shouldEqual 39 + + // last modified as iso datetime like 2021-12-07T10:33:44Z + tblProperties("LastModified").length shouldEqual 20 + } + + it can "create a table with defaults but shared autoscale throughput" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + cleanupDatabaseLater(databaseName) + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName WITH DBPROPERTIES ('autoScaleMaxThroughput' = '16000');") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp;") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + + // verify default partition key path is used + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id")) + + try { + // validate that container uses shared database throughput as default + cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + + fail("Expected CosmosException not thrown") + } + catch { + case expectedError: CosmosException => + expectedError.getStatusCode shouldEqual 400 + logInfo(s"Expected CosmosException: $expectedError") + } + + val tblProperties = getTblProperties(spark, databaseName, containerName) + + tblProperties should have size 8 + + tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "null" + tblProperties("CosmosPartitionCount") shouldEqual "2" + tblProperties("CosmosPartitionKeyDefinition") shouldEqual "{\"paths\":[\"/id\"],\"kind\":\"Hash\"}" + tblProperties("DefaultTtlInSeconds") shouldEqual "null" + tblProperties("VectorEmbeddingPolicy") shouldEqual "null" + tblProperties("IndexingPolicy") shouldEqual + "{\"indexingMode\":\"consistent\",\"automatic\":true,\"includedPaths\":[{\"path\":\"/*\"}]," + + "\"excludedPaths\":[{\"path\":\"/\\\"_etag\\\"/?\"}]}" + + // would look like Manual|RUProvisioned|LastOfferModification + // - last modified as iso datetime like 2021-12-07T10:33:44Z + logInfo(s"ProvisionedThroughput: ${tblProperties("ProvisionedThroughput")}") + tblProperties("ProvisionedThroughput").startsWith("Shared.AutoScale|1600|16000|") shouldEqual true + tblProperties("ProvisionedThroughput").length shouldEqual 48 + + // last modified as iso datetime like 2021-12-07T10:33:44Z + tblProperties("LastModified").length shouldEqual 20 + } + + it can "create a table with customized properties" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " + + s"TBLPROPERTIES(partitionKeyPath = '/mypk', manualThroughput = '1100')") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk")) + // scalastyle:off null + containerProperties.getDefaultTimeToLiveInSeconds shouldEqual null + // scalastyle:on null + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 1100 + } + + it can "create a table with well known indexing policy 'AllProperties'" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " + + s"TBLPROPERTIES(partitionKeyPath = '/mypk', manualThroughput = '1100', indexingPolicy = 'AllProperties')") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk")) + containerProperties + .getIndexingPolicy + .getIncludedPaths + .asScala + .map(p => p.getPath) + .toArray should equal(Array("/*")) + containerProperties + .getIndexingPolicy + .getExcludedPaths + .asScala + .map(p => p.getPath) + .toArray should equal(Array(raw"""/"_etag"/?""")) + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 1100 + } + + it can "create a table with well known indexing policy 'OnlySystemProperties'" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " + + s"TBLPROPERTIES(partitionKeyPath = '/mypk', manualThroughput = '1100', indexingPolicy = 'ONLYSystemproperties')") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + containerProperties.getPartitionKeyDefinition.getPaths.toArray should equal(Array("/mypk")) + containerProperties + .getIndexingPolicy + .getIncludedPaths + .asScala.map(p => p.getPath) + .toArray.length shouldEqual 0 + containerProperties + .getIndexingPolicy + .getExcludedPaths + .asScala + .map(p => p.getPath) + .toArray should equal(Array("/*", raw"""/"_etag"/?""")) + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 1100 + } + + it can "create a table with custom indexing policy" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + + val indexPolicyJson = raw"""{"indexingMode":"consistent","automatic":true,"includedPaths":""" + + raw"""[{"path":"\/helloWorld\/?"},{"path":"\/mypk\/?"}],"excludedPaths":[{"path":"\/*"}]}""" + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " + + s"TBLPROPERTIES(partitionKeyPath = '/mypk', manualThroughput = '1100', indexingPolicy = '$indexPolicyJson')") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk")) + containerProperties + .getIndexingPolicy + .getIncludedPaths + .asScala + .map(p => p.getPath) + .toArray should equal(Array("/helloWorld/?", "/mypk/?")) + containerProperties + .getIndexingPolicy + .getExcludedPaths + .asScala + .map(p => p.getPath) + .toArray should equal(Array("/*", raw"""/"_etag"/?""")) + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 1100 + + val tblProperties = getTblProperties(spark, databaseName, containerName) + + tblProperties should have size 8 + + tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "null" + tblProperties("CosmosPartitionCount") shouldEqual "1" + tblProperties("CosmosPartitionKeyDefinition") shouldEqual "{\"paths\":[\"/mypk\"],\"kind\":\"Hash\"}" + tblProperties("DefaultTtlInSeconds") shouldEqual "null" + tblProperties("VectorEmbeddingPolicy") shouldEqual "null" + + // indexPolicyJson will be normalized by the backend - so not be the same as the input json + // for the purpose of this test I just want to make sure that the custom indexing options + // are included - correctness of json serialization of indexing policy is tested elsewhere + tblProperties("IndexingPolicy").contains("helloWorld") shouldEqual true + tblProperties("IndexingPolicy").contains("mypk") shouldEqual true + + // would look like Manual|RUProvisioned|LastOfferModification + // - last modified as iso datetime like 2021-12-07T10:33:44Z + tblProperties("ProvisionedThroughput").startsWith("Manual|1100|") shouldEqual true + tblProperties("ProvisionedThroughput").length shouldEqual 32 + + // last modified as iso datetime like 2021-12-07T10:33:44Z + tblProperties("LastModified").length shouldEqual 20 + } + + it can "create a table with TTL -1" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " + + s"TBLPROPERTIES(partitionKeyPath = '/mypk', defaultTtlInSeconds = '-1')") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk")) + containerProperties.getDefaultTimeToLiveInSeconds shouldEqual -1 + + val tblProperties = getTblProperties(spark, databaseName, containerName) + tblProperties("DefaultTtlInSeconds") shouldEqual "-1" + } + + it can "create a table with positive TTL" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " + + s"TBLPROPERTIES(partitionKeyPath = '/mypk', defaultTtlInSeconds = '5')") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk")) + containerProperties.getDefaultTimeToLiveInSeconds shouldEqual 5 + + val tblProperties = getTblProperties(spark, databaseName, containerName) + tblProperties("DefaultTtlInSeconds") shouldEqual "5" + } + + it can "create a table with vector embedding policy" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + cleanupDatabaseLater(databaseName) + + val vectorEmbeddingPolicyJson = + raw"""{"vectorEmbeddings":[{"path":"/vector1","dataType":"float32","distanceFunction":"cosine","dimensions":500}]}""" + + val indexingPolicyJson = + raw"""{"indexingMode":"consistent","automatic":true,"includedPaths":[{"path":"\/mypk\/?"}],""" + + raw""""excludedPaths":[{"path":"\/*"}],"vectorIndexes":[{"path":"\/vector1","type":"flat"}]}""" + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " + + s"TBLPROPERTIES(partitionKeyPath = '/mypk', manualThroughput = '1100', " + + s"indexingPolicy = '$indexingPolicyJson', " + + s"vectorEmbeddingPolicy = '$vectorEmbeddingPolicyJson')") + + val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk")) + + // validate vector embedding policy + val vectorEmbeddingPolicy = containerProperties.getVectorEmbeddingPolicy + vectorEmbeddingPolicy should not be null + vectorEmbeddingPolicy.getVectorEmbeddings should have size 1 + val embedding = vectorEmbeddingPolicy.getVectorEmbeddings.get(0) + embedding.getPath shouldEqual "/vector1" + embedding.getDataType.toString shouldEqual "float32" + embedding.getDistanceFunction.toString shouldEqual "cosine" + embedding.getEmbeddingDimensions shouldEqual 500 + + // validate vector indexes are in indexing policy + val vectorIndexes = containerProperties.getIndexingPolicy.getVectorIndexes + vectorIndexes should have size 1 + vectorIndexes.get(0).getPath shouldEqual "/vector1" + vectorIndexes.get(0).getType shouldEqual "flat" + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 1100 + + val tblProperties = getTblProperties(spark, databaseName, containerName) + + tblProperties should have size 8 + + tblProperties("CosmosPartitionKeyDefinition") shouldEqual "{\"paths\":[\"/mypk\"],\"kind\":\"Hash\"}" + tblProperties("DefaultTtlInSeconds") shouldEqual "null" + tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "null" + + // validate vector embedding policy is in table properties (structured check) + val vepObjectMapper = Utils.getSimpleObjectMapper + val vepNode = vepObjectMapper.readTree(tblProperties("VectorEmbeddingPolicy")) + val vepEmbeddings = vepNode.get("vectorEmbeddings") + vepEmbeddings.size() shouldEqual 1 + vepEmbeddings.get(0).get("path").asText() shouldEqual "/vector1" + vepEmbeddings.get(0).get("dataType").asText() shouldEqual "float32" + vepEmbeddings.get(0).get("distanceFunction").asText() shouldEqual "cosine" + + // validate vector indexes are in indexing policy (structured check) + val ipNode = vepObjectMapper.readTree(tblProperties("IndexingPolicy")) + val vectorIndexesNode = ipNode.get("vectorIndexes") + vectorIndexesNode.size() shouldEqual 1 + vectorIndexesNode.get(0).get("path").asText() shouldEqual "/vector1" + vectorIndexesNode.get(0).get("type").asText() shouldEqual "flat" + + // would look like Manual|RUProvisioned|LastOfferModification + // - last modified as iso datetime like 2021-12-07T10:33:44Z + tblProperties("ProvisionedThroughput").startsWith("Manual|1100|") shouldEqual true + tblProperties("ProvisionedThroughput").length shouldEqual 32 + + // last modified as iso datetime like 2021-12-07T10:33:44Z + tblProperties("LastModified").length shouldEqual 20 + } + + it can "select from a catalog table with default TBLPROPERTIES" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + cleanupDatabaseLater(databaseName) + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp;") + + val container = cosmosClient.getDatabase(databaseName).getContainer(containerName) + val containerProperties = container.read().block().getProperties + + // verify default partition key path is used + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id")) + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 400 + + for (state <- Array(true, false)) { + val objectNode = Utils.getSimpleObjectMapper.createObjectNode() + objectNode.put("name", "Shrodigner's mouse") + objectNode.put("type", "mouse") + objectNode.put("age", 20) + objectNode.put("isAlive", state) + objectNode.put("id", UUID.randomUUID().toString) + container.createItem(objectNode).block() + } + + val dfWithInference = spark.sql(s"SELECT * FROM testCatalog.$databaseName.$containerName") + val rowsArrayUnfiltered= dfWithInference.collect() + rowsArrayUnfiltered should have size 2 + val rowsArrayWithInference = dfWithInference.where("isAlive = 'true' and type = 'mouse'").collect() + rowsArrayWithInference should have size 1 + + val rowWithInference = rowsArrayWithInference(0) + rowWithInference.getAs[String]("name") shouldEqual "Shrodigner's mouse" + rowWithInference.getAs[String]("type") shouldEqual "mouse" + rowWithInference.getAs[Integer]("age") shouldEqual 20 + rowWithInference.getAs[Boolean]("isAlive") shouldEqual true + + val fieldNames = rowWithInference.schema.fields.map(field => field.name) + fieldNames.contains(CosmosTableSchemaInferrer.SelfAttributeName) shouldBe false + fieldNames.contains(CosmosTableSchemaInferrer.TimestampAttributeName) shouldBe false + fieldNames.contains(CosmosTableSchemaInferrer.ResourceIdAttributeName) shouldBe false + fieldNames.contains(CosmosTableSchemaInferrer.ETagAttributeName) shouldBe false + fieldNames.contains(CosmosTableSchemaInferrer.AttachmentsAttributeName) shouldBe false + } + + it can "select from a catalog Cosmos view" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + val viewName = containerName + "view" + RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;") + + val container = cosmosClient.getDatabase(databaseName).getContainer(containerName) + val containerProperties = container.read().block().getProperties + + // verify default partition key path is used + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id")) + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 400 + + for (state <- Array(true, false)) { + val objectNode = Utils.getSimpleObjectMapper.createObjectNode() + objectNode.put("name", "Shrodigner's mouse") + objectNode.put("type", "mouse") + objectNode.put("age", 20) + objectNode.put("isAlive", state) + objectNode.put("id", UUID.randomUUID().toString) + container.createItem(objectNode).block() + } + + spark.sql( + s"CREATE TABLE testCatalog.$databaseName.$viewName using cosmos.oltp " + + s"TBLPROPERTIES(isCosmosView = 'True') " + + s"OPTIONS (" + + s"spark.cosmos.database = '$databaseName', " + + s"spark.cosmos.container = '$containerName', " + + "spark.cosmos.read.inferSchema.enabled = 'True', " + + "spark.cosmos.read.inferSchema.includeSystemProperties = 'True', " + + "spark.cosmos.read.partitioning.strategy = 'Restrictive');") + val tables = spark.sql(s"SHOW TABLES in testCatalog.$databaseName;") + + tables.collect() should have size 2 + + tables + .where(s"tableName = '$viewName' and namespace = '$databaseName'") + .collect() should have size 1 + + tables + .where(s"tableName = '$containerName' and namespace = '$databaseName'") + .collect() should have size 1 + + val dfWithInference = spark.sql(s"SELECT * FROM testCatalog.$databaseName.$viewName") + val rowsArrayUnfiltered= dfWithInference.collect() + rowsArrayUnfiltered should have size 2 + + val rowsArrayWithInference = dfWithInference.where("isAlive = 'true' and type = 'mouse'").collect() + rowsArrayWithInference should have size 1 + + val rowWithInference = rowsArrayWithInference(0) + rowWithInference.getAs[String]("name") shouldEqual "Shrodigner's mouse" + rowWithInference.getAs[String]("type") shouldEqual "mouse" + rowWithInference.getAs[Integer]("age") shouldEqual 20 + rowWithInference.getAs[Boolean]("isAlive") shouldEqual true + + val fieldNames = rowWithInference.schema.fields.map(field => field.name) + fieldNames.contains(CosmosTableSchemaInferrer.SelfAttributeName) shouldBe true + fieldNames.contains(CosmosTableSchemaInferrer.TimestampAttributeName) shouldBe true + fieldNames.contains(CosmosTableSchemaInferrer.ResourceIdAttributeName) shouldBe true + fieldNames.contains(CosmosTableSchemaInferrer.ETagAttributeName) shouldBe true + fieldNames.contains(CosmosTableSchemaInferrer.AttachmentsAttributeName) shouldBe true + } + + it can "manage Cosmos view metadata in the catalog" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + val viewNameRaw = containerName + + "view" + + RandomStringUtils.randomAlphabetic(6).toLowerCase + + System.currentTimeMillis() + val viewNameWithSchemaInference = containerName + + "view" + + RandomStringUtils.randomAlphabetic(6).toLowerCase + + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;") + + val container = cosmosClient.getDatabase(databaseName).getContainer(containerName) + val containerProperties = container.read().block().getProperties + + // verify default partition key path is used + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id")) + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 400 + + for (state <- Array(true, false)) { + val objectNode = Utils.getSimpleObjectMapper.createObjectNode() + objectNode.put("name", "Shrodigner's snake") + objectNode.put("type", "snake") + objectNode.put("age", 20) + objectNode.put("isAlive", state) + objectNode.put("id", UUID.randomUUID().toString) + container.createItem(objectNode).block() + } + + spark.sql( + s"CREATE TABLE testCatalog.$databaseName.$viewNameRaw using cosmos.oltp " + + s"TBLPROPERTIES(isCosmosView = 'True') " + + s"OPTIONS (" + + s"spark.cosmos.database = '$databaseName', " + + s"spark.cosmos.container = '$containerName', " + + s"spark.sql.catalog.testCatalog.spark.cosmos.accountKey = '${TestConfigurations.MASTER_KEY}', " + + s"spark.sql.catalog.testCatalog.spark.cosmos.accountEndpoint = '${TestConfigurations.HOST}', " + + s"spark.cosmos.accountKey = '${TestConfigurations.MASTER_KEY}', " + + s"spark.cosmos.accountEndpoint = '${TestConfigurations.HOST}', " + + "spark.cosmos.read.inferSchema.enabled = 'False', " + + "spark.cosmos.read.partitioning.strategy = 'Restrictive');") + + var tables = spark.sql(s"SHOW TABLES in testCatalog.$databaseName;") + tables.collect() should have size 2 + + spark.sql( + s"CREATE TABLE testCatalog.$databaseName.$viewNameWithSchemaInference using cosmos.oltp " + + s"TBLPROPERTIES(isCosmosView = 'True') " + + s"OPTIONS (" + + s"spark.cosmos.database = '$databaseName', " + + s"spark.cosmos.container = '$containerName', " + + s"spark.sql.catalog.testCatalog.spark.cosmos.accountKey = '${TestConfigurations.MASTER_KEY}', " + + s"spark.sql.catalog.testCatalog.spark.cosmos.accountEndpoint = '${TestConfigurations.HOST}', " + + s"spark.cosmos.accountKey = '${TestConfigurations.MASTER_KEY}', " + + s"spark.cosmos.accountEndpoint = '${TestConfigurations.HOST}', " + + "spark.cosmos.read.inferSchema.enabled = 'True', " + + "spark.cosmos.read.inferSchema.includeSystemProperties = 'False', " + + "spark.cosmos.read.partitioning.strategy = 'Restrictive');") + + tables = spark.sql(s"SHOW TABLES in testCatalog.$databaseName;") + tables.collect() should have size 3 + + val filePath = spark.conf.get("spark.sql.catalog.testCatalog.spark.cosmos.views.repositoryPath") + val hdfsMetadataLog = new HDFSMetadataLog[String](spark, filePath) + + hdfsMetadataLog.getLatest() match { + case None => throw new IllegalStateException("HDFS metadata file should have been written") + case Some((batchId, json)) => + + logInfo(s"BatchId: $batchId, Json: $json") + + // Validate the master key is not stored anywhere + json.contains(TestConfigurations.MASTER_KEY) shouldEqual false + json.contains(TestConfigurations.SECONDARY_MASTER_KEY) shouldEqual false + json.contains(TestConfigurations.HOST) shouldEqual false + + // validate that we can deserialize the persisted json + val deserializedViews = ViewDefinitionEnvelopeSerializer.fromJson(json) + deserializedViews.length >= 2 shouldBe true + deserializedViews + .exists(vd => vd.databaseName == databaseName && vd.viewName == viewNameRaw) shouldEqual true + deserializedViews + .exists(vd => vd.databaseName == databaseName && + vd.viewName == viewNameWithSchemaInference) shouldEqual true + } + + tables + .where(s"tableName = '$containerName' and namespace = '$databaseName'") + .collect() should have size 1 + tables + .where(s"tableName = '$viewNameRaw' and namespace = '$databaseName'") + .collect() should have size 1 + tables + .where(s"tableName = '$viewNameWithSchemaInference' and namespace = '$databaseName'") + .collect() should have size 1 + + val dfRaw = spark.sql(s"SELECT * FROM testCatalog.$databaseName.$viewNameRaw") + val rowsArrayUnfilteredRaw= dfRaw.collect() + rowsArrayUnfilteredRaw should have size 2 + + val fieldNamesRaw = dfRaw.schema.fields.map(field => field.name) + fieldNamesRaw.contains(CosmosTableSchemaInferrer.IdAttributeName) shouldBe true + fieldNamesRaw.contains(CosmosTableSchemaInferrer.RawJsonBodyAttributeName) shouldBe true + fieldNamesRaw.contains(CosmosTableSchemaInferrer.TimestampAttributeName) shouldBe true + fieldNamesRaw.contains(CosmosTableSchemaInferrer.SelfAttributeName) shouldBe false + fieldNamesRaw.contains(CosmosTableSchemaInferrer.ResourceIdAttributeName) shouldBe false + fieldNamesRaw.contains(CosmosTableSchemaInferrer.ETagAttributeName) shouldBe false + fieldNamesRaw.contains(CosmosTableSchemaInferrer.AttachmentsAttributeName) shouldBe false + + val dfWithInference = spark.sql(s"SELECT * FROM testCatalog.$databaseName.$viewNameWithSchemaInference") + val rowsArrayUnfiltered= dfWithInference.collect() + rowsArrayUnfiltered should have size 2 + + val rowsArrayWithInference = dfWithInference.where("isAlive = 'true' and type = 'snake'").collect() + rowsArrayWithInference should have size 1 + + val rowWithInference = rowsArrayWithInference(0) + rowWithInference.getAs[String]("name") shouldEqual "Shrodigner's snake" + rowWithInference.getAs[String]("type") shouldEqual "snake" + rowWithInference.getAs[Integer]("age") shouldEqual 20 + rowWithInference.getAs[Boolean]("isAlive") shouldEqual true + + val fieldNames = rowWithInference.schema.fields.map(field => field.name) + fieldNames.contains(CosmosTableSchemaInferrer.SelfAttributeName) shouldBe false + fieldNames.contains(CosmosTableSchemaInferrer.TimestampAttributeName) shouldBe false + fieldNames.contains(CosmosTableSchemaInferrer.ResourceIdAttributeName) shouldBe false + fieldNames.contains(CosmosTableSchemaInferrer.ETagAttributeName) shouldBe false + fieldNames.contains(CosmosTableSchemaInferrer.AttachmentsAttributeName) shouldBe false + + spark.sql(s"DROP TABLE testCatalog.$databaseName.$viewNameRaw;") + tables = spark.sql(s"SHOW TABLES in testCatalog.$databaseName;") + tables.collect() should have size 2 + + spark.sql(s"DROP TABLE testCatalog.$databaseName.$viewNameWithSchemaInference;") + tables = spark.sql(s"SHOW TABLES in testCatalog.$databaseName;") + tables.collect() should have size 1 + } + + "creating a view without specifying isCosmosView table property" should "throw IllegalArgumentException" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + val viewName = containerName + + "view" + + RandomStringUtils.randomAlphabetic(6).toLowerCase + + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;") + + val container = cosmosClient.getDatabase(databaseName).getContainer(containerName) + val containerProperties = container.read().block().getProperties + + // verify default partition key path is used + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id")) + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 400 + + for (state <- Array(true, false)) { + val objectNode = Utils.getSimpleObjectMapper.createObjectNode() + objectNode.put("name", "Shrodigner's snake") + objectNode.put("type", "snake") + objectNode.put("age", 20) + objectNode.put("isAlive", state) + objectNode.put("id", UUID.randomUUID().toString) + container.createItem(objectNode).block() + } + + try { + spark.sql( + s"CREATE TABLE testCatalog.$databaseName.$viewName using cosmos.oltp " + + s"TBLPROPERTIES(isCosmosViewWithTypo = 'True') " + + s"OPTIONS (" + + s"spark.cosmos.database = '$databaseName', " + + s"spark.cosmos.container = '$containerName', " + + "spark.cosmos.read.inferSchema.enabled = 'False', " + + "spark.cosmos.read.partitioning.strategy = 'Restrictive');") + + fail("Expected IllegalArgumentException not thrown") + } + catch { + case expectedError: IllegalArgumentException => + logInfo(s"Expected IllegaleArgumentException: $expectedError") + succeed + } + } + + "creating a view with specifying isCosmosView==False table property" should "throw IllegalArgumentException" in { + val databaseName = getAutoCleanableDatabaseName + val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + val viewName = containerName + + "view" + + RandomStringUtils.randomAlphabetic(6).toLowerCase + + System.currentTimeMillis() + + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;") + + val container = cosmosClient.getDatabase(databaseName).getContainer(containerName) + val containerProperties = container.read().block().getProperties + + // verify default partition key path is used + containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id")) + + // validate throughput + val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties + throughput.getManualThroughput shouldEqual 400 + + for (state <- Array(true, false)) { + val objectNode = Utils.getSimpleObjectMapper.createObjectNode() + objectNode.put("name", "Shrodigner's snake") + objectNode.put("type", "snake") + objectNode.put("age", 20) + objectNode.put("isAlive", state) + objectNode.put("id", UUID.randomUUID().toString) + container.createItem(objectNode).block() + } + + try { + spark.sql( + s"CREATE TABLE testCatalog.$databaseName.$viewName using cosmos.oltp " + + s"TBLPROPERTIES(isCosmosView = 'False') " + + s"OPTIONS (" + + s"spark.cosmos.database = '$databaseName', " + + s"spark.cosmos.container = '$containerName', " + + "spark.cosmos.read.inferSchema.enabled = 'False', " + + "spark.cosmos.read.partitioning.strategy = 'Restrictive');") + + fail("Expected IllegalArgumentException not thrown") + } + catch { + case expectedError: IllegalArgumentException => + logInfo(s"Expected IllegaleArgumentException: $expectedError") + succeed + } + } + + it can "list all containers in a database" in { + val databaseName = getAutoCleanableDatabaseName + cosmosClient.createDatabase(databaseName).block() + + // create multiple containers under the same database + val containerName1 = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + val containerName2 = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis() + cosmosClient.getDatabase(databaseName).createContainer(containerName1, "/id").block() + cosmosClient.getDatabase(databaseName).createContainer(containerName2, "/id").block() + + val containers = spark.sql(s"SHOW TABLES FROM testCatalog.$databaseName").collect() + containers should have size 2 + containers + .filter( + row => row.getAs[String]("tableName").equals(containerName1) + || row.getAs[String]("tableName").equals(containerName2)) should have size 2 + } + + private def getTblProperties(spark: SparkSession, databaseName: String, containerName: String) = { + val descriptionDf = spark.sql(s"DESCRIBE TABLE EXTENDED testCatalog.$databaseName.$containerName;") + val tblPropertiesRowsArray = descriptionDf + .where("col_name = 'Table Properties'") + .collect() + + for (row <- tblPropertiesRowsArray) { + logInfo(row.mkString) + } + tblPropertiesRowsArray should have size 1 + + // Output will look something like this + // [key1='value1',key2='value2',...] + val tblPropertiesText = tblPropertiesRowsArray(0).getAs[String]("data_type") + // parsing this into dictionary + + val keyValuePairs = tblPropertiesText.substring(1, tblPropertiesText.length - 2).split("',") + keyValuePairs + .map(kvp => { + val columns = kvp.split("='") + (columns(0), columns(1)) + }) + .toMap + } + + def createDatabase(spark: SparkSession, databaseName: String): DataFrame = { + spark.sql(s"CREATE DATABASE testCatalog.$databaseName;") + } + + //scalastyle:on magic.number + //scalastyle:on multiple.string.literals +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosRowConverterTest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosRowConverterTest.scala new file mode 100644 index 000000000000..a5bdc9df94c1 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosRowConverterTest.scala @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.types.TimestampNTZType + +import java.sql.{Date, Timestamp} +import java.time.format.DateTimeFormatter +import java.time.{LocalDateTime, OffsetDateTime} + +// scalastyle:off underscore.import +import org.apache.spark.sql.types._ +// scalastyle:on underscore.import + +class CosmosRowConverterTest extends UnitSpec with BasicLoggingTrait { + //scalastyle:off null + //scalastyle:off multiple.string.literals + //scalastyle:off file.size.limit + + val objectMapper = new ObjectMapper() + private[this] val defaultRowConverter = + CosmosRowConverter.get( + new CosmosSerializationConfig( + SerializationInclusionModes.Always, + SerializationDateTimeConversionModes.Default + ) + ) + + + "date and time and TimestampNTZType in spark row" should "translate to ObjectNode" in { + val colName1 = "testCol1" + val colName2 = "testCol2" + val colName3 = "testCol3" + val colName4 = "testCol4" + val currentMillis = System.currentTimeMillis() + val colVal1 = new Date(currentMillis) + val timestampNTZType = "2021-07-01T08:43:28.037" + val colVal2 = LocalDateTime.parse(timestampNTZType, DateTimeFormatter.ISO_DATE_TIME) + val colVal3 = currentMillis.toInt + + val row = new GenericRowWithSchema( + Array(colVal1, colVal2, colVal3, colVal3), + StructType(Seq(StructField(colName1, DateType), + StructField(colName2, TimestampNTZType), + StructField(colName3, DateType), + StructField(colName4, TimestampType)))) + + val objectNode = defaultRowConverter.fromRowToObjectNode(row) + objectNode.get(colName1).asLong() shouldEqual currentMillis + objectNode.get(colName2).asText() shouldEqual "2021-07-01T08:43:28.037" + objectNode.get(colName3).asInt() shouldEqual colVal3 + objectNode.get(colName4).asInt() shouldEqual colVal3 + } + + "time and TimestampNTZType in ObjectNode" should "translate to Row" in { + val colName1 = "testCol1" + val colName2 = "testCol2" + val colName3 = "testCol3" + val colName4 = "testCol4" + val colVal1 = System.currentTimeMillis() + val colVal1AsTime = new Timestamp(colVal1) + val colVal2 = System.currentTimeMillis() + val colVal2AsTime = new Timestamp(colVal2) + val colVal3 = "2021-01-20T20:10:15+01:00" + val colVal3AsTime = Timestamp.valueOf(OffsetDateTime.parse(colVal3, DateTimeFormatter.ISO_OFFSET_DATE_TIME).toLocalDateTime) + val colVal4 = "2021-07-01T08:43:28.037" + val colVal4AsTime = LocalDateTime.parse(colVal4, DateTimeFormatter.ISO_DATE_TIME) + + val objectNode: ObjectNode = objectMapper.createObjectNode() + objectNode.put(colName1, colVal1) + objectNode.put(colName2, colVal2) + objectNode.put(colName3, colVal3) + objectNode.put(colName4, colVal4) + val schema = StructType(Seq( + StructField(colName1, TimestampType), + StructField(colName2, TimestampType), + StructField(colName3, TimestampType), + StructField(colName4, TimestampNTZType))) + val row = defaultRowConverter.fromObjectNodeToRow(schema, objectNode, SchemaConversionModes.Relaxed) + val asTime = row.get(0).asInstanceOf[Timestamp] + asTime.compareTo(colVal1AsTime) shouldEqual 0 + val asTime2 = row.get(1).asInstanceOf[Timestamp] + asTime2.compareTo(colVal2AsTime) shouldEqual 0 + val asTime3 = row.get(2).asInstanceOf[Timestamp] + asTime3.compareTo(colVal3AsTime) shouldEqual 0 + val asTime4 = row.get(3).asInstanceOf[LocalDateTime] + asTime4.compareTo(colVal4AsTime) shouldEqual 0 + } + + //scalastyle:on null + //scalastyle:on multiple.string.literals + //scalastyle:on file.size.limit +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ItemsScanITest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ItemsScanITest.scala new file mode 100644 index 000000000000..b6433c6d7b2f --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ItemsScanITest.scala @@ -0,0 +1,256 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.{CosmosClientMetadataCachesSnapshot, SparkBridgeImplementationInternal, TestConfigurations, Utils} +import com.azure.cosmos.models.PartitionKey +import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.connector.expressions.Expressions +import org.apache.spark.sql.sources.{Filter, In} +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +import java.util.UUID +import scala.collection.mutable.ListBuffer + +class ItemsScanITest + extends IntegrationSpec + with Spark + with AutoCleanableCosmosContainersWithPkAsPartitionKey { + + //scalastyle:off multiple.string.literals + //scalastyle:off magic.number + + private val idProperty = "id" + private val pkProperty = "pk" + private val itemIdentityProperty = "_itemIdentity" + + private val analyzedAggregatedFilters = + AnalyzedAggregatedFilters( + QueryFilterAnalyzer.rootParameterizedQuery, + false, + Array.empty[Filter], + Array.empty[Filter], + Option.empty[List[ReadManyFilter]]) + + it should "only return readMany filtering property when runtTimeFiltering is enabled and readMany filtering is enabled" in { + val clientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots() + + val testCases = Array( + // containerName, partitionKey property, expected readMany filtering property + (cosmosContainer, idProperty, idProperty), + (cosmosContainersWithPkAsPartitionKey, pkProperty, itemIdentityProperty) + ) + + for (testCase <- testCases) { + val partitionKeyDefinition = + cosmosClient + .getDatabase(cosmosDatabase) + .getContainer(testCase._1) + .read() + .block() + .getProperties + .getPartitionKeyDefinition + + for (runTimeFilteringEnabled <- Array(true, false)) { + for (readManyFilteringEnabled <- Array(true, false)) { + logInfo(s"TestCase: containerName ${testCase._1}, partitionKeyProperty ${testCase._2}, " + + s"runtimeFilteringEnabled $runTimeFilteringEnabled, readManyFilteringEnabled $readManyFilteringEnabled") + + val config = Map( + "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, + "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> testCase._1, + "spark.cosmos.read.inferSchema.enabled" -> "true", + "spark.cosmos.applicationName" -> "ItemsScan", + "spark.cosmos.read.runtimeFiltering.enabled" -> runTimeFilteringEnabled.toString, + "spark.cosmos.read.readManyFiltering.enabled" -> readManyFilteringEnabled.toString + ) + val readConfig = CosmosReadConfig.parseCosmosReadConfig(config) + val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config) + val schema = getDefaultSchema(testCase._2) + + val itemScan = new ItemsScan( + spark, + schema, + config, + readConfig, + analyzedAggregatedFilters, + clientMetadataCachesSnapshots, + diagnosticsConfig, + "", + partitionKeyDefinition) + val arrayReferences = itemScan.filterAttributes() + + if (runTimeFilteringEnabled && readManyFilteringEnabled) { + arrayReferences.size shouldBe 1 + arrayReferences should contain theSameElementsAs Array(Expressions.column(testCase._3)) + } else { + arrayReferences shouldBe empty + } + } + } + } + } + + it should "only prune partitions when runtTimeFiltering is enabled and readMany filtering is enabled" in { + val clientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots() + + val testCases = Array( + //containerName, partitionKeyProperty, expected readManyFiltering property + (cosmosContainer, idProperty, idProperty), + (cosmosContainersWithPkAsPartitionKey, pkProperty, itemIdentityProperty) + ) + for (testCase <- testCases) { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(testCase._1) + val partitionKeyDefinition = container.read().block().getProperties.getPartitionKeyDefinition + + // assert that there is more than one range + val feedRanges = container.getFeedRanges.block() + feedRanges.size() should be > 1 + + // first inject few items + val matchingItemList = ListBuffer[ObjectNode]() + for (_ <- 1 to 20) { + val objectNode = getNewItem(testCase._2) + container.createItem(objectNode).block() + matchingItemList += objectNode + logInfo(s"ID of test doc: ${objectNode.get(idProperty).asText()}") + } + + // choose one of the items created above and filter by it + val runtimeFilters = getReadManyFilters(Array(matchingItemList(0)), testCase._2, testCase._3) + + for (runTimeFilteringEnabled <- Array(true, false)) { + for (readManyFilteringEnabled <- Array(true, false)) { + logInfo(s"TestCase: containerName ${testCase._1}, partitionKeyProperty ${testCase._2}, " + + s"runtimeFilteringEnabled $runTimeFilteringEnabled, readManyFilteringEnabled $readManyFilteringEnabled") + + val config = Map( + "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, + "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> testCase._1, + "spark.cosmos.read.inferSchema.enabled" -> "true", + "spark.cosmos.applicationName" -> "ItemsScan", + "spark.cosmos.read.partitioning.strategy" -> "Restrictive", + "spark.cosmos.read.runtimeFiltering.enabled" -> runTimeFilteringEnabled.toString, + "spark.cosmos.read.readManyFiltering.enabled" -> readManyFilteringEnabled.toString + ) + val readConfig = CosmosReadConfig.parseCosmosReadConfig(config) + val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config) + + val schema = getDefaultSchema(testCase._2) + val itemScan = new ItemsScan( + spark, + schema, + config, + readConfig, + analyzedAggregatedFilters, + clientMetadataCachesSnapshots, + diagnosticsConfig, + "", + partitionKeyDefinition) + + val plannedInputPartitions = itemScan.planInputPartitions() + plannedInputPartitions.length shouldBe feedRanges.size() // using restrictive strategy + + itemScan.filter(runtimeFilters) + val plannedInputPartitionAfterFiltering = itemScan.planInputPartitions() + + if (runTimeFilteringEnabled && readManyFilteringEnabled) { + // partition can be pruned + plannedInputPartitionAfterFiltering.length shouldBe 1 + val filterItemFeedRange = + SparkBridgeImplementationInternal.partitionKeyToNormalizedRange( + new PartitionKey(getPartitionKeyValue(matchingItemList(0), s"/${testCase._2}")), + partitionKeyDefinition) + + val rangesOverlap = + SparkBridgeImplementationInternal.doRangesOverlap( + filterItemFeedRange, + plannedInputPartitionAfterFiltering(0).asInstanceOf[CosmosInputPartition].feedRange) + + rangesOverlap shouldBe true + } else { + // no partition will be pruned + plannedInputPartitionAfterFiltering.length shouldBe plannedInputPartitions.length + plannedInputPartitionAfterFiltering should contain theSameElementsAs plannedInputPartitions + } + } + } + } + } + + private def getCosmosClientMetadataCachesSnapshots(): Broadcast[CosmosClientMetadataCachesSnapshots] = { + val cosmosClientMetadataCachesSnapshot = new CosmosClientMetadataCachesSnapshot() + cosmosClientMetadataCachesSnapshot.serialize(cosmosClient) + + spark.sparkContext.broadcast( + CosmosClientMetadataCachesSnapshots( + cosmosClientMetadataCachesSnapshot, + Option.empty[CosmosClientMetadataCachesSnapshot])) + } + + private def getReadManyFilters( + filteringItems: Array[ObjectNode], + partitionKeyProperty: String, + readManyFilteringProperty: String): Array[Filter] = { + val readManyFilterValues = + filteringItems + .map(filteringItem => getReadManyFilteringValue(filteringItem, partitionKeyProperty, readManyFilteringProperty)) + + if (partitionKeyProperty.equalsIgnoreCase(idProperty)) { + Array[Filter](In(idProperty, readManyFilterValues.map(_.asInstanceOf[Any]))) + } else { + Array[Filter](In(readManyFilteringProperty, readManyFilterValues.map(_.asInstanceOf[Any]))) + } + } + + private def getReadManyFilteringValue( + objectNode: ObjectNode, + partitionKeyProperty: String, + readManyFilteringProperty: String): String = { + + if (readManyFilteringProperty.equals(itemIdentityProperty)) { + CosmosItemIdentityHelper + .getCosmosItemIdentityValueString( + objectNode.get(idProperty).asText(), + List(objectNode.get(partitionKeyProperty).asText())) + } else { + objectNode.get(idProperty).asText() + } + } + + private def getNewItem(partitionKeyProperty: String): ObjectNode = { + val objectNode = Utils.getSimpleObjectMapper.createObjectNode() + val id = UUID.randomUUID().toString + objectNode.put(idProperty, id) + + if (!partitionKeyProperty.equalsIgnoreCase(idProperty)) { + val pk = UUID.randomUUID().toString + objectNode.put(partitionKeyProperty, pk) + } + + objectNode + } + + private def getDefaultSchema(partitionKeyProperty: String): StructType = { + if (!partitionKeyProperty.equalsIgnoreCase(idProperty)) { + StructType(Seq( + StructField(idProperty, StringType), + StructField(pkProperty, StringType), + StructField(itemIdentityProperty, StringType) + )) + } else { + StructType(Seq( + StructField(idProperty, StringType) + )) + } + } + + //scalastyle:on multiple.string.literals + //scalastyle:on magic.number +} diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/RowSerializerPollTest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/RowSerializerPollTest.scala new file mode 100644 index 000000000000..2335bedf917d --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/RowSerializerPollTest.scala @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +class RowSerializerPollTest extends RowSerializerPollSpec { + //scalastyle:off multiple.string.literals + + "RowSerializer " should "be returned to the pool only a limited number of times" in { + val canRun = Platform.canRunTestAccessingDirectByteBuffer + assume(canRun._1, canRun._2) + + val schema = StructType(Seq(StructField("column01", IntegerType), StructField("column02", StringType))) + + for (_ <- 1 to 256) { + RowSerializerPool.returnSerializerToPool(schema, ExpressionEncoder.apply(schema).createSerializer()) shouldBe true + } + + logInfo("First 256 attempt to pool succeeded") + + RowSerializerPool.returnSerializerToPool(schema, ExpressionEncoder.apply(schema).createSerializer()) shouldBe false + } + //scalastyle:on multiple.string.literals +} + diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala new file mode 100644 index 000000000000..5f9cb1dbdbc8 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.TestConfigurations +import com.fasterxml.jackson.databind.node.ObjectNode + +import java.util.UUID + +class SparkE2EQueryITest + extends SparkE2EQueryITestBase { + + "spark query" can "return proper Cosmos specific query plan on explain with nullable properties" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + + val id = UUID.randomUUID().toString + + val rawItem = + s""" + | { + | "id" : "$id", + | "nestedObject" : { + | "prop1" : 5, + | "prop2" : "6" + | } + | } + |""".stripMargin + + val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode]) + + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer) + container.createItem(objectNode).block() + + val cfg = Map("spark.cosmos.accountEndpoint" -> cosmosEndpoint, + "spark.cosmos.accountKey" -> cosmosMasterKey, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> cosmosContainer, + "spark.cosmos.read.inferSchema.forceNullableProperties" -> "true", + "spark.cosmos.read.partitioning.strategy" -> "Restrictive" + ) + + val df = spark.read.format("cosmos.oltp").options(cfg).load() + val rowsArray = df.where("nestedObject.prop2 = '6'").collect() + rowsArray should have size 1 + + var output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + df.explain() + } + var queryPlan = output.toString.replaceAll("#\\d+", "#x") + logInfo(s"Query Plan: $queryPlan") + queryPlan.contains("Cosmos Query: SELECT * FROM r") shouldEqual true + + output = new java.io.ByteArrayOutputStream() + Console.withOut(output) { + df.where("nestedObject.prop2 = '6'").explain() + } + queryPlan = output.toString.replaceAll("#\\d+", "#x") + logInfo(s"Query Plan: $queryPlan") + val expected = s"Cosmos Query: SELECT * FROM r WHERE (NOT(IS_NULL(r['nestedObject']['prop2'])) AND IS_DEFINED(r['nestedObject']['prop2'])) " + + s"AND r['nestedObject']['prop2']=" + + s"@param0${System.getProperty("line.separator")} > param: @param0 = 6" + queryPlan.contains(expected) shouldEqual true + + val item = rowsArray(0) + item.getAs[String]("id") shouldEqual id + } +}