diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CHANGELOG.md b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CHANGELOG.md
new file mode 100644
index 000000000000..3972ae6aeb98
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CHANGELOG.md
@@ -0,0 +1,56 @@
+## Release History
+
+### 4.47.0-beta.1 (Unreleased)
+
+#### Features Added
+
+#### Breaking Changes
+
+#### Bugs Fixed
+
+#### Other Changes
+
+### 4.46.0 (2026-03-27)
+
+#### Bugs Fixed
+* Fixed an issue where creating containers with hierarchical partition keys (multi-hash) through the Spark catalog on the AAD path would fail. - See [PR 48548](https://github.com/Azure/azure-sdk-for-java/pull/48548)
+
+### 4.45.0 (2026-03-13)
+
+#### Features Added
+* Added `vectorEmbeddingPolicy` support in Spark catalog `TBLPROPERTIES` for creating vector-search-enabled containers. - See [PR 48349](https://github.com/Azure/azure-sdk-for-java/pull/48349)
+
+### 4.44.2 (2026-03-05)
+
+#### Other Changes
+* Changed azure-resourcemanager-cosmos usage to a pinned version which is deployed across all public and non-public clouds - [PR 48268](https://github.com/Azure/azure-sdk-for-java/pull/48268)
+
+### 4.44.1 (2026-03-03)
+
+#### Other Changes
+* Reduced noisy warning logs in Gateway mode - [PR 48189](https://github.com/Azure/azure-sdk-for-java/pull/48189)
+
+### 4.44.0 (2026-02-27)
+
+#### Features Added
+* Added config entry `spark.cosmos.account.azureEnvironment.management.scope` to allow specifying the Entra ID scope/audience to be used when retrieving tokens to authenticate against the ARM/management endpoint of non-public clouds. - See [PR 48137](https://github.com/Azure/azure-sdk-for-java/pull/48137)
+
+### 4.43.1 (2026-02-25)
+
+#### Bugs Fixed
+* Fixed an issue where `TransientIOErrorsRetryingIterator` would trigger extra query during retries and on close. - See [PR 47996](https://github.com/Azure/azure-sdk-for-java/pull/47996)
+
+#### Other Changes
+* Added status code history in `BulkWriterNoProgressException` error message. - See [PR 48022](https://github.com/Azure/azure-sdk-for-java/pull/48022)
+* Reduced the log noise level for frequent transient errors - for example throttling - in Gateway mode - [PR 48112](https://github.com/Azure/azure-sdk-for-java/pull/48112)
+
+### 4.43.0 (2026-02-10)
+
+#### Features Added
+* Initial release of Spark 4.0 connector with Scala 2.13 support
+* Added transactional batch support. See [PR 47478](https://github.com/Azure/azure-sdk-for-java/pull/47478) and [PR 47697](https://github.com/Azure/azure-sdk-for-java/pull/47697) and [47803](https://github.com/Azure/azure-sdk-for-java/pull/47803)
+* Added support for throughput bucket. - See [47856](https://github.com/Azure/azure-sdk-for-java/pull/47856)
+
+#### Other Changes
+
+### NOTE: See CHANGELOG.md in 3.3, 3.4, and 3.5 projects for changes in prior Spark versions
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CONTRIBUTING.md b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CONTRIBUTING.md
new file mode 100644
index 000000000000..2435e3acead4
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/CONTRIBUTING.md
@@ -0,0 +1,84 @@
+# Contributing
+This instruction is guideline for building and code contribution.
+
+## Prerequisites
+- JDK 17 or above (Spark 4.0 requires Java 17+)
+- [Maven](https://maven.apache.org/) 3.0 and above
+
+## Build from source
+To build the project, run maven commands.
+
+```bash
+git clone https://github.com/Azure/azure-sdk-for-java.git
+cd sdk/cosmos/azure-cosmos-spark_4-0_2-13
+mvn clean install
+```
+
+## Test
+There are integration tests on azure and on emulator to trigger integration test execution
+against Azure Cosmos DB and against
+[Azure Cosmos DB Emulator](https://docs.microsoft.com/azure/cosmos-db/local-emulator), you need to
+follow the link to set up emulator before test execution.
+
+- Run unit tests
+```bash
+mvn clean install -Dgpg.skip
+```
+
+- Run integration tests
+ - on Azure
+ > **NOTE** Please note that integration test against Azure requires Azure Cosmos DB Document
+ API and will automatically create a Cosmos database in your Azure subscription, then there
+ will be **Azure usage fee.**
+
+ Integration tests will require a Azure Subscription. If you don't already have an Azure
+ subscription, you can activate your
+ [MSDN subscriber benefits](https://azure.microsoft.com/pricing/member-offers/msdn-benefits-details/)
+ or sign up for a [free Azure account](https://azure.microsoft.com/free/).
+
+ 1. Create an Azure Cosmos DB on Azure.
+ - Go to [Azure portal](https://portal.azure.com/) and click +New.
+ - Click Databases, and then click Azure Cosmos DB to create your database.
+ - Navigate to the database you have created, and click Access keys and copy your
+ URI and access keys for your database.
+
+ 2. Set environment variables ACCOUNT_HOST, ACCOUNT_KEY and SECONDARY_ACCOUNT_KEY, where value
+ of them are Cosmos account URI, primary key and secondary key.
+
+ So set the
+ second group environment variables NEW_ACCOUNT_HOST, NEW_ACCOUNT_KEY and
+ NEW_SECONDARY_ACCOUNT_KEY, the two group environment variables can be same.
+ 3. Run maven command with `integration-test-azure` profile.
+
+ ```bash
+ set ACCOUNT_HOST=your-cosmos-account-uri
+ set ACCOUNT_KEY=your-cosmos-account-primary-key
+ set SECONDARY_ACCOUNT_KEY=your-cosmos-account-secondary-key
+
+ set NEW_ACCOUNT_HOST=your-cosmos-account-uri
+ set NEW_ACCOUNT_KEY=your-cosmos-account-primary-key
+ set NEW_SECONDARY_ACCOUNT_KEY=your-cosmos-account-secondary-key
+ mvnw -P integration-test-azure clean install
+ ```
+
+ - on Emulator
+
+ Setup Azure Cosmos DB Emulator by following
+ [this instruction](https://docs.microsoft.com/azure/cosmos-db/local-emulator), and set
+ associated environment variables. Then run test with:
+ ```bash
+ mvnw -P integration-test-emulator install
+ ```
+
+
+- Skip tests execution
+```bash
+mvn clean install -Dgpg.skip -DskipTests
+```
+
+## Version management
+Developing version naming convention is like `0.1.2-beta.1`. Release version naming convention is like `0.1.2`.
+
+## Contribute to code
+Contribution is welcome. Please follow
+[this instruction](https://github.com/Azure/azure-sdk-for-java/blob/main/CONTRIBUTING.md) to contribute code.
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/README.md b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/README.md
new file mode 100644
index 000000000000..81869fdba085
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/README.md
@@ -0,0 +1,233 @@
+# Azure Cosmos DB OLTP Spark 4 connector
+
+## Azure Cosmos DB OLTP Spark 4 connector for Spark 4.0
+**Azure Cosmos DB OLTP Spark connector** provides Apache Spark support for Azure Cosmos DB using
+the [SQL API][sql_api_query].
+[Azure Cosmos DB][cosmos_introduction] is a globally-distributed database service which allows
+developers to work with data using a variety of standard APIs, such as SQL, MongoDB, Cassandra, Graph, and Table.
+
+If you have any feedback or ideas on how to improve your experience please let us know here:
+https://github.com/Azure/azure-sdk-for-java/issues/new
+
+### Documentation
+
+- [Getting started](https://aka.ms/azure-cosmos-spark-3-quickstart)
+- [Catalog API](https://aka.ms/azure-cosmos-spark-3-catalog-api)
+- [Configuration Parameter Reference](https://aka.ms/azure-cosmos-spark-3-config)
+
+### Version Compatibility
+
+#### azure-cosmos-spark_4-0_2-13
+| Connector | Supported Spark Versions | Minimum Java Version | Supported Scala Versions | Supported Databricks Runtimes | Supported Fabric Runtimes |
+|-----------|--------------------------|----------------------|---------------------------|-------------------------------|---------------------------|
+| 4.46.0 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD |
+| 4.45.0 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD |
+| 4.44.2 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD |
+| 4.44.1 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD |
+| 4.44.0 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD |
+| 4.43.1 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD |
+| 4.43.0 | 4.0.0 | [17, 21] | 2.13 | 17.\* | TBD |
+
+Note: Spark 4.0 requires Scala 2.13 and Java 17 or higher. When using the Scala API, it is necessary for applications
+to use Scala 2.13 that Spark 4.0 was compiled for.
+
+#### azure-cosmos-spark_3-3_2-12
+| Connector | Supported Spark Versions | Supported JVM Versions | Supported Scala Versions | Supported Databricks Runtimes |
+|-----------|--------------------------|------------------------|--------------------------|-------------------------------|
+| 4.46.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.45.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.44.2 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.44.1 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.44.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.43.1 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.43.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.42.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.41.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.40.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.39.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.38.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.37.2 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.37.1 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.37.0 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.36.1 | 3.3.0 - 3.3.2 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.36.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.35.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.34.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.33.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.33.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.32.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.32.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.31.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.30.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.29.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.28.4 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.28.3 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.28.2 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.28.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.28.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.27.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.27.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.26.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.26.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.25.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.25.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.24.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.24.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.23.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.22.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.21.1 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.21.0 | 3.3.0 | [8, 11] | 2.12 | 11.\*, 12.\* |
+| 4.20.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* |
+| 4.19.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* |
+| 4.18.2 | 3.3.0 | [8, 11] | 2.12 | 11.\* |
+| 4.18.1 | 3.3.0 | [8, 11] | 2.12 | 11.\* |
+| 4.18.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* |
+| 4.17.2 | 3.3.0 | [8, 11] | 2.12 | 11.\* |
+| 4.17.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* |
+| 4.16.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* |
+| 4.15.0 | 3.3.0 | [8, 11] | 2.12 | 11.\* |
+
+#### azure-cosmos-spark_3-4_2-12
+| Connector | Supported Spark Versions | Supported JVM Versions | Supported Scala Versions | Supported Databricks Runtimes | Supported Fabric Runtimes |
+|-----------|--------------------------|------------------------|--------------------------|-------------------------------|---------------------------|
+| 4.46.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.45.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.44.2 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.44.1 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.44.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.43.1 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.43.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.42.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.41.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.40.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.39.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.38.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.37.2 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.37.1 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.37.0 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.36.1 | 3.4.0 - 3.4.1 | [8, 11] | 2.12 | 13.\* | |
+| 4.36.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.35.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.34.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.33.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.33.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.32.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.32.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.31.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.30.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.29.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.28.4 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.28.3 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.28.2 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.28.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.28.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.27.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.27.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.26.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.26.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.25.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.25.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.24.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.24.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.23.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.22.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.21.1 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+| 4.21.0 | 3.4.0 | [8, 11] | 2.12 | 13.* | |
+
+#### azure-cosmos-spark_3-5_2-12
+| Connector | Supported Spark Versions | Minimum Java Version | Supported Scala Versions | Supported Databricks Runtimes | Supported Fabric Runtimes |
+|-----------|--------------------------|-----------------------|---------------------------|-------------------------------|---------------------------|
+| 4.46.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* |
+| 4.45.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* |
+| 4.44.2 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* |
+| 4.44.1 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* |
+| 4.44.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* |
+| 4.43.1 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* |
+| 4.43.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* |
+| 4.42.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* |
+| 4.41.0 | 3.5.0 | [8, 11, 17] | 2.12 | 14.\*, 15.\*, 16.4 LTS | 1.3.\* |
+| 4.40.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.39.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.38.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.37.2 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.37.1 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.37.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.36.1 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.36.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.35.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.34.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.33.1 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.33.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.32.1 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.32.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.31.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.30.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+| 4.29.0 | 3.5.0 | [8, 11] | 2.12 | 14.\*, 15.\* | |
+
+Note: Java 8 prior to version 8u371 support is deprecated as of Spark 3.5.0. When using the Scala API, it is necessary for applications
+to use the same version of Scala that Spark was compiled for.
+
+#### azure-cosmos-spark_3-5_2-13
+| Connector | Supported Spark Versions | Minimum Java Version | Supported Scala Versions | Supported Databricks Runtimes | Supported Fabric Runtimes |
+|-----------|--------------------------|-----------------------|---------------------------|-------------------------------|---------------------------|
+| 4.46.0 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD |
+| 4.45.0 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD |
+| 4.44.2 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD |
+| 4.44.1 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD |
+| 4.44.0 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD |
+| 4.43.1 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD |
+| 4.43.0 | 3.5.0 | [17] | 2.13 | 16.4 LTS | TBD |
+
+### Download
+
+You can use the maven coordinate of the jar to auto install the Spark Connector to your Databricks Runtime from Maven:
+`com.azure.cosmos.spark:azure-cosmos-spark_4-0_2-13:4.46.0`
+
+You can also integrate against Cosmos DB Spark Connector in your SBT project:
+```scala
+libraryDependencies += "com.azure.cosmos.spark" % "azure-cosmos-spark_4-0_2-13" % "4.46.0"
+```
+
+Cosmos DB Spark Connector is available on [Maven Central Repo](https://central.sonatype.com/search?namespace=com.azure.cosmos.spark).
+
+#### General
+
+If you encounter any bug, please file an issue [here](https://github.com/Azure/azure-sdk-for-java/issues/new).
+
+To suggest a new feature or changes that could be made, file an issue the same way you would for a bug.
+
+### License
+This project is under MIT license and uses and repackages other third party libraries as an uber jar.
+See [NOTICE.txt](https://github.com/Azure/azure-sdk-for-java/blob/main/NOTICE.txt).
+
+### Contributing
+
+This project welcomes contributions and suggestions. Most contributions require you to agree to a
+[Contributor License Agreement (CLA)][cla] declaring that you have the right to, and actually do, grant us the rights
+to use your contribution.
+
+When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate
+the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to
+do this once across all repos using our CLA.
+
+This project has adopted the [Microsoft Open Source Code of Conduct][coc]. For more information see the [Code of Conduct FAQ][coc_faq]
+or contact [opencode@microsoft.com][coc_contact] with any additional questions or comments.
+
+
+[source_code]: src
+[cosmos_introduction]: https://learn.microsoft.com/azure/cosmos-db/
+[cosmos_docs]: https://learn.microsoft.com/azure/cosmos-db/introduction
+[jdk]: https://learn.microsoft.com/java/azure/jdk/?view=azure-java-stable
+[maven]: https://maven.apache.org/
+[cla]: https://cla.microsoft.com
+[coc]: https://opensource.microsoft.com/codeofconduct/
+[coc_faq]: https://opensource.microsoft.com/codeofconduct/faq/
+[coc_contact]: mailto:opencode@microsoft.com
+[azure_subscription]: https://azure.microsoft.com/free/
+[samples]: https://github.com/Azure/azure-sdk-for-java/tree/main/sdk/spring/azure-spring-data-cosmos/src/samples/java/com/azure/spring/data/cosmos
+[sql_api_query]: https://learn.microsoft.com/azure/cosmos-db/sql-api-sql-query
+[local_emulator]: https://learn.microsoft.com/azure/cosmos-db/local-emulator
+[local_emulator_export_ssl_certificates]: https://learn.microsoft.com/azure/cosmos-db/local-emulator-export-ssl-certificates
+[azure_cosmos_db_partition]: https://learn.microsoft.com/azure/cosmos-db/partition-data
+[sql_queries_in_cosmos]: https://learn.microsoft.com/azure/cosmos-db/tutorial-query-sql-api
+[sql_queries_getting_started]: https://learn.microsoft.com/azure/cosmos-db/sql-query-getting-started
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/pom.xml b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/pom.xml
new file mode 100644
index 000000000000..1499d93ee4d3
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/pom.xml
@@ -0,0 +1,262 @@
+
+
+ 4.0.0
+
+ com.azure.cosmos.spark
+ azure-cosmos-spark_3
+ 0.0.1-beta.1
+ ../azure-cosmos-spark_3
+
+ com.azure.cosmos.spark
+ azure-cosmos-spark_4-1_2-13
+ 4.47.0-beta.1
+ jar
+ https://github.com/Azure/azure-sdk-for-java/tree/main/sdk/cosmos/azure-cosmos-spark_4-1_2-13
+ OLTP Spark 4.1 Connector for Azure Cosmos DB SQL API
+ OLTP Spark 4.1 Connector for Azure Cosmos DB SQL API
+
+ scm:git:https://github.com/Azure/azure-sdk-for-java.git/sdk/cosmos/azure-cosmos-spark_4-1_2-13
+
+ https://github.com/Azure/azure-sdk-for-java/sdk/cosmos/azure-cosmos-spark_4-1_2-13
+
+
+ Microsoft Corporation
+ http://microsoft.com
+
+
+
+ The MIT License (MIT)
+ http://opensource.org/licenses/MIT
+ repo
+
+
+
+
+ microsoft
+ Microsoft Corporation
+
+
+
+ false
+ 4.1
+ 2.13
+ 2.13.17
+ 0.9.1
+ 0.8.0
+ 3.2.2
+ 3.2.3
+ 3.2.3
+ 5.0.0
+ true
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-resources-plugin
+ 2.4.3
+
+
+ copy-shared-sources
+ initialize
+
+ copy-resources
+
+
+ ${project.build.directory}/shared-sources
+
+
+ ${basedir}/../azure-cosmos-spark_3/src/main/scala
+
+ **/ChangeFeedInitialOffsetWriter.scala
+ **/CosmosCatalogBase.scala
+
+
+
+
+
+
+ copy-shared-test-sources
+ initialize
+
+ copy-resources
+
+
+ ${project.build.directory}/shared-test-sources
+
+
+ ${basedir}/../azure-cosmos-spark_3/src/test/scala
+
+ **/CosmosCatalogITestBase.scala
+
+
+
+
+
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+ 3.6.1
+
+
+ add-sources
+ generate-sources
+
+ add-source
+
+
+
+ ${project.build.directory}/shared-sources
+ ${basedir}/src/main/scala
+
+
+
+
+ add-test-sources
+ generate-test-sources
+
+ add-test-source
+
+
+
+ ${project.build.directory}/shared-test-sources
+ ${basedir}/src/test/scala
+
+
+
+
+ add-resources
+ generate-resources
+
+ add-resource
+
+
+
+ ${basedir}/../azure-cosmos-spark_3/src/main/resources
+ ${basedir}/src/main/resources
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-enforcer-plugin
+ 3.6.1
+
+
+
+
+
+
+ spark-e2e_4-1_2-13
+
+
+ [17,)
+
+ ${basedir}/scalastyle_config.xml
+
+
+ spark-e2e_4-1_2-13
+ true
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+ 3.5.3
+
+
+ **/*.*
+ **/*Test.*
+ **/*Suite.*
+ **/*Spec.*
+
+ true
+
+
+
+ org.scalatest
+ scalatest-maven-plugin
+ 2.1.0
+
+ ${scalatest.argLine}
+ stdOut=true,verbose=true,stdErr=true
+ false
+ FDEF
+ FDEF
+ once
+ true
+ ${project.build.directory}/surefire-reports
+ .
+ SparkTestSuite.txt
+ (ITest|Test|Spec|Suite)
+
+
+
+ test
+
+ test
+
+
+
+
+
+
+
+
+
+ spark-4-1-disable-tests-java-lt-17
+
+ (,17)
+
+
+ true
+
+
+
+ java9-plus
+
+ [9,)
+
+
+ --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false
+
+
+
+
+
+ org.apache.spark
+ spark-sql_2.13
+ 4.1.0
+
+
+ io.netty
+ netty-all
+
+
+ org.slf4j
+ *
+
+
+ provided
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ 2.18.4
+
+
+ com.fasterxml.jackson.module
+ jackson-module-scala_2.13
+ 2.18.4
+
+
+
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/scalastyle_config.xml b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/scalastyle_config.xml
new file mode 100644
index 000000000000..7a8ad2823fb8
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/scalastyle_config.xml
@@ -0,0 +1,130 @@
+
+ Scalastyle standard configuration
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/resources/azure-cosmos-spark.properties b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/resources/azure-cosmos-spark.properties
new file mode 100644
index 000000000000..ca812989b4f2
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/resources/azure-cosmos-spark.properties
@@ -0,0 +1,2 @@
+name=${project.artifactId}
+version=${project.version}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedInitialOffsetWriter.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedInitialOffsetWriter.scala
new file mode 100644
index 000000000000..c4687f4102ed
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedInitialOffsetWriter.scala
@@ -0,0 +1,60 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.streaming.checkpointing.{HDFSMetadataLog, MetadataVersionUtil}
+
+import java.io.{BufferedWriter, InputStream, InputStreamReader, OutputStream, OutputStreamWriter}
+import java.nio.charset.StandardCharsets
+
+private class ChangeFeedInitialOffsetWriter
+(
+ sparkSession: SparkSession,
+ metadataPath: String
+) extends HDFSMetadataLog[String](sparkSession, metadataPath) {
+
+ val VERSION = 1
+
+ override def serialize(offsetJson: String, out: OutputStream): Unit = {
+ val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
+ writer.write(s"v$VERSION\n")
+ writer.write(offsetJson)
+ writer.flush()
+ }
+
+ override def deserialize(in: InputStream): String = {
+ val content = readerToString(new InputStreamReader(in, StandardCharsets.UTF_8))
+ // HDFSMetadataLog would never create a partial file.
+ require(content.nonEmpty)
+ val indexOfNewLine = content.indexOf("\n")
+ if (content(0) != 'v' || indexOfNewLine < 0) {
+ throw new IllegalStateException(
+ "Log file was malformed: failed to detect the log file version line.")
+ }
+
+ MetadataVersionUtil.validateVersion(content.substring(0, indexOfNewLine), VERSION)
+ content.substring(indexOfNewLine + 1)
+ }
+
+ private def readerToString(reader: java.io.Reader): String = {
+ val writer = new StringBuilderWriter
+ val buffer = new Array[Char](4096)
+ Stream.continually(reader.read(buffer)).takeWhile(_ != -1).foreach(writer.write(buffer, 0, _))
+ writer.toString
+ }
+
+ private class StringBuilderWriter extends java.io.Writer {
+ private val stringBuilder = new StringBuilder
+
+ override def write(cbuf: Array[Char], off: Int, len: Int): Unit = {
+ stringBuilder.appendAll(cbuf, off, len)
+ }
+
+ override def flush(): Unit = {}
+
+ override def close(): Unit = {}
+
+ override def toString: String = stringBuilder.toString()
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedMicroBatchStream.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedMicroBatchStream.scala
new file mode 100644
index 000000000000..bf4632cf609a
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ChangeFeedMicroBatchStream.scala
@@ -0,0 +1,271 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.changeFeedMetrics.{ChangeFeedMetricsListener, ChangeFeedMetricsTracker}
+import com.azure.cosmos.implementation.SparkBridgeImplementationInternal
+import com.azure.cosmos.implementation.guava25.collect.{HashBiMap, Maps}
+import com.azure.cosmos.spark.CosmosPredicates.{assertNotNull, assertNotNullOrEmpty, assertOnSparkDriver}
+import com.azure.cosmos.spark.diagnostics.{DiagnosticsContext, LoggerHelper}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset, ReadLimit, SupportsAdmissionControl}
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
+import org.apache.spark.sql.types.StructType
+
+import java.time.Duration
+import java.util.UUID
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicLong
+
+// scalastyle:off underscore.import
+import scala.collection.JavaConverters._
+// scalastyle:on underscore.import
+
+// scala style rule flaky - even complaining on partial log messages
+// scalastyle:off multiple.string.literals
+private class ChangeFeedMicroBatchStream
+(
+ val session: SparkSession,
+ val schema: StructType,
+ val config: Map[String, String],
+ val cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots],
+ val checkpointLocation: String,
+ diagnosticsConfig: DiagnosticsConfig
+) extends MicroBatchStream
+ with SupportsAdmissionControl {
+
+ @transient private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass)
+
+ private val correlationActivityId = UUID.randomUUID()
+ private val streamId = correlationActivityId.toString
+ log.logTrace(s"Instantiated ${this.getClass.getSimpleName}.$streamId")
+
+ private val defaultParallelism = session.sparkContext.defaultParallelism
+ private val readConfig = CosmosReadConfig.parseCosmosReadConfig(config)
+ private val sparkEnvironmentInfo = CosmosClientConfiguration.getSparkEnvironmentInfo(Some(session))
+ private val clientConfiguration = CosmosClientConfiguration.apply(
+ config,
+ readConfig.readConsistencyStrategy,
+ sparkEnvironmentInfo)
+ private val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(config)
+ private val partitioningConfig = CosmosPartitioningConfig.parseCosmosPartitioningConfig(config)
+ private val changeFeedConfig = CosmosChangeFeedConfig.parseCosmosChangeFeedConfig(config)
+ private val clientCacheItem = CosmosClientCache(
+ clientConfiguration,
+ Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches),
+ s"ChangeFeedMicroBatchStream(streamId $streamId)")
+ private val throughputControlClientCacheItemOpt =
+ ThroughputControlHelper.getThroughputControlClientCacheItem(
+ config, clientCacheItem.context, Some(cosmosClientStateHandles), sparkEnvironmentInfo)
+ private val container =
+ ThroughputControlHelper.getContainer(
+ config,
+ containerConfig,
+ clientCacheItem,
+ throughputControlClientCacheItemOpt)
+
+ private var latestOffsetSnapshot: Option[ChangeFeedOffset] = None
+
+ private val partitionIndex = new AtomicLong(0)
+ private val partitionIndexMap = Maps.synchronizedBiMap(HashBiMap.create[NormalizedRange, Long]())
+ private val partitionMetricsMap = new ConcurrentHashMap[NormalizedRange, ChangeFeedMetricsTracker]()
+
+ if (changeFeedConfig.performanceMonitoringEnabled) {
+ log.logInfo("ChangeFeed performance monitoring is enabled, registering ChangeFeedMetricsListener")
+ session.sparkContext.addSparkListener(new ChangeFeedMetricsListener(partitionIndexMap, partitionMetricsMap))
+ } else {
+ log.logInfo("ChangeFeed performance monitoring is disabled")
+ }
+
+ override def latestOffset(): Offset = {
+ // For Spark data streams implementing SupportsAdmissionControl trait
+ // latestOffset(Offset, ReadLimit) is called instead
+ throw new UnsupportedOperationException(
+ "latestOffset(Offset, ReadLimit) should be called instead of this method")
+ }
+
+ /**
+ * Returns a list of `InputPartition` given the start and end offsets. Each
+ * `InputPartition` represents a data split that can be processed by one Spark task. The
+ * number of input partitions returned here is the same as the number of RDD partitions this scan
+ * outputs.
+ *
+ * If the `Scan` supports filter push down, this stream is likely configured with a filter
+ * and is responsible for creating splits for that filter, which is not a full scan.
+ *
+ *
+ * This method will be called multiple times, to launch one Spark job for each micro-batch in this
+ * data stream.
+ *
+ */
+ override def planInputPartitions(startOffset: Offset, endOffset: Offset): Array[InputPartition] = {
+ assertNotNull(startOffset, "startOffset")
+ assertNotNull(endOffset, "endOffset")
+ assert(startOffset.isInstanceOf[ChangeFeedOffset], "Argument 'startOffset' is not a change feed offset.")
+ assert(endOffset.isInstanceOf[ChangeFeedOffset], "Argument 'endOffset' is not a change feed offset.")
+
+ log.logDebug(s"--> planInputPartitions.$streamId, startOffset: ${startOffset.json()} - endOffset: ${endOffset.json()}")
+ val start = startOffset.asInstanceOf[ChangeFeedOffset]
+ val end = endOffset.asInstanceOf[ChangeFeedOffset]
+
+ val startChangeFeedState = new String(java.util.Base64.getUrlDecoder.decode(start.changeFeedState))
+ log.logDebug(s"Start-ChangeFeedState.$streamId: $startChangeFeedState")
+
+ val endChangeFeedState = new String(java.util.Base64.getUrlDecoder.decode(end.changeFeedState))
+ log.logDebug(s"End-ChangeFeedState.$streamId: $endChangeFeedState")
+
+ assert(end.inputPartitions.isDefined, "Argument 'endOffset.inputPartitions' must not be null or empty.")
+
+ val parsedStartChangeFeedState = SparkBridgeImplementationInternal.parseChangeFeedState(start.changeFeedState)
+ end
+ .inputPartitions
+ .get
+ .map(partition => {
+ val index = partitionIndexMap.asScala.getOrElseUpdate(partition.feedRange, partitionIndex.incrementAndGet())
+ partition
+ .withContinuationState(
+ SparkBridgeImplementationInternal
+ .extractChangeFeedStateForRange(parsedStartChangeFeedState, partition.feedRange),
+ clearEndLsn = false)
+ .withIndex(index)
+ })
+ }
+
+ /**
+ * Returns a factory to create a `PartitionReader` for each `InputPartition`.
+ */
+ override def createReaderFactory(): PartitionReaderFactory = {
+ log.logDebug(s"--> createReaderFactory.$streamId")
+ ChangeFeedScanPartitionReaderFactory(
+ config,
+ schema,
+ DiagnosticsContext(correlationActivityId, checkpointLocation),
+ cosmosClientStateHandles,
+ diagnosticsConfig,
+ CosmosClientConfiguration.getSparkEnvironmentInfo(Some(session)))
+ }
+
+ /**
+ * Returns the most recent offset available given a read limit. The start offset can be used
+ * to figure out how much new data should be read given the limit. Users should implement this
+ * method instead of latestOffset for a MicroBatchStream or getOffset for Source.
+ *
+ * When this method is called on a `Source`, the source can return `null` if there is no
+ * data to process. In addition, for the very first micro-batch, the `startOffset` will be
+ * null as well.
+ *
+ * When this method is called on a MicroBatchStream, the `startOffset` will be `initialOffset`
+ * for the very first micro-batch. The source can return `null` if there is no data to process.
+ */
+ // This method is doing all the heavy lifting - after calculating the latest offset
+ // all information necessary to plan partitions is available - so we plan partitions here and
+ // serialize them in the end offset returned to avoid any IO calls for the actual partitioning
+ override def latestOffset(startOffset: Offset, readLimit: ReadLimit): Offset = {
+
+ log.logDebug(s"--> latestOffset.$streamId")
+
+ val startChangeFeedOffset = startOffset.asInstanceOf[ChangeFeedOffset]
+ val offset = CosmosPartitionPlanner.getLatestOffset(
+ config,
+ startChangeFeedOffset,
+ readLimit,
+ Duration.ZERO,
+ this.clientConfiguration,
+ this.cosmosClientStateHandles,
+ this.containerConfig,
+ this.partitioningConfig,
+ this.defaultParallelism,
+ this.container,
+ Some(this.partitionMetricsMap)
+ )
+
+ if (offset.changeFeedState != startChangeFeedOffset.changeFeedState) {
+ log.logDebug(s"<-- latestOffset.$streamId - new offset ${offset.json()}")
+ this.latestOffsetSnapshot = Some(offset)
+ offset
+ } else {
+ log.logDebug(s"<-- latestOffset.$streamId - Finished returning null")
+
+ this.latestOffsetSnapshot = None
+
+ // scalastyle:off null
+ // null means no more data to process
+ // null is used here because the DataSource V2 API is defined in Java
+ null
+ // scalastyle:on null
+ }
+ }
+
+ /**
+ * Returns the initial offset for a streaming query to start reading from. Note that the
+ * streaming data source should not assume that it will start reading from its initial offset:
+ * if Spark is restarting an existing query, it will restart from the check-pointed offset rather
+ * than the initial one.
+ */
+ // Mapping start form settings to the initial offset/LSNs
+ override def initialOffset(): Offset = {
+ assertOnSparkDriver()
+
+ val metadataLog = new ChangeFeedInitialOffsetWriter(
+ assertNotNull(session, "session"),
+ assertNotNullOrEmpty(checkpointLocation, "checkpointLocation"))
+ val offsetJson = metadataLog.get(0).getOrElse {
+ val newOffsetJson = CosmosPartitionPlanner.createInitialOffset(
+ container, containerConfig, changeFeedConfig, partitioningConfig, Some(streamId))
+ metadataLog.add(0, newOffsetJson)
+ newOffsetJson
+ }
+
+ log.logDebug(s"MicroBatch stream $streamId: Initial offset '$offsetJson'.")
+ ChangeFeedOffset(offsetJson, None)
+ }
+
+ /**
+ * Returns the read limits potentially passed to the data source through options when creating
+ * the data source.
+ */
+ override def getDefaultReadLimit: ReadLimit = {
+ this.changeFeedConfig.toReadLimit
+ }
+
+ /**
+ * Returns the most recent offset available.
+ *
+ * The source can return `null`, if there is no data to process or the source does not support
+ * to this method.
+ */
+ override def reportLatestOffset(): Offset = {
+ this.latestOffsetSnapshot.orNull
+ }
+
+ /**
+ * Deserialize a JSON string into an Offset of the implementation-defined offset type.
+ *
+ * @throws IllegalArgumentException if the JSON does not encode a valid offset for this reader
+ */
+ override def deserializeOffset(s: String): Offset = {
+ log.logDebug(s"MicroBatch stream $streamId: Deserialized offset '$s'.")
+ ChangeFeedOffset.fromJson(s)
+ }
+
+ /**
+ * Informs the source that Spark has completed processing all data for offsets less than or
+ * equal to `end` and will only request offsets greater than `end` in the future.
+ */
+ override def commit(offset: Offset): Unit = {
+ log.logDebug(s"MicroBatch stream $streamId: Committed offset '${offset.json()}'.")
+ }
+
+ /**
+ * Stop this source and free any resources it has allocated.
+ */
+ override def stop(): Unit = {
+ clientCacheItem.close()
+ if (throughputControlClientCacheItemOpt.isDefined) {
+ throughputControlClientCacheItemOpt.get.close()
+ }
+ log.logDebug(s"MicroBatch stream $streamId: stopped.")
+ }
+}
+// scalastyle:on multiple.string.literals
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosBytesWrittenMetric.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosBytesWrittenMetric.scala
new file mode 100644
index 000000000000..9d7f645227bf
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosBytesWrittenMetric.scala
@@ -0,0 +1,11 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import org.apache.spark.sql.connector.metric.CustomSumMetric
+
+private[cosmos] class CosmosBytesWrittenMetric extends CustomSumMetric {
+ override def name(): String = CosmosConstants.MetricNames.BytesWritten
+
+ override def description(): String = CosmosConstants.MetricNames.BytesWritten
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalog.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalog.scala
new file mode 100644
index 000000000000..778c2311e2e0
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalog.scala
@@ -0,0 +1,59 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos.spark
+
+import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException
+
+import java.util
+// scalastyle:off underscore.import
+// scalastyle:on underscore.import
+import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException}
+import org.apache.spark.sql.connector.catalog.{NamespaceChange, SupportsNamespaces}
+
+// scalastyle:off underscore.import
+
+class CosmosCatalog
+ extends CosmosCatalogBase
+ with SupportsNamespaces {
+
+ override def listNamespaces(): Array[Array[String]] = {
+ super.listNamespacesBase()
+ }
+
+ @throws(classOf[NoSuchNamespaceException])
+ override def listNamespaces(namespace: Array[String]): Array[Array[String]] = {
+ super.listNamespacesBase(namespace)
+ }
+
+ @throws(classOf[NoSuchNamespaceException])
+ override def loadNamespaceMetadata(namespace: Array[String]): util.Map[String, String] = {
+ super.loadNamespaceMetadataBase(namespace)
+ }
+
+ @throws(classOf[NamespaceAlreadyExistsException])
+ override def createNamespace(namespace: Array[String],
+ metadata: util.Map[String, String]): Unit = {
+ super.createNamespaceBase(namespace, metadata)
+ }
+
+ @throws(classOf[UnsupportedOperationException])
+ override def alterNamespace(namespace: Array[String],
+ changes: NamespaceChange*): Unit = {
+ super.alterNamespaceBase(namespace, changes)
+ }
+
+ @throws(classOf[NoSuchNamespaceException])
+ @throws(classOf[NonEmptyNamespaceException])
+ override def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = {
+ if (!cascade) {
+ if (this.listTables(namespace).length > 0) {
+ throw new NonEmptyNamespaceException(namespace)
+ }
+ }
+ super.dropNamespaceBase(namespace)
+ }
+}
+// scalastyle:on multiple.string.literals
+// scalastyle:on number.of.methods
+// scalastyle:on file.size.limit
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalogBase.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalogBase.scala
new file mode 100644
index 000000000000..6aed2016bba1
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosCatalogBase.scala
@@ -0,0 +1,727 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.spark.catalog.{CosmosCatalogConflictException, CosmosCatalogException, CosmosCatalogNotFoundException, CosmosThroughputProperties}
+import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException}
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, NamespaceChange, Table, TableCatalog, TableChange}
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.execution.streaming.checkpointing.HDFSMetadataLog
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+import java.util
+import scala.annotation.tailrec
+import scala.collection.mutable.ArrayBuffer
+
+// scalastyle:off underscore.import
+import scala.collection.JavaConverters._
+// scalastyle:on underscore.import
+
+// CosmosCatalog provides a meta data store for Cosmos database, container control plane
+// This will be required for hive integration
+// relevant interfaces to implement:
+// - SupportsNamespaces (Cosmos Database and Cosmos Container can be modeled as namespace)
+// - SupportsCatalogOptions // TODO moderakh
+// - CatalogPlugin - A marker interface to provide a catalog implementation for Spark.
+// Implementations can provide catalog functions by implementing additional interfaces
+// for tables, views, and functions.
+// - TableCatalog Catalog methods for working with Tables.
+
+// All Hive keywords are case-insensitive, including the names of Hive operators and functions.
+// scalastyle:off multiple.string.literals
+// scalastyle:off number.of.methods
+// scalastyle:off file.size.limit
+class CosmosCatalogBase
+ extends CatalogPlugin
+ with TableCatalog
+ with BasicLoggingTrait {
+
+ private lazy val sparkSession = SparkSession.active
+ private lazy val sparkEnvironmentInfo = CosmosClientConfiguration.getSparkEnvironmentInfo(SparkSession.getActiveSession)
+
+ // mutable but only expected to be changed from within initialize method
+ private var catalogName: String = _
+ //private var client: CosmosAsyncClient = _
+ private var config: Map[String, String] = _
+ private var readConfig: CosmosReadConfig = _
+ private var tableOptions: Map[String, String] = _
+ private var viewRepository: Option[HDFSMetadataLog[String]] = None
+
+ /**
+ * Called to initialize configuration.
+ *
+ * This method is called once, just after the provider is instantiated.
+ *
+ * @param name the name used to identify and load this catalog
+ * @param options a case-insensitive string map of configuration
+ */
+ override def initialize(name: String,
+ options: CaseInsensitiveStringMap): Unit = {
+ this.config = CosmosConfig.getEffectiveConfig(
+ None,
+ None,
+ options.asCaseSensitiveMap().asScala.toMap)
+ this.readConfig = CosmosReadConfig.parseCosmosReadConfig(config)
+
+ tableOptions = toTableConfig(options)
+ this.catalogName = name
+
+ val viewRepositoryConfig = CosmosViewRepositoryConfig.parseCosmosViewRepositoryConfig(config)
+ if (viewRepositoryConfig.metaDataPath.isDefined) {
+ this.viewRepository = Some(new HDFSMetadataLog[String](
+ this.sparkSession,
+ viewRepositoryConfig.metaDataPath.get))
+ }
+ }
+
+ /**
+ * Catalog implementations are registered to a name by adding a configuration option to Spark:
+ * spark.sql.catalog.catalog-name=com.example.YourCatalogClass.
+ * All configuration properties in the Spark configuration that share the catalog name prefix,
+ * spark.sql.catalog.catalog-name.(key)=(value) will be passed in the case insensitive
+ * string map of options in initialization with the prefix removed.
+ * name, is also passed and is the catalog's name; in this case, "catalog-name".
+ *
+ * @return catalog name
+ */
+ override def name(): String = catalogName
+
+ /**
+ * List top-level namespaces from the catalog.
+ *
+ * If an object such as a table, view, or function exists, its parent namespaces must also exist
+ * and must be returned by this discovery method. For example, if table a.t exists, this method
+ * must return ["a"] in the result array.
+ *
+ * @return an array of multi-part namespace names.
+ */
+ def listNamespacesBase(): Array[Array[String]] = {
+ logDebug("catalog:listNamespaces")
+
+ TransientErrorsRetryPolicy.executeWithRetry(() => listNamespacesImpl())
+ }
+
+ private[this] def listNamespacesImpl(): Array[Array[String]] = {
+ logDebug("catalog:listNamespaces")
+
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(CosmosClientCache(
+ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo),
+ None,
+ s"CosmosCatalog(name $catalogName).listNamespaces"
+ ))
+ ))
+ .to(cosmosClientCacheItems => {
+ cosmosClientCacheItems(0)
+ .get
+ .sparkCatalogClient
+ .readAllDatabases()
+ .map(Array(_))
+ .collectSeq()
+ .block()
+ .toArray
+ })
+ }
+
+ /**
+ * List namespaces in a namespace.
+ *
+ * Cosmos supports only single depth database. Hence we always return an empty list of namespaces.
+ * or throw if the root namespace doesn't exist
+ */
+ @throws(classOf[NoSuchNamespaceException])
+ def listNamespacesBase(namespace: Array[String]): Array[Array[String]] = {
+ loadNamespaceMetadataBase(namespace) // throws NoSuchNamespaceException if namespace doesn't exist
+ // Cosmos DB only has one single level depth databases
+ Array.empty[Array[String]]
+ }
+
+ /**
+ * Load metadata properties for a namespace.
+ *
+ * @param namespace a multi-part namespace
+ * @return a string map of properties for the given namespace
+ * @throws NoSuchNamespaceException If the namespace does not exist (optional)
+ */
+ @throws(classOf[NoSuchNamespaceException])
+ def loadNamespaceMetadataBase(namespace: Array[String]): util.Map[String, String] = {
+
+ TransientErrorsRetryPolicy.executeWithRetry(() => loadNamespaceMetadataImpl(namespace))
+ }
+
+ private[this] def loadNamespaceMetadataImpl(
+ namespace: Array[String]): util.Map[String, String] = {
+
+ checkNamespace(namespace)
+
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(CosmosClientCache(
+ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo),
+ None,
+ s"CosmosCatalog(name $catalogName).loadNamespaceMetadata([${namespace.mkString(", ")}])"
+ ))
+ ))
+ .to(clientCacheItems => {
+ try {
+ clientCacheItems(0)
+ .get
+ .sparkCatalogClient
+ .readDatabaseThroughput(toCosmosDatabaseName(namespace.head))
+ .block()
+ .asJava
+ } catch {
+ case _: CosmosCatalogNotFoundException =>
+ throw new NoSuchNamespaceException(namespace)
+ }
+ })
+ }
+
+ @throws(classOf[NamespaceAlreadyExistsException])
+ def createNamespaceBase(namespace: Array[String],
+ metadata: util.Map[String, String]): Unit = {
+ TransientErrorsRetryPolicy.executeWithRetry(() => createNamespaceImpl(namespace, metadata))
+ }
+
+ @throws(classOf[NamespaceAlreadyExistsException])
+ private[this] def createNamespaceImpl(namespace: Array[String],
+ metadata: util.Map[String, String]): Unit = {
+ checkNamespace(namespace)
+ val databaseName = toCosmosDatabaseName(namespace.head)
+
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(CosmosClientCache(
+ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo),
+ None,
+ s"CosmosCatalog(name $catalogName).createNamespace([${namespace.mkString(", ")}])"
+ ))
+ ))
+ .to(cosmosClientCacheItems => {
+ try {
+ cosmosClientCacheItems(0)
+ .get
+ .sparkCatalogClient
+ .createDatabase(databaseName, metadata.asScala.toMap)
+ .block()
+ } catch {
+ case _: CosmosCatalogConflictException =>
+ throw new NamespaceAlreadyExistsException(namespace)
+ }
+ })
+ }
+
+ @throws(classOf[UnsupportedOperationException])
+ def alterNamespaceBase(namespace: Array[String],
+ changes: Seq[NamespaceChange]): Unit = {
+ checkNamespace(namespace)
+
+ if (changes.size > 0) {
+ val invalidChangesCount = changes
+ .count(change => !CosmosThroughputProperties.isThroughputProperty(change))
+ if (invalidChangesCount > 0) {
+ throw new UnsupportedOperationException("ALTER NAMESPACE contains unsupported changes.")
+ }
+
+ val finalThroughputProperty = changes.last.asInstanceOf[NamespaceChange.SetProperty]
+
+ val databaseName = toCosmosDatabaseName(namespace.head)
+
+ alterNamespaceImpl(databaseName, finalThroughputProperty)
+ }
+ }
+
+ //scalastyle:off method.length
+ private def alterNamespaceImpl(databaseName: String, finalThroughputProperty: NamespaceChange.SetProperty): Unit = {
+ logInfo(s"alterNamespace DB:$databaseName")
+
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(CosmosClientCache(
+ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo),
+ None,
+ s"CosmosCatalog(name $catalogName).alterNamespace($databaseName)"
+ ))
+ ))
+ .to(cosmosClientCacheItems => {
+ cosmosClientCacheItems(0).get
+ .sparkCatalogClient
+ .alterDatabase(databaseName, finalThroughputProperty)
+ .block()
+ })
+ }
+ //scalastyle:on method.length
+
+ /**
+ * Drop a namespace from the catalog, recursively dropping all objects within the namespace.
+ *
+ * @param namespace - a multi-part namespace
+ * @return true if the namespace was dropped
+ */
+ @throws(classOf[NoSuchNamespaceException])
+ def dropNamespaceBase(namespace: Array[String]): Boolean = {
+ TransientErrorsRetryPolicy.executeWithRetry(() => dropNamespaceImpl(namespace))
+ }
+
+ @throws(classOf[NoSuchNamespaceException])
+ private[this] def dropNamespaceImpl(namespace: Array[String]): Boolean = {
+ checkNamespace(namespace)
+ try {
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(CosmosClientCache(
+ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo),
+ None,
+ s"CosmosCatalog(name $catalogName).dropNamespace([${namespace.mkString(", ")}])"
+ ))
+ ))
+ .to(cosmosClientCacheItems => {
+ cosmosClientCacheItems(0)
+ .get
+ .sparkCatalogClient
+ .deleteDatabase(toCosmosDatabaseName(namespace.head))
+ .block()
+ })
+ true
+ } catch {
+ case _: CosmosCatalogNotFoundException =>
+ throw new NoSuchNamespaceException(namespace)
+ }
+ }
+
+ override def listTables(namespace: Array[String]): Array[Identifier] = {
+ TransientErrorsRetryPolicy.executeWithRetry(() => listTablesImpl(namespace))
+ }
+
+ private[this] def listTablesImpl(namespace: Array[String]): Array[Identifier] = {
+ checkNamespace(namespace)
+ val databaseName = toCosmosDatabaseName(namespace.head)
+
+ try {
+ val cosmosTables =
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(CosmosClientCache(
+ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo),
+ None,
+ s"CosmosCatalog(name $catalogName).listTables([${namespace.mkString(", ")}])"
+ ))
+ ))
+ .to(cosmosClientCacheItems => {
+ cosmosClientCacheItems(0).get
+ .sparkCatalogClient
+ .readAllContainers(databaseName)
+ .map(containerId => getContainerIdentifier(namespace.head, containerId))
+ .collectSeq()
+ .block()
+ .toList
+ })
+
+ val tableIdentifiers = this.tryGetViewDefinitions(databaseName) match {
+ case Some(viewDefinitions) =>
+ cosmosTables ++ viewDefinitions.map(viewDef => getContainerIdentifier(namespace.head, viewDef)).toIterable
+ case None => cosmosTables
+ }
+
+ tableIdentifiers.toArray
+ } catch {
+ case _: CosmosCatalogNotFoundException =>
+ throw new NoSuchNamespaceException(namespace)
+ }
+ }
+
+ override def loadTable(ident: Identifier): Table = {
+ TransientErrorsRetryPolicy.executeWithRetry(() => loadTableImpl(ident))
+ }
+
+ private[this] def loadTableImpl(ident: Identifier): Table = {
+ checkNamespace(ident.namespace())
+ val databaseName = toCosmosDatabaseName(ident.namespace().head)
+ val containerName = toCosmosContainerName(ident.name())
+ logInfo(s"loadTable DB:$databaseName, Container: $containerName")
+
+ this.tryGetContainerMetadata(databaseName, containerName) match {
+ case Some(tableProperties) =>
+ new ItemsTable(
+ sparkSession,
+ Array[Transform](),
+ Some(databaseName),
+ Some(containerName),
+ tableOptions.asJava,
+ None,
+ tableProperties)
+ case None =>
+ this.tryGetViewDefinition(databaseName, containerName) match {
+ case Some(viewDefinition) =>
+ val effectiveOptions = tableOptions ++ viewDefinition.options
+ new ItemsReadOnlyTable(
+ sparkSession,
+ Array[Transform](),
+ None,
+ None,
+ effectiveOptions.asJava,
+ viewDefinition.userProvidedSchema)
+ case None =>
+ throw new NoSuchTableException(ident)
+ }
+ }
+ }
+
+ override def createTable(ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+
+ TransientErrorsRetryPolicy.executeWithRetry(() =>
+ createTableImpl(ident, schema, partitions, properties))
+ }
+
+ private[this] def createTableImpl(ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ checkNamespace(ident.namespace())
+
+ val databaseName = toCosmosDatabaseName(ident.namespace().head)
+ val containerName = toCosmosContainerName(ident.name())
+ val containerProperties = properties.asScala.toMap
+
+ if (CosmosViewRepositoryConfig.isCosmosView(containerProperties)) {
+ createViewTable(ident, databaseName, containerName, schema, partitions, containerProperties)
+ } else {
+ createPhysicalTable(databaseName, containerName, schema, partitions, containerProperties)
+ }
+ }
+
+ @throws(classOf[UnsupportedOperationException])
+ override def alterTable(ident: Identifier, changes: TableChange*): Table = {
+ checkNamespace(ident.namespace())
+
+ if (changes.size > 0) {
+ val invalidChangesCount = changes
+ .count(change => !CosmosThroughputProperties.isThroughputProperty(change))
+ if (invalidChangesCount > 0) {
+ throw new UnsupportedOperationException("ALTER TABLE contains unsupported changes.")
+ }
+
+ val finalThroughputProperty = changes.last.asInstanceOf[TableChange.SetProperty]
+
+ val tableBeforeModification = loadTableImpl(ident)
+ if (!tableBeforeModification.isInstanceOf[ItemsTable]) {
+ throw new UnsupportedOperationException("ALTER TABLE cannot be applied to Cosmos views.")
+ }
+
+ val databaseName = toCosmosDatabaseName(ident.namespace().head)
+ val containerName = toCosmosContainerName(ident.name())
+
+ alterPhysicalTable(databaseName, containerName, finalThroughputProperty)
+ }
+
+ loadTableImpl(ident)
+ }
+
+ override def dropTable(ident: Identifier): Boolean = {
+ TransientErrorsRetryPolicy.executeWithRetry(() => dropTableImpl(ident))
+ }
+
+ private[this] def dropTableImpl(ident: Identifier): Boolean = {
+ checkNamespace(ident.namespace())
+
+ val databaseName = toCosmosDatabaseName(ident.namespace().head)
+ val containerName = toCosmosContainerName(ident.name())
+
+ if (deleteViewTable(databaseName, containerName)) {
+ true
+ } else {
+ this.deletePhysicalTable(databaseName, containerName)
+ }
+ }
+
+ @throws(classOf[UnsupportedOperationException])
+ override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = {
+ throw new UnsupportedOperationException("renaming table not supported")
+ }
+
+ //scalastyle:off method.length
+ private def createPhysicalTable(databaseName: String,
+ containerName: String,
+ schema: StructType,
+ partitions: Array[Transform],
+ containerProperties: Map[String, String]): Table = {
+ logInfo(s"createPhysicalTable DB:$databaseName, Container: $containerName")
+
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(CosmosClientCache(
+ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo),
+ None,
+ s"CosmosCatalog(name $catalogName).createPhysicalTable($databaseName, $containerName)"
+ ))
+ ))
+ .to(cosmosClientCacheItems => {
+ cosmosClientCacheItems(0).get
+ .sparkCatalogClient
+ .createContainer(databaseName, containerName, containerProperties)
+ .block()
+ })
+
+ val effectiveOptions = tableOptions ++ containerProperties
+
+ new ItemsTable(
+ sparkSession,
+ partitions,
+ Some(databaseName),
+ Some(containerName),
+ effectiveOptions.asJava,
+ Option.apply(schema))
+ }
+ //scalastyle:on method.length
+
+ //scalastyle:off method.length
+ @tailrec
+ private def createViewTable(ident: Identifier,
+ databaseName: String,
+ viewName: String,
+ schema: StructType,
+ partitions: Array[Transform],
+ containerProperties: Map[String, String]): Table = {
+
+ logInfo(s"createViewTable DB:$databaseName, View: $viewName")
+
+ this.viewRepository match {
+ case Some(viewRepositorySnapshot) =>
+ val userProvidedSchema = if (schema != null && schema.length > 0) {
+ Some(schema)
+ } else {
+ None
+ }
+ val viewDefinition = ViewDefinition(
+ databaseName, viewName, userProvidedSchema, redactAuthInfo(containerProperties))
+ var lastBatchId = 0L
+ val newViewDefinitionsSnapshot = viewRepositorySnapshot.getLatest() match {
+ case Some(viewDefinitionsEnvelopeSnapshot) =>
+ lastBatchId = viewDefinitionsEnvelopeSnapshot._1
+ val alreadyExistingViews = ViewDefinitionEnvelopeSerializer.fromJson(viewDefinitionsEnvelopeSnapshot._2)
+
+ if (alreadyExistingViews.exists(v => v.databaseName.equals(databaseName) &&
+ v.viewName.equals(viewName))) {
+
+ throw new IllegalArgumentException(s"View '$viewName' already exists in database '$databaseName'")
+ }
+
+ alreadyExistingViews ++ Array(viewDefinition)
+ case None => Array(viewDefinition)
+ }
+
+ if (viewRepositorySnapshot.add(
+ lastBatchId + 1,
+ ViewDefinitionEnvelopeSerializer.toJson(newViewDefinitionsSnapshot))) {
+
+ logInfo(s"LatestBatchId: ${viewRepositorySnapshot.getLatestBatchId().getOrElse(-1)}")
+ viewRepositorySnapshot.purge(lastBatchId)
+ logInfo(s"LatestBatchId: ${viewRepositorySnapshot.getLatestBatchId().getOrElse(-1)}")
+ val effectiveOptions = tableOptions ++ viewDefinition.options
+
+ new ItemsReadOnlyTable(
+ sparkSession,
+ partitions,
+ None,
+ None,
+ effectiveOptions.asJava,
+ userProvidedSchema)
+ } else {
+ createViewTable(ident, databaseName, viewName, schema, partitions, containerProperties)
+ }
+ case None =>
+ throw new IllegalArgumentException(
+ s"Catalog configuration for '${CosmosViewRepositoryConfig.MetaDataPathKeyName}' must " +
+ "be set when creating views'")
+ }
+ }
+ //scalastyle:on method.length
+
+ //scalastyle:off method.length
+ private def alterPhysicalTable(databaseName: String,
+ containerName: String,
+ finalThroughputProperty: TableChange.SetProperty): Unit = {
+ logInfo(s"alterPhysicalTable DB:$databaseName, Container: $containerName")
+
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(CosmosClientCache(
+ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo),
+ None,
+ s"CosmosCatalog(name $catalogName).alterPhysicalTable($databaseName, $containerName)"
+ ))
+ ))
+ .to(cosmosClientCacheItems => {
+ cosmosClientCacheItems(0).get
+ .sparkCatalogClient
+ .alterContainer(databaseName, containerName, finalThroughputProperty)
+ .block()
+ })
+ }
+ //scalastyle:on method.length
+
+ private def deletePhysicalTable(databaseName: String, containerName: String): Boolean = {
+ try {
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(CosmosClientCache(
+ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo),
+ None,
+ s"CosmosCatalog(name $catalogName).deletePhysicalTable($databaseName, $containerName)"
+ ))
+ ))
+ .to (cosmosClientCacheItems =>
+ cosmosClientCacheItems(0).get
+ .sparkCatalogClient
+ .deleteContainer(databaseName, containerName))
+ .block()
+ true
+ } catch {
+ case _: CosmosCatalogNotFoundException => false
+ }
+ }
+
+ @tailrec
+ private def deleteViewTable(databaseName: String, viewName: String): Boolean = {
+ logInfo(s"deleteViewTable DB:$databaseName, View: $viewName")
+
+ this.viewRepository match {
+ case Some(viewRepositorySnapshot) =>
+ viewRepositorySnapshot.getLatest() match {
+ case Some(viewDefinitionsEnvelopeSnapshot) =>
+ val lastBatchId = viewDefinitionsEnvelopeSnapshot._1
+ val viewDefinitions = ViewDefinitionEnvelopeSerializer.fromJson(viewDefinitionsEnvelopeSnapshot._2)
+
+ viewDefinitions.find(v => v.databaseName.equals(databaseName) &&
+ v.viewName.equals(viewName)) match {
+ case Some(existingView) =>
+ val updatedViewDefinitionsSnapshot: Array[ViewDefinition] =
+ ArrayBuffer(viewDefinitions: _*).filterNot(_ == existingView).toArray
+
+ if (viewRepositorySnapshot.add(
+ lastBatchId + 1,
+ ViewDefinitionEnvelopeSerializer.toJson(updatedViewDefinitionsSnapshot))) {
+
+ viewRepositorySnapshot.purge(lastBatchId)
+ true
+ } else {
+ deleteViewTable(databaseName, viewName)
+ }
+ case None => false
+ }
+ case None => false
+ }
+ case None =>
+ false
+ }
+ }
+
+ //scalastyle:off method.length
+ private def tryGetContainerMetadata
+ (
+ databaseName: String,
+ containerName: String
+ ): Option[util.HashMap[String, String]] = {
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(CosmosClientCache(
+ CosmosClientConfiguration(config, readConfig.readConsistencyStrategy, sparkEnvironmentInfo),
+ None,
+ s"CosmosCatalog(name $catalogName).tryGetContainerMetadata($databaseName, $containerName)"
+ ))
+ ))
+ .to(cosmosClientCacheItems => {
+ cosmosClientCacheItems(0)
+ .get
+ .sparkCatalogClient
+ .readContainerMetadata(databaseName, containerName)
+ .block()
+ })
+ }
+ //scalastyle:on method.length
+
+ private def tryGetViewDefinition(databaseName: String,
+ containerName: String): Option[ViewDefinition] = {
+
+ this.tryGetViewDefinitions(databaseName) match {
+ case Some(viewDefinitions) =>
+ viewDefinitions.find(v => databaseName.equals(v.databaseName) &&
+ containerName.equals(v.viewName))
+ case None => None
+ }
+ }
+
+ private def tryGetViewDefinitions(databaseName: String): Option[Array[ViewDefinition]] = {
+
+ this.viewRepository match {
+ case Some(viewRepositorySnapshot) =>
+ viewRepositorySnapshot.getLatest() match {
+ case Some(latestMetadataSnapshot) =>
+ val viewDefinitions = ViewDefinitionEnvelopeSerializer.fromJson(latestMetadataSnapshot._2)
+ .filter(v => databaseName.equals(v.databaseName))
+ if (viewDefinitions.length > 0) {
+ Some(viewDefinitions)
+ } else {
+ None
+ }
+ case None => None
+ }
+ case None => None
+ }
+ }
+
+ private def getContainerIdentifier(
+ namespaceName: String,
+ containerId: String): Identifier = {
+ Identifier.of(Array(namespaceName), containerId)
+ }
+
+ private def getContainerIdentifier
+ (
+ namespaceName: String,
+ viewDefinition: ViewDefinition
+ ): Identifier = {
+
+ Identifier.of(Array(namespaceName), viewDefinition.viewName)
+ }
+
+ private def checkNamespace(namespace: Array[String]): Unit = {
+ if (namespace == null || namespace.length != 1) {
+ throw new CosmosCatalogException(
+ s"invalid namespace ${namespace.mkString("Array(", ", ", ")")}." +
+ s" Cosmos DB already support single depth namespace.")
+ }
+ }
+
+ private def toCosmosDatabaseName(namespace: String): String = {
+ namespace
+ }
+
+ private def toCosmosContainerName(tableIdent: String): String = {
+ tableIdent
+ }
+
+ private def toTableConfig(options: CaseInsensitiveStringMap): Map[String, String] = {
+ options.asCaseSensitiveMap().asScala.toMap
+ }
+
+
+ private def redactAuthInfo(cfg: Map[String, String]): Map[String, String] = {
+ cfg.filter((kvp) => !CosmosConfigNames.AccountEndpoint.equalsIgnoreCase(kvp._1) &&
+ !CosmosConfigNames.AccountKey.equalsIgnoreCase(kvp._1) &&
+ !kvp._1.toLowerCase.contains(CosmosConfigNames.AccountEndpoint.toLowerCase()) &&
+ !kvp._1.toLowerCase.contains(CosmosConfigNames.AccountKey.toLowerCase())
+ )
+ }
+}
+// scalastyle:on multiple.string.literals
+// scalastyle:on number.of.methods
+// scalastyle:on file.size.limit
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRecordsWrittenMetric.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRecordsWrittenMetric.scala
new file mode 100644
index 000000000000..8814c59d0c7d
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRecordsWrittenMetric.scala
@@ -0,0 +1,11 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import org.apache.spark.sql.connector.metric.CustomSumMetric
+
+private[cosmos] class CosmosRecordsWrittenMetric extends CustomSumMetric {
+ override def name(): String = CosmosConstants.MetricNames.RecordsWritten
+
+ override def description(): String = CosmosConstants.MetricNames.RecordsWritten
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRowConverter.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRowConverter.scala
new file mode 100644
index 000000000000..fb4e9db760a0
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosRowConverter.scala
@@ -0,0 +1,127 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.spark.SchemaConversionModes.SchemaConversionMode
+import com.fasterxml.jackson.annotation.JsonInclude.Include
+// scalastyle:off underscore.import
+import com.fasterxml.jackson.databind.node._
+import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
+import java.time.format.DateTimeFormatter
+import java.time.LocalDateTime
+import scala.collection.concurrent.TrieMap
+
+// scalastyle:off underscore.import
+import org.apache.spark.sql.types._
+// scalastyle:on underscore.import
+
+import scala.util.{Try, Success, Failure}
+
+// scalastyle:off
+private[cosmos] object CosmosRowConverter {
+
+ // TODO: Expose configuration to handle duplicate fields
+ // See: https://github.com/Azure/azure-sdk-for-java/pull/18642#discussion_r558638474
+ private val rowConverterMap = new TrieMap[CosmosSerializationConfig, CosmosRowConverter]
+
+ def get(serializationConfig: CosmosSerializationConfig): CosmosRowConverter = {
+ rowConverterMap.get(serializationConfig) match {
+ case Some(existingRowConverter) => existingRowConverter
+ case None =>
+ val newRowConverterCandidate = createRowConverter(serializationConfig)
+ rowConverterMap.putIfAbsent(serializationConfig, newRowConverterCandidate) match {
+ case Some(existingConcurrentlyCreatedRowConverter) => existingConcurrentlyCreatedRowConverter
+ case None => newRowConverterCandidate
+ }
+ }
+ }
+
+ private def createRowConverter(serializationConfig: CosmosSerializationConfig): CosmosRowConverter = {
+ val objectMapper = new ObjectMapper()
+ import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
+ objectMapper.registerModule(new JavaTimeModule)
+ serializationConfig.serializationInclusionMode match {
+ case SerializationInclusionModes.NonNull => objectMapper.setSerializationInclusion(Include.NON_NULL)
+ case SerializationInclusionModes.NonEmpty => objectMapper.setSerializationInclusion(Include.NON_EMPTY)
+ case SerializationInclusionModes.NonDefault => objectMapper.setSerializationInclusion(Include.NON_DEFAULT)
+ case _ => objectMapper.setSerializationInclusion(Include.ALWAYS)
+ }
+
+ new CosmosRowConverter(objectMapper, serializationConfig)
+ }
+}
+
+private[cosmos] class CosmosRowConverter(private val objectMapper: ObjectMapper, private val serializationConfig: CosmosSerializationConfig)
+ extends CosmosRowConverterBase(objectMapper, serializationConfig) {
+
+ override def convertSparkDataTypeToJsonNodeConditionallyForSparkRuntimeSpecificDataType
+ (
+ fieldType: DataType,
+ rowData: Any
+ ): Option[JsonNode] = {
+ fieldType match {
+ case TimestampNTZType if rowData.isInstanceOf[java.time.LocalDateTime] => convertToJsonNodeConditionally(rowData.asInstanceOf[java.time.LocalDateTime].toString)
+ case _ =>
+ throw new Exception(s"Cannot cast $rowData into a Json value. $fieldType has no matching Json value.")
+ }
+ }
+
+ override def convertSparkDataTypeToJsonNodeNonNullForSparkRuntimeSpecificDataType(fieldType: DataType, rowData: Any): JsonNode = {
+ fieldType match {
+ case TimestampNTZType if rowData.isInstanceOf[java.time.LocalDateTime] => objectMapper.convertValue(rowData.asInstanceOf[java.time.LocalDateTime].toString, classOf[JsonNode])
+ case _ =>
+ throw new Exception(s"Cannot cast $rowData into a Json value. $fieldType has no matching Json value.")
+ }
+ }
+
+ override def convertToSparkDataTypeForSparkRuntimeSpecificDataType
+ (dataType: DataType,
+ value: JsonNode,
+ schemaConversionMode: SchemaConversionMode): Any =
+ (value, dataType) match {
+ case (_, _: TimestampNTZType) => handleConversionErrors(() => toTimestampNTZ(value), schemaConversionMode)
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Unsupported datatype conversion [Value: $value] of ${value.getClass}] to $dataType]")
+ }
+
+
+ def toTimestampNTZ(value: JsonNode): LocalDateTime = {
+ value match {
+ case isJsonNumber() => LocalDateTime.parse(value.asText())
+ case textNode: TextNode =>
+ parseDateTimeNTZFromString(textNode.asText()) match {
+ case Some(odt) => odt
+ case None =>
+ throw new IllegalArgumentException(
+ s"Value '${textNode.asText()} cannot be parsed as LocalDateTime (TIMESTAMP_NTZ).")
+ }
+ case _ => LocalDateTime.parse(value.asText())
+ }
+ }
+
+ private def handleConversionErrors[A] = (conversion: () => A,
+ schemaConversionMode: SchemaConversionMode) => {
+ Try(conversion()) match {
+ case Success(convertedValue) => convertedValue
+ case Failure(error) =>
+ if (schemaConversionMode == SchemaConversionModes.Relaxed) {
+ null
+ }
+ else {
+ throw error
+ }
+ }
+ }
+
+ def parseDateTimeNTZFromString(value: String): Option[LocalDateTime] = {
+ try {
+ val odt = LocalDateTime.parse(value, DateTimeFormatter.ISO_DATE_TIME)
+ Some(odt)
+ }
+ catch {
+ case _: Exception => None
+ }
+ }
+
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosWriter.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosWriter.scala
new file mode 100644
index 000000000000..042c6ca5636e
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/CosmosWriter.scala
@@ -0,0 +1,109 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.CosmosDiagnosticsContext
+import com.azure.cosmos.implementation.ImplementationBridgeHelpers
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
+import org.apache.spark.sql.connector.write.WriterCommitMessage
+import org.apache.spark.sql.execution.metric.CustomMetrics
+import org.apache.spark.sql.types.StructType
+
+import java.util.concurrent.atomic.AtomicLong
+
+private class CosmosWriter(
+ userConfig: Map[String, String],
+ cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots],
+ diagnosticsConfig: DiagnosticsConfig,
+ inputSchema: StructType,
+ partitionId: Int,
+ taskId: Long,
+ epochId: Option[Long],
+ sparkEnvironmentInfo: String)
+ extends CosmosWriterBase(
+ userConfig,
+ cosmosClientStateHandles,
+ diagnosticsConfig,
+ inputSchema,
+ partitionId,
+ taskId,
+ epochId,
+ sparkEnvironmentInfo
+ ) with OutputMetricsPublisherTrait {
+
+ private val recordsWritten = new AtomicLong(0)
+ private val bytesWritten = new AtomicLong(0)
+ private val totalRequestCharge = new AtomicLong(0)
+
+ private val recordsWrittenMetric = new CustomTaskMetric {
+ override def name(): String = CosmosConstants.MetricNames.RecordsWritten
+ override def value(): Long = recordsWritten.get()
+ }
+
+ private val bytesWrittenMetric = new CustomTaskMetric {
+ override def name(): String = CosmosConstants.MetricNames.BytesWritten
+
+ override def value(): Long = bytesWritten.get()
+ }
+
+ private val totalRequestChargeMetric = new CustomTaskMetric {
+ override def name(): String = CosmosConstants.MetricNames.TotalRequestCharge
+
+ // Internally we capture RU/s up to 2 fractional digits to have more precise rounding
+ override def value(): Long = totalRequestCharge.get() / 100L
+ }
+
+ private val metrics = Array(recordsWrittenMetric, bytesWrittenMetric, totalRequestChargeMetric)
+
+ override def currentMetricsValues(): Array[CustomTaskMetric] = {
+ metrics
+ }
+
+ override def getOutputMetricsPublisher(): OutputMetricsPublisherTrait = this
+
+ override def trackWriteOperation(recordCount: Long, diagnostics: Option[CosmosDiagnosticsContext]): Unit = {
+ if (recordCount > 0) {
+ recordsWritten.addAndGet(recordCount)
+ }
+
+ diagnostics match {
+ case Some(ctx) =>
+ // Capturing RU/s with 2 fractional digits internally
+ totalRequestCharge.addAndGet((ctx.getTotalRequestCharge * 100L).toLong)
+ bytesWritten.addAndGet(
+ if (ImplementationBridgeHelpers
+ .CosmosDiagnosticsContextHelper
+ .getCosmosDiagnosticsContextAccessor
+ .getOperationType(ctx)
+ .isReadOnlyOperation) {
+
+ ctx.getMaxRequestPayloadSizeInBytes + ctx.getMaxResponsePayloadSizeInBytes
+ } else {
+ ctx.getMaxRequestPayloadSizeInBytes
+ }
+ )
+ case None =>
+ }
+ }
+
+ override def commit(): WriterCommitMessage = {
+ val commitMessage = super.commit()
+
+ // TODO @fabianm - this is a workaround - it shouldn't be necessary to do this here
+ // Unfortunately WriteToDataSourceV2Exec.scala is not updating custom metrics after the
+ // call to commit - meaning DataSources which asynchronously write data and flush in commit
+ // won't get accurate metrics because updates between the last call to write and flushing the
+ // writes are lost. See https://issues.apache.org/jira/browse/SPARK-45759
+ // Once above issue is addressed (probably in Spark 3.4.1 or 3.5 - this needs to be changed
+ //
+ // NOTE: This also means that the RU/s metrics cannot be updated in commit - so the
+ // RU/s metric at the end of a task will be slightly outdated/behind
+ CustomMetrics.updateMetrics(
+ currentMetricsValues(),
+ SparkInternalsBridge.getInternalCustomTaskMetricsAsSQLMetric(CosmosConstants.MetricNames.KnownCustomMetricNames))
+
+ commitMessage
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScan.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScan.scala
new file mode 100644
index 000000000000..1e193b9e6959
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScan.scala
@@ -0,0 +1,41 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.models.PartitionKeyDefinition
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.expressions.NamedReference
+import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+
+private[spark] class ItemsScan(session: SparkSession,
+ schema: StructType,
+ config: Map[String, String],
+ readConfig: CosmosReadConfig,
+ analyzedFilters: AnalyzedAggregatedFilters,
+ cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots],
+ diagnosticsConfig: DiagnosticsConfig,
+ sparkEnvironmentInfo: String,
+ partitionKeyDefinition: PartitionKeyDefinition)
+ extends ItemsScanBase(
+ session,
+ schema,
+ config,
+ readConfig,
+ analyzedFilters,
+ cosmosClientStateHandles,
+ diagnosticsConfig,
+ sparkEnvironmentInfo,
+ partitionKeyDefinition)
+ with SupportsRuntimeFiltering { // SupportsRuntimeFiltering extends scan
+ override def filterAttributes(): Array[NamedReference] = {
+ runtimeFilterAttributesCore()
+ }
+
+ override def filter(filters: Array[Filter]): Unit = {
+ runtimeFilterCore(filters)
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScanBuilder.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScanBuilder.scala
new file mode 100644
index 000000000000..340a40585eb0
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsScanBuilder.scala
@@ -0,0 +1,137 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.SparkBridgeInternal
+import com.azure.cosmos.models.PartitionKeyDefinition
+import com.azure.cosmos.spark.diagnostics.LoggerHelper
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+// scalastyle:off underscore.import
+import scala.collection.JavaConverters._
+// scalastyle:on underscore.import
+
+private case class ItemsScanBuilder(session: SparkSession,
+ config: CaseInsensitiveStringMap,
+ inputSchema: StructType,
+ cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots],
+ diagnosticsConfig: DiagnosticsConfig,
+ sparkEnvironmentInfo: String)
+ extends ScanBuilder
+ with SupportsPushDownFilters
+ with SupportsPushDownRequiredColumns {
+
+ @transient private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass)
+ log.logTrace(s"Instantiated ${this.getClass.getSimpleName}")
+
+ private val configMap = config.asScala.toMap
+ private val readConfig = CosmosReadConfig.parseCosmosReadConfig(configMap)
+ private var processedPredicates : Option[AnalyzedAggregatedFilters] = Option.empty
+
+ private val clientConfiguration = CosmosClientConfiguration.apply(
+ configMap,
+ readConfig.readConsistencyStrategy,
+ CosmosClientConfiguration.getSparkEnvironmentInfo(Some(session))
+ )
+ private val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(configMap)
+ private val description = {
+ s"""Cosmos ItemsScanBuilder: ${containerConfig.database}.${containerConfig.container}""".stripMargin
+ }
+
+ private val partitionKeyDefinition: PartitionKeyDefinition = {
+ TransientErrorsRetryPolicy.executeWithRetry(() => {
+ val calledFrom = s"ItemsScan($description()).getPartitionKeyDefinition"
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(CosmosClientCache.apply(
+ clientConfiguration,
+ Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches),
+ calledFrom
+ )),
+ ThroughputControlHelper.getThroughputControlClientCacheItem(
+ configMap, calledFrom, Some(cosmosClientStateHandles), sparkEnvironmentInfo)
+ ))
+ .to(clientCacheItems => {
+ val container =
+ ThroughputControlHelper.getContainer(
+ configMap,
+ containerConfig,
+ clientCacheItems(0).get,
+ clientCacheItems(1))
+
+ SparkBridgeInternal
+ .getContainerPropertiesFromCollectionCache(container)
+ .getPartitionKeyDefinition()
+ })
+ })
+ }
+
+ private val filterAnalyzer = FilterAnalyzer(readConfig, partitionKeyDefinition)
+
+ /**
+ * Pushes down filters, and returns filters that need to be evaluated after scanning.
+ * @param filters pushed down filters.
+ * @return the filters that spark need to evaluate
+ */
+ override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ this.processedPredicates = Option.apply(filterAnalyzer.analyze(filters))
+
+ // return the filters that spark need to evaluate
+ this.processedPredicates.get.filtersNotSupportedByCosmos
+ }
+
+ /**
+ * Returns the filters that are pushed to Cosmos as query predicates
+ * @return filters to be pushed to cosmos db.
+ */
+ override def pushedFilters: Array[Filter] = {
+ if (this.processedPredicates.isDefined) {
+ this.processedPredicates.get.filtersToBePushedDownToCosmos
+ } else {
+ Array[Filter]()
+ }
+ }
+
+ override def build(): Scan = {
+ val effectiveAnalyzedFilters = this.processedPredicates match {
+ case Some(analyzedFilters) => analyzedFilters
+ case None => filterAnalyzer.analyze(Array.empty[Filter])
+ }
+
+ // TODO when inferring schema we should consolidate the schema from pruneColumns
+ new ItemsScan(
+ session,
+ inputSchema,
+ this.configMap,
+ this.readConfig,
+ effectiveAnalyzedFilters,
+ cosmosClientStateHandles,
+ diagnosticsConfig,
+ sparkEnvironmentInfo,
+ partitionKeyDefinition)
+ }
+
+ /**
+ * Applies column pruning w.r.t. the given requiredSchema.
+ *
+ * Implementation should try its best to prune the unnecessary columns or nested fields, but it's
+ * also OK to do the pruning partially, e.g., a data source may not be able to prune nested
+ * fields, and only prune top-level columns.
+ *
+ * Note that, `Scan` implementation should take care of the column
+ * pruning applied here.
+ */
+ override def pruneColumns(requiredSchema: StructType): Unit = {
+ // TODO: we need to decide whether do a push down or not on the projection
+ // spark will do column pruning on the returned data.
+ // pushing down projection to cosmos has tradeoffs:
+ // - it increases consumed RU in cosmos query engine
+ // - it decrease the networking layer latency
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsWriterBuilder.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsWriterBuilder.scala
new file mode 100644
index 000000000000..ea759335091b
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/ItemsWriterBuilder.scala
@@ -0,0 +1,185 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.{CosmosAsyncClient, ReadConsistencyStrategy, SparkBridgeInternal}
+import com.azure.cosmos.spark.diagnostics.LoggerHelper
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
+import org.apache.spark.sql.connector.expressions.{Expression, Expressions, NullOrdering, SortDirection, SortOrder}
+import org.apache.spark.sql.connector.metric.CustomMetric
+import org.apache.spark.sql.connector.write.streaming.StreamingWrite
+import org.apache.spark.sql.connector.write.{BatchWrite, RequiresDistributionAndOrdering, Write, WriteBuilder}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+// scalastyle:off underscore.import
+import scala.collection.JavaConverters._
+// scalastyle:on underscore.import
+
+private class ItemsWriterBuilder
+(
+ userConfig: CaseInsensitiveStringMap,
+ inputSchema: StructType,
+ cosmosClientStateHandles: Broadcast[CosmosClientMetadataCachesSnapshots],
+ diagnosticsConfig: DiagnosticsConfig,
+ sparkEnvironmentInfo: String
+)
+ extends WriteBuilder {
+ @transient private lazy val log = LoggerHelper.getLogger(diagnosticsConfig, this.getClass)
+ log.logTrace(s"Instantiated ${this.getClass.getSimpleName}")
+
+ override def build(): Write = {
+ new CosmosWrite
+ }
+
+ override def buildForBatch(): BatchWrite =
+ new ItemsBatchWriter(
+ userConfig.asCaseSensitiveMap().asScala.toMap,
+ inputSchema,
+ cosmosClientStateHandles,
+ diagnosticsConfig,
+ sparkEnvironmentInfo)
+
+ override def buildForStreaming(): StreamingWrite =
+ new ItemsBatchWriter(
+ userConfig.asCaseSensitiveMap().asScala.toMap,
+ inputSchema,
+ cosmosClientStateHandles,
+ diagnosticsConfig,
+ sparkEnvironmentInfo)
+
+ private class CosmosWrite extends Write with RequiresDistributionAndOrdering {
+
+ private[this] val supportedCosmosMetrics: Array[CustomMetric] = {
+ Array(
+ new CosmosBytesWrittenMetric(),
+ new CosmosRecordsWrittenMetric(),
+ new TotalRequestChargeMetric()
+ )
+ }
+
+ // Extract userConfig conversion to avoid repeated calls
+ private[this] val userConfigMap = userConfig.asCaseSensitiveMap().asScala.toMap
+
+ private[this] val writeConfig = CosmosWriteConfig.parseWriteConfig(
+ userConfigMap,
+ inputSchema
+ )
+
+ private[this] val containerConfig = CosmosContainerConfig.parseCosmosContainerConfig(
+ userConfigMap
+ )
+
+ override def toBatch(): BatchWrite =
+ new ItemsBatchWriter(
+ userConfigMap,
+ inputSchema,
+ cosmosClientStateHandles,
+ diagnosticsConfig,
+ sparkEnvironmentInfo)
+
+ override def toStreaming: StreamingWrite =
+ new ItemsBatchWriter(
+ userConfigMap,
+ inputSchema,
+ cosmosClientStateHandles,
+ diagnosticsConfig,
+ sparkEnvironmentInfo)
+
+ override def supportedCustomMetrics(): Array[CustomMetric] = supportedCosmosMetrics
+
+ override def requiredDistribution(): Distribution = {
+ if (writeConfig.bulkEnabled && writeConfig.bulkTransactional) {
+ log.logInfo("Transactional batch mode enabled - configuring data distribution by partition key columns")
+ // For transactional writes, partition by all partition key columns
+ val partitionKeyPaths = getPartitionKeyColumnNames()
+ if (partitionKeyPaths.nonEmpty) {
+ // Use public Expressions.column() factory - returns NamedReference
+ val clustering = partitionKeyPaths.map(path => Expressions.column(path): Expression).toArray
+ Distributions.clustered(clustering)
+ } else {
+ Distributions.unspecified()
+ }
+ } else {
+ Distributions.unspecified()
+ }
+ }
+
+ override def requiredOrdering(): Array[SortOrder] = {
+ if (writeConfig.bulkEnabled && writeConfig.bulkTransactional) {
+ // For transactional writes, order by all partition key columns (ascending)
+ val partitionKeyPaths = getPartitionKeyColumnNames()
+ if (partitionKeyPaths.nonEmpty) {
+ partitionKeyPaths.map { path =>
+ // Use public Expressions.sort() factory for creating SortOrder
+ Expressions.sort(
+ Expressions.column(path),
+ SortDirection.ASCENDING,
+ NullOrdering.NULLS_FIRST
+ )
+ }.toArray
+ } else {
+ Array.empty[SortOrder]
+ }
+ } else {
+ Array.empty[SortOrder]
+ }
+ }
+
+ private def getPartitionKeyColumnNames(): Seq[String] = {
+ try {
+ Loan(
+ List[Option[CosmosClientCacheItem]](
+ Some(createClientForPartitionKeyLookup())
+ ))
+ .to(clientCacheItems => {
+ val container = ThroughputControlHelper.getContainer(
+ userConfigMap,
+ containerConfig,
+ clientCacheItems(0).get,
+ None
+ )
+
+ // Simplified retrieval using SparkBridgeInternal directly
+ val containerProperties = SparkBridgeInternal.getContainerPropertiesFromCollectionCache(container)
+ val partitionKeyDefinition = containerProperties.getPartitionKeyDefinition
+
+ extractPartitionKeyPaths(partitionKeyDefinition)
+ })
+ } catch {
+ case ex: Exception =>
+ log.logWarning(s"Failed to get partition key definition for transactional writes: ${ex.getMessage}")
+ Seq.empty[String]
+ }
+ }
+
+ private def createClientForPartitionKeyLookup(): CosmosClientCacheItem = {
+ CosmosClientCache(
+ CosmosClientConfiguration(
+ userConfigMap,
+ ReadConsistencyStrategy.EVENTUAL,
+ sparkEnvironmentInfo
+ ),
+ Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches),
+ "ItemsWriterBuilder-PKLookup"
+ )
+ }
+
+ private def extractPartitionKeyPaths(partitionKeyDefinition: com.azure.cosmos.models.PartitionKeyDefinition): Seq[String] = {
+ if (partitionKeyDefinition != null && partitionKeyDefinition.getPaths != null) {
+ val paths = partitionKeyDefinition.getPaths.asScala
+ if (paths.isEmpty) {
+ log.logError("Partition key definition has 0 columns - this should not happen for modern containers")
+ }
+ paths.map(path => {
+ // Remove leading '/' from partition key path (e.g., "/pk" -> "pk")
+ if (path.startsWith("/")) path.substring(1) else path
+ }).toSeq
+ } else {
+ log.logError("Partition key definition is null - this should not happen for modern containers")
+ Seq.empty[String]
+ }
+ }
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/RowSerializerPool.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/RowSerializerPool.scala
new file mode 100644
index 000000000000..427b8757e3e5
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/RowSerializerPool.scala
@@ -0,0 +1,29 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Spark serializers are not thread-safe - and expensive to create (dynamic code generation)
+ * So we will use this object pool to allow reusing serializers based on the targeted schema.
+ * The main purpose for pooling serializers (vs. creating new ones in each PartitionReader) is for Structured
+ * Streaming scenarios where PartitionReaders for the same schema could be created every couple of 100
+ * milliseconds
+ * A clean-up task is used to purge serializers for schemas which weren't used anymore
+ * For each schema we have an object pool that will use a soft-limit to limit the memory footprint
+ */
+private object RowSerializerPool {
+ private val serializerFactorySingletonInstance =
+ new RowSerializerPoolInstance((schema: StructType) => ExpressionEncoder.apply(schema).createSerializer())
+
+ def getOrCreateSerializer(schema: StructType): ExpressionEncoder.Serializer[Row] = {
+ serializerFactorySingletonInstance.getOrCreateSerializer(schema)
+ }
+
+ def returnSerializerToPool(schema: StructType, serializer: ExpressionEncoder.Serializer[Row]): Boolean = {
+ serializerFactorySingletonInstance.returnSerializerToPool(schema, serializer)
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/SparkInternalsBridge.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/SparkInternalsBridge.scala
new file mode 100644
index 000000000000..45d7bacef995
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/SparkInternalsBridge.scala
@@ -0,0 +1,107 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.implementation.guava25.base.MoreObjects.firstNonNull
+import com.azure.cosmos.implementation.guava25.base.Strings.emptyToNull
+import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
+import org.apache.spark.TaskContext
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.util.AccumulatorV2
+
+import java.lang.reflect.Method
+import java.util.Locale
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
+class SparkInternalsBridge {
+ // Only used in ChangeFeedMetricsListener, which is easier for test validation
+ def getInternalCustomTaskMetricsAsSQLMetric(
+ knownCosmosMetricNames: Set[String],
+ taskMetrics: TaskMetrics) : Map[String, SQLMetric] = {
+ SparkInternalsBridge.getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskMetrics)
+ }
+}
+
+object SparkInternalsBridge extends BasicLoggingTrait {
+ private val SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY = "COSMOS.SPARK_REFLECTION_ACCESS_ALLOWED"
+ private val SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE = "COSMOS_SPARK_REFLECTION_ACCESS_ALLOWED"
+
+ private val DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED = true
+ private val accumulatorsMethod : AtomicReference[Method] = new AtomicReference[Method]()
+
+ private def getSparkReflectionAccessAllowed: Boolean = {
+ val allowedText = System.getProperty(
+ SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY,
+ firstNonNull(
+ emptyToNull(System.getenv.get(SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE)),
+ String.valueOf(DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED)))
+
+ try {
+ java.lang.Boolean.valueOf(allowedText.toUpperCase(Locale.ROOT))
+ }
+ catch {
+ case e: Exception =>
+ logError(s"Parsing spark reflection access allowed $allowedText failed. Using the default $DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED.", e)
+ DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED
+ }
+ }
+
+ private final lazy val reflectionAccessAllowed = new AtomicBoolean(getSparkReflectionAccessAllowed)
+
+ def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String]) : Map[String, SQLMetric] = {
+ Option.apply(TaskContext.get()) match {
+ case Some(taskCtx) => getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames, taskCtx.taskMetrics())
+ case None => Map.empty[String, SQLMetric]
+ }
+ }
+
+ def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String], taskMetrics: TaskMetrics) : Map[String, SQLMetric] = {
+
+ if (!reflectionAccessAllowed.get) {
+ Map.empty[String, SQLMetric]
+ } else {
+ getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskMetrics)
+ }
+ }
+
+ private def getAccumulators(taskMetrics: TaskMetrics): Option[Seq[AccumulatorV2[_, _]]] = {
+ try {
+ val method = Option(accumulatorsMethod.get) match {
+ case Some(existing) => existing
+ case None =>
+ val newMethod = taskMetrics.getClass.getMethod("accumulators")
+ newMethod.setAccessible(true)
+ accumulatorsMethod.set(newMethod)
+ newMethod
+ }
+
+ val accums = method.invoke(taskMetrics).asInstanceOf[Seq[AccumulatorV2[_, _]]]
+
+ Some(accums)
+ } catch {
+ case e: Exception =>
+ logInfo(s"Could not invoke getAccumulators via reflection - Error ${e.getMessage}", e)
+
+ // reflection failed - disabling it for the future
+ reflectionAccessAllowed.set(false)
+ None
+ }
+ }
+
+ private def getInternalCustomTaskMetricsAsSQLMetricInternal(
+ knownCosmosMetricNames: Set[String],
+ taskMetrics: TaskMetrics): Map[String, SQLMetric] = {
+ getAccumulators(taskMetrics) match {
+ case Some(accumulators) => accumulators
+ .filter(accumulable => accumulable.isInstanceOf[SQLMetric]
+ && accumulable.name.isDefined
+ && knownCosmosMetricNames.contains(accumulable.name.get))
+ .map(accumulable => {
+ val sqlMetric = accumulable.asInstanceOf[SQLMetric]
+ sqlMetric.name.get -> sqlMetric
+ })
+ .toMap[String, SQLMetric]
+ case None => Map.empty[String, SQLMetric]
+ }
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/TotalRequestChargeMetric.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/TotalRequestChargeMetric.scala
new file mode 100644
index 000000000000..56d1f0ba2b78
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/main/scala/com/azure/cosmos/spark/TotalRequestChargeMetric.scala
@@ -0,0 +1,11 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import org.apache.spark.sql.connector.metric.CustomSumMetric
+
+private[cosmos] class TotalRequestChargeMetric extends CustomSumMetric {
+ override def name(): String = CosmosConstants.MetricNames.TotalRequestCharge
+
+ override def description(): String = CosmosConstants.MetricNames.TotalRequestCharge
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientBuilderInterceptor b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientBuilderInterceptor
new file mode 100644
index 000000000000..0d43a5bfc657
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientBuilderInterceptor
@@ -0,0 +1 @@
+com.azure.cosmos.spark.TestCosmosClientBuilderInterceptor
\ No newline at end of file
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor
new file mode 100644
index 000000000000..e2239720776d
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.CosmosClientInterceptor
@@ -0,0 +1 @@
+com.azure.cosmos.spark.TestFaultInjectionClientInterceptor
\ No newline at end of file
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor
new file mode 100644
index 000000000000..c60cbf2f14e4
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/resources/META-INF/services/com.azure.cosmos.spark.WriteOnRetryCommitInterceptor
@@ -0,0 +1 @@
+com.azure.cosmos.spark.TestWriteOnRetryCommitInterceptor
\ No newline at end of file
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ChangeFeedMetricsListenerITest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ChangeFeedMetricsListenerITest.scala
new file mode 100644
index 000000000000..6b9de815ea9c
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ChangeFeedMetricsListenerITest.scala
@@ -0,0 +1,157 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+// scalastyle:off magic.number
+// scalastyle:off multiple.string.literals
+
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.changeFeedMetrics.{ChangeFeedMetricsListener, ChangeFeedMetricsTracker}
+import com.azure.cosmos.implementation.guava25.collect.{HashBiMap, Maps}
+import org.apache.spark.Success
+import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics}
+import org.apache.spark.scheduler.{SparkListenerTaskEnd, TaskInfo}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.mockito.ArgumentMatchers
+import org.mockito.Mockito.{mock, when}
+
+import java.lang.reflect.Field
+import java.util.concurrent.ConcurrentHashMap
+
+class ChangeFeedMetricsListenerITest extends IntegrationSpec with SparkWithJustDropwizardAndNoSlf4jMetrics {
+ "ChangeFeedMetricsListener" should "be able to capture changeFeed performance metrics" in {
+ val taskEnd = SparkListenerTaskEnd(
+ stageId = 1,
+ stageAttemptId = 0,
+ taskType = "ResultTask",
+ reason = Success,
+ taskInfo = mock(classOf[TaskInfo]),
+ taskExecutorMetrics = mock(classOf[ExecutorMetrics]),
+ taskMetrics = mock(classOf[TaskMetrics])
+ )
+
+ val indexMetric = SQLMetrics.createMetric(spark.sparkContext, "index")
+ indexMetric.set(1)
+ val lsnMetric = SQLMetrics.createMetric(spark.sparkContext, "lsn")
+ lsnMetric.set(100)
+ val itemsMetric = SQLMetrics.createMetric(spark.sparkContext, "items")
+ itemsMetric.set(100)
+
+ val metrics = Map[String, SQLMetric](
+ CosmosConstants.MetricNames.ChangeFeedPartitionIndex -> indexMetric,
+ CosmosConstants.MetricNames.ChangeFeedLsnRange -> lsnMetric,
+ CosmosConstants.MetricNames.ChangeFeedItemsCnt -> itemsMetric
+ )
+
+ // create sparkInternalsBridge mock
+ val sparkInternalsBridge = mock(classOf[SparkInternalsBridge])
+ when(sparkInternalsBridge.getInternalCustomTaskMetricsAsSQLMetric(
+ ArgumentMatchers.any[Set[String]],
+ ArgumentMatchers.any[TaskMetrics]
+ )).thenReturn(metrics)
+
+ val partitionIndexMap = Maps.synchronizedBiMap(HashBiMap.create[NormalizedRange, Long]())
+ partitionIndexMap.put(NormalizedRange("0", "FF"), 1)
+
+ val partitionMetricsMap = new ConcurrentHashMap[NormalizedRange, ChangeFeedMetricsTracker]()
+ val changeFeedMetricsListener = new ChangeFeedMetricsListener(partitionIndexMap, partitionMetricsMap)
+
+ // set the internal sparkInternalsBridgeField
+ val sparkInternalsBridgeField: Field = classOf[ChangeFeedMetricsListener].getDeclaredField("sparkInternalsBridge")
+ sparkInternalsBridgeField.setAccessible(true)
+ sparkInternalsBridgeField.set(changeFeedMetricsListener, sparkInternalsBridge)
+
+ // verify that metrics will be properly tracked
+ changeFeedMetricsListener.onTaskEnd(taskEnd)
+ partitionMetricsMap.size() shouldBe 1
+ partitionMetricsMap.containsKey(NormalizedRange("0", "FF")) shouldBe true
+ partitionMetricsMap.get(NormalizedRange("0", "FF")).getWeightedChangeFeedItemsPerLsn.get shouldBe 1
+ }
+
+ it should "ignore metrics for unknown partition index" in {
+ val taskEnd = SparkListenerTaskEnd(
+ stageId = 1,
+ stageAttemptId = 0,
+ taskType = "ResultTask",
+ reason = Success,
+ taskInfo = mock(classOf[TaskInfo]),
+ taskExecutorMetrics = mock(classOf[ExecutorMetrics]),
+ taskMetrics = mock(classOf[TaskMetrics])
+ )
+
+ val indexMetric2 = SQLMetrics.createMetric(spark.sparkContext, "index")
+ indexMetric2.set(10)
+ val lsnMetric2 = SQLMetrics.createMetric(spark.sparkContext, "lsn")
+ lsnMetric2.set(100)
+ val itemsMetric2 = SQLMetrics.createMetric(spark.sparkContext, "items")
+ itemsMetric2.set(100)
+
+ val metrics = Map[String, SQLMetric](
+ CosmosConstants.MetricNames.ChangeFeedPartitionIndex -> indexMetric2,
+ CosmosConstants.MetricNames.ChangeFeedLsnRange -> lsnMetric2,
+ CosmosConstants.MetricNames.ChangeFeedItemsCnt -> itemsMetric2
+ )
+
+ // create sparkInternalsBridge mock
+ val sparkInternalsBridge = mock(classOf[SparkInternalsBridge])
+ when(sparkInternalsBridge.getInternalCustomTaskMetricsAsSQLMetric(
+ ArgumentMatchers.any[Set[String]],
+ ArgumentMatchers.any[TaskMetrics]
+ )).thenReturn(metrics)
+
+ val partitionIndexMap = Maps.synchronizedBiMap(HashBiMap.create[NormalizedRange, Long]())
+ partitionIndexMap.put(NormalizedRange("0", "FF"), 1)
+
+ val partitionMetricsMap = new ConcurrentHashMap[NormalizedRange, ChangeFeedMetricsTracker]()
+ val changeFeedMetricsListener = new ChangeFeedMetricsListener(partitionIndexMap, partitionMetricsMap)
+
+ // set the internal sparkInternalsBridgeField
+ val sparkInternalsBridgeField: Field = classOf[ChangeFeedMetricsListener].getDeclaredField("sparkInternalsBridge")
+ sparkInternalsBridgeField.setAccessible(true)
+ sparkInternalsBridgeField.set(changeFeedMetricsListener, sparkInternalsBridge)
+
+ // because partition index 10 does not exist in the partitionIndexMap, it will be ignored
+ changeFeedMetricsListener.onTaskEnd(taskEnd)
+ partitionMetricsMap shouldBe empty
+ }
+
+ it should "ignore unrelated metrics" in {
+ val taskEnd = SparkListenerTaskEnd(
+ stageId = 1,
+ stageAttemptId = 0,
+ taskType = "ResultTask",
+ reason = Success,
+ taskInfo = mock(classOf[TaskInfo]),
+ taskExecutorMetrics = mock(classOf[ExecutorMetrics]),
+ taskMetrics = mock(classOf[TaskMetrics])
+ )
+
+ val unknownMetric3 = SQLMetrics.createMetric(spark.sparkContext, "unknown")
+ unknownMetric3.set(10)
+
+ val metrics = Map[String, SQLMetric](
+ "unknownMetrics" -> unknownMetric3
+ )
+
+ // create sparkInternalsBridge mock
+ val sparkInternalsBridge = mock(classOf[SparkInternalsBridge])
+ when(sparkInternalsBridge.getInternalCustomTaskMetricsAsSQLMetric(
+ ArgumentMatchers.any[Set[String]],
+ ArgumentMatchers.any[TaskMetrics]
+ )).thenReturn(metrics)
+
+ val partitionIndexMap = Maps.synchronizedBiMap(HashBiMap.create[NormalizedRange, Long]())
+ partitionIndexMap.put(NormalizedRange("0", "FF"), 1)
+
+ val partitionMetricsMap = new ConcurrentHashMap[NormalizedRange, ChangeFeedMetricsTracker]()
+ val changeFeedMetricsListener = new ChangeFeedMetricsListener(partitionIndexMap, partitionMetricsMap)
+
+ // set the internal sparkInternalsBridgeField
+ val sparkInternalsBridgeField: Field = classOf[ChangeFeedMetricsListener].getDeclaredField("sparkInternalsBridge")
+ sparkInternalsBridgeField.setAccessible(true)
+ sparkInternalsBridgeField.set(changeFeedMetricsListener, sparkInternalsBridge)
+
+ // because partition index 10 does not exist in the partitionIndexMap, it will be ignored
+ changeFeedMetricsListener.onTaskEnd(taskEnd)
+ partitionMetricsMap shouldBe empty
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITest.scala
new file mode 100644
index 000000000000..c9fc02a6482d
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITest.scala
@@ -0,0 +1,103 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos.spark
+
+import org.apache.commons.lang3.RandomStringUtils
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException
+
+class CosmosCatalogITest
+ extends CosmosCatalogITestBase(skipHive = true) {
+
+ //scalastyle:off magic.number
+
+ // TODO: spark on windows has issue with this test.
+ // java.lang.RuntimeException: java.io.IOException: (null) entry in command string: null chmod 0733 D:\tmp\hive;
+ // once we move Linux CI re-enable the test:
+ it can "drop an empty database" in {
+ assume(!Platform.isWindows)
+
+ for (cascade <- Array(true, false)) {
+ val databaseName = getAutoCleanableDatabaseName
+ spark.catalog.databaseExists(databaseName) shouldEqual false
+
+ createDatabase(spark, databaseName)
+ databaseExists(databaseName) shouldEqual true
+
+ dropDatabase(spark, databaseName, cascade)
+ spark.catalog.databaseExists(databaseName) shouldEqual false
+ }
+ }
+
+ // TODO: spark on windows has issue with this test.
+ // java.lang.RuntimeException: java.io.IOException: (null) entry in command string: null chmod 0733 D:\tmp\hive;
+ // once we move Linux CI re-enable the test:
+ it can "drop an non-empty database with cascade true" in {
+ assume(!Platform.isWindows)
+
+ val databaseName = getAutoCleanableDatabaseName
+ spark.catalog.databaseExists(databaseName) shouldEqual false
+
+ createDatabase(spark, databaseName)
+ databaseExists(databaseName) shouldEqual true
+
+ val containerName = RandomStringUtils.randomAlphabetic(5)
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;")
+
+ dropDatabase(spark, databaseName, true)
+ spark.catalog.databaseExists(databaseName) shouldEqual false
+ }
+
+ // TODO: spark on windows has issue with this test.
+ // java.lang.RuntimeException: java.io.IOException: (null) entry in command string: null chmod 0733 D:\tmp\hive;
+ // once we move Linux CI re-enable the test:
+ "drop an non-empty database with cascade false" should "throw NonEmptyNamespaceException" in {
+ assume(!Platform.isWindows)
+
+ try {
+ val databaseName = getAutoCleanableDatabaseName
+ spark.catalog.databaseExists(databaseName) shouldEqual false
+
+ createDatabase(spark, databaseName)
+ databaseExists(databaseName) shouldEqual true
+
+ val containerName = RandomStringUtils.randomAlphabetic(5)
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;")
+
+ dropDatabase(spark, databaseName, false)
+ fail("Expected NonEmptyNamespaceException is not thrown")
+ }
+ catch {
+ case expectedError: NonEmptyNamespaceException => {
+ logInfo(s"Expected NonEmptyNamespaceException: $expectedError")
+ succeed
+ }
+ }
+ }
+
+ it can "list all databases" in {
+ val databaseName1 = getAutoCleanableDatabaseName
+ val databaseName2 = getAutoCleanableDatabaseName
+
+ // creating those databases ahead of time
+ cosmosClient.createDatabase(databaseName1).block()
+ cosmosClient.createDatabase(databaseName2).block()
+
+ val databases = spark.sql("SHOW DATABASES IN testCatalog").collect()
+ databases.size should be >= 2
+ //validate databases has the above database name1
+ databases
+ .filter(
+ row => row.getAs[String]("namespace").equals(databaseName1)
+ || row.getAs[String]("namespace").equals(databaseName2)) should have size 2
+ }
+
+ private def dropDatabase(spark: SparkSession, databaseName: String, cascade: Boolean) = {
+ if (cascade) {
+ spark.sql(s"DROP DATABASE testCatalog.$databaseName CASCADE;")
+ } else {
+ spark.sql(s"DROP DATABASE testCatalog.$databaseName;")
+ }
+ }
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITestBase.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITestBase.scala
new file mode 100644
index 000000000000..f10fddac7138
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosCatalogITestBase.scala
@@ -0,0 +1,973 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.CosmosException
+import com.azure.cosmos.implementation.{TestConfigurations, Utils}
+import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
+import org.apache.commons.lang3.RandomStringUtils
+import org.apache.spark.sql.execution.streaming.checkpointing.HDFSMetadataLog
+import org.apache.spark.sql.{DataFrame, SparkSession}
+
+import java.util.UUID
+// scalastyle:off underscore.import
+import scala.collection.JavaConverters._
+// scalastyle:on underscore.import
+
+abstract class CosmosCatalogITestBase(val skipHive: Boolean = false) extends IntegrationSpec with CosmosClient with BasicLoggingTrait {
+ //scalastyle:off multiple.string.literals
+ //scalastyle:off magic.number
+
+ var spark : SparkSession = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ val cosmosEndpoint = TestConfigurations.HOST
+ val cosmosMasterKey = TestConfigurations.MASTER_KEY
+
+ var sparkBuilder = SparkSession.builder()
+ .appName("spark connector sample")
+ .master("local")
+
+ if (!skipHive) {
+ sparkBuilder = sparkBuilder.enableHiveSupport()
+ }
+
+ spark = sparkBuilder.getOrCreate()
+
+ LocalJavaFileSystem.applyToSparkSession(spark)
+
+ spark.conf.set(s"spark.sql.catalog.testCatalog", "com.azure.cosmos.spark.CosmosCatalog")
+ spark.conf.set(s"spark.sql.catalog.testCatalog.spark.cosmos.accountEndpoint", cosmosEndpoint)
+ spark.conf.set(s"spark.sql.catalog.testCatalog.spark.cosmos.accountKey", cosmosMasterKey)
+ spark.conf.set(
+ "spark.sql.catalog.testCatalog.spark.cosmos.views.repositoryPath",
+ s"/viewRepository/${UUID.randomUUID().toString}")
+ spark.conf.set(
+ "spark.sql.catalog.testCatalog.spark.cosmos.read.partitioning.strategy",
+ "Restrictive")
+ }
+
+ override def afterAll(): Unit = {
+ try spark.close()
+ finally super.afterAll()
+ }
+
+ it can "create a database with shared throughput" in {
+ val databaseName = getAutoCleanableDatabaseName
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName WITH DBPROPERTIES ('manualThroughput' = '1000');")
+
+ cosmosClient.getDatabase(databaseName).read().block()
+ val throughput = cosmosClient.getDatabase(databaseName).readThroughput().block()
+
+ throughput.getProperties.getManualThroughput shouldEqual 1000
+ }
+
+ it can "create a table with customized properties and hierarchical partition keys, without partition kind and version" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " +
+ s"TBLPROPERTIES(partitionKeyPath = '/tenantId,/userId,/sessionId', manualThroughput = '1100')")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/tenantId", "/userId", "/sessionId"))
+ // scalastyle:off null
+ containerProperties.getDefaultTimeToLiveInSeconds shouldEqual null
+ // scalastyle:on null
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 1100
+ }
+
+ it can "create a table with customized properties and hierarchical partition keys, with correct partition kind" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " +
+ s"TBLPROPERTIES(partitionKeyPath = '/tenantId,/userId,/sessionId', partitionKeyVersion = 'V2', partitionKeyKind = 'MultiHash', manualThroughput = '1100')")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/tenantId", "/userId", "/sessionId"))
+ // scalastyle:off null
+ containerProperties.getDefaultTimeToLiveInSeconds shouldEqual null
+ // scalastyle:on null
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 1100
+ }
+
+ it can "create a table with customized properties and hierarchical partition keys, with wrong partition kind" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ try {
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " +
+ s"TBLPROPERTIES(partitionKeyPath = '/tenantId,/userId,/sessionId', partitionKeyVersion = 'V1', partitionKeyKind = 'Hash', manualThroughput = '1100')")
+ fail("Expected IllegalArgumentException not thrown")
+ }
+ catch
+ {
+ case expectedError: IllegalArgumentException =>
+ logInfo(s"Expected IllegaleArgumentException: $expectedError")
+ succeed // expected error
+ }
+
+ }
+
+ it can "create a database with shared throughput and alter throughput afterwards" in {
+ val databaseName = getAutoCleanableDatabaseName
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName WITH DBPROPERTIES ('manualThroughput' = '1000');")
+
+ cosmosClient.getDatabase(databaseName).read().block()
+ var throughput = cosmosClient.getDatabase(databaseName).readThroughput().block()
+
+ throughput.getProperties.getManualThroughput shouldEqual 1000
+
+ spark.sql(s"ALTER DATABASE testCatalog.$databaseName SET DBPROPERTIES ('manualThroughput' = '4000');")
+
+ cosmosClient.getDatabase(databaseName).read().block()
+ throughput = cosmosClient.getDatabase(databaseName).readThroughput().block()
+
+ throughput.getProperties.getManualThroughput shouldEqual 4000
+ }
+
+ it can "create a table with defaults" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ cleanupDatabaseLater(databaseName)
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp;")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+
+ // verify default partition key path is used
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id"))
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 400
+
+ val tblProperties = getTblProperties(spark, databaseName, containerName)
+
+ tblProperties should have size 8
+
+ tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "null"
+ tblProperties("CosmosPartitionCount") shouldEqual "1"
+ tblProperties("CosmosPartitionKeyDefinition") shouldEqual "{\"paths\":[\"/id\"],\"kind\":\"Hash\"}"
+ tblProperties("DefaultTtlInSeconds") shouldEqual "null"
+ tblProperties("VectorEmbeddingPolicy") shouldEqual "null"
+ tblProperties("IndexingPolicy") shouldEqual
+ "{\"indexingMode\":\"consistent\",\"automatic\":true,\"includedPaths\":[{\"path\":\"/*\"}]," +
+ "\"excludedPaths\":[{\"path\":\"/\\\"_etag\\\"/?\"}]}"
+
+ // would look like Manual|RUProvisioned|LastOfferModification
+ // - last modified as iso datetime like 2021-12-07T10:33:44Z
+ tblProperties("ProvisionedThroughput").startsWith("Manual|400|") shouldEqual true
+ tblProperties("ProvisionedThroughput").length shouldEqual 31
+
+ // last modified as iso datetime like 2021-12-07T10:33:44Z
+ tblProperties("LastModified").length shouldEqual 20
+ }
+
+ it can "create a table and alter throughput afterwards" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ cleanupDatabaseLater(databaseName)
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp;")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+
+ // verify default partition key path is used
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id"))
+
+ // validate throughput
+ var throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 400
+
+ var tblProperties = getTblProperties(spark, databaseName, containerName)
+
+ tblProperties should have size 8
+
+ // would look like Manual|RUProvisioned|LastOfferModification
+ // - last modified as iso datetime like 2021-12-07T10:33:44Z
+ tblProperties("ProvisionedThroughput").startsWith("Manual|400|") shouldEqual true
+ tblProperties("ProvisionedThroughput").length shouldEqual 31
+
+ // last modified as iso datetime like 2021-12-07T10:33:44Z
+ tblProperties("LastModified").length shouldEqual 20
+
+ spark.sql(s"ALTER TABLE testCatalog.$databaseName.$containerName SET TBLPROPERTIES ('manualThroughput' = '4000');")
+
+ // validate throughput
+ throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 4000
+
+ tblProperties = getTblProperties(spark, databaseName, containerName)
+
+ tblProperties should have size 8
+
+ // would look like Manual|RUProvisioned|LastOfferModification
+ // - last modified as iso datetime like 2021-12-07T10:33:44Z
+ tblProperties("ProvisionedThroughput").startsWith("Manual|4000|") shouldEqual true
+ }
+
+ it can "create a table with shared throughput and Hash V2" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ cleanupDatabaseLater(databaseName)
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName WITH DBPROPERTIES ('manualThroughput' = '1000');")
+ spark.sql(
+ s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " +
+ // TODO @fabianm Emulator doesn't seem to support analytical store - needs to be tested separately
+ // s"TBLPROPERTIES(partitionKeyVersion = 'V2', analyticalStoreTtlInSeconds = '3000000')")
+ s"TBLPROPERTIES(partitionKeyVersion = 'V2')")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+
+ // verify default partition key path is used
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id"))
+
+ try {
+ // validate that container uses shared database throughput as default
+ cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+
+ fail("Expected CosmosException not thrown")
+ }
+ catch {
+ case expectedError: CosmosException =>
+ expectedError.getStatusCode shouldEqual 400
+ logInfo(s"Expected CosmosException: $expectedError")
+ }
+
+ val tblProperties = getTblProperties(spark, databaseName, containerName)
+
+ tblProperties should have size 8
+
+ // tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "3000000"
+ tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "null"
+ tblProperties("CosmosPartitionCount") shouldEqual "1"
+ tblProperties("CosmosPartitionKeyDefinition") shouldEqual "{\"paths\":[\"/id\"],\"kind\":\"Hash\",\"version\":2}"
+ tblProperties("DefaultTtlInSeconds") shouldEqual "null"
+ tblProperties("VectorEmbeddingPolicy") shouldEqual "null"
+ tblProperties("IndexingPolicy") shouldEqual
+ "{\"indexingMode\":\"consistent\",\"automatic\":true,\"includedPaths\":[{\"path\":\"/*\"}]," +
+ "\"excludedPaths\":[{\"path\":\"/\\\"_etag\\\"/?\"}]}"
+
+ // would look like Manual|RUProvisioned|LastOfferModification
+ // - last modified as iso datetime like 2021-12-07T10:33:44Z
+ logInfo(s"ProvisionedThroughput: ${tblProperties("ProvisionedThroughput")}")
+ tblProperties("ProvisionedThroughput").startsWith("Shared.Manual|1000|") shouldEqual true
+ tblProperties("ProvisionedThroughput").length shouldEqual 39
+
+ // last modified as iso datetime like 2021-12-07T10:33:44Z
+ tblProperties("LastModified").length shouldEqual 20
+ }
+
+ it can "create a table with defaults but shared autoscale throughput" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ cleanupDatabaseLater(databaseName)
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName WITH DBPROPERTIES ('autoScaleMaxThroughput' = '16000');")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp;")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+
+ // verify default partition key path is used
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id"))
+
+ try {
+ // validate that container uses shared database throughput as default
+ cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+
+ fail("Expected CosmosException not thrown")
+ }
+ catch {
+ case expectedError: CosmosException =>
+ expectedError.getStatusCode shouldEqual 400
+ logInfo(s"Expected CosmosException: $expectedError")
+ }
+
+ val tblProperties = getTblProperties(spark, databaseName, containerName)
+
+ tblProperties should have size 8
+
+ tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "null"
+ tblProperties("CosmosPartitionCount") shouldEqual "2"
+ tblProperties("CosmosPartitionKeyDefinition") shouldEqual "{\"paths\":[\"/id\"],\"kind\":\"Hash\"}"
+ tblProperties("DefaultTtlInSeconds") shouldEqual "null"
+ tblProperties("VectorEmbeddingPolicy") shouldEqual "null"
+ tblProperties("IndexingPolicy") shouldEqual
+ "{\"indexingMode\":\"consistent\",\"automatic\":true,\"includedPaths\":[{\"path\":\"/*\"}]," +
+ "\"excludedPaths\":[{\"path\":\"/\\\"_etag\\\"/?\"}]}"
+
+ // would look like Manual|RUProvisioned|LastOfferModification
+ // - last modified as iso datetime like 2021-12-07T10:33:44Z
+ logInfo(s"ProvisionedThroughput: ${tblProperties("ProvisionedThroughput")}")
+ tblProperties("ProvisionedThroughput").startsWith("Shared.AutoScale|1600|16000|") shouldEqual true
+ tblProperties("ProvisionedThroughput").length shouldEqual 48
+
+ // last modified as iso datetime like 2021-12-07T10:33:44Z
+ tblProperties("LastModified").length shouldEqual 20
+ }
+
+ it can "create a table with customized properties" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " +
+ s"TBLPROPERTIES(partitionKeyPath = '/mypk', manualThroughput = '1100')")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk"))
+ // scalastyle:off null
+ containerProperties.getDefaultTimeToLiveInSeconds shouldEqual null
+ // scalastyle:on null
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 1100
+ }
+
+ it can "create a table with well known indexing policy 'AllProperties'" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " +
+ s"TBLPROPERTIES(partitionKeyPath = '/mypk', manualThroughput = '1100', indexingPolicy = 'AllProperties')")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk"))
+ containerProperties
+ .getIndexingPolicy
+ .getIncludedPaths
+ .asScala
+ .map(p => p.getPath)
+ .toArray should equal(Array("/*"))
+ containerProperties
+ .getIndexingPolicy
+ .getExcludedPaths
+ .asScala
+ .map(p => p.getPath)
+ .toArray should equal(Array(raw"""/"_etag"/?"""))
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 1100
+ }
+
+ it can "create a table with well known indexing policy 'OnlySystemProperties'" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " +
+ s"TBLPROPERTIES(partitionKeyPath = '/mypk', manualThroughput = '1100', indexingPolicy = 'ONLYSystemproperties')")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+ containerProperties.getPartitionKeyDefinition.getPaths.toArray should equal(Array("/mypk"))
+ containerProperties
+ .getIndexingPolicy
+ .getIncludedPaths
+ .asScala.map(p => p.getPath)
+ .toArray.length shouldEqual 0
+ containerProperties
+ .getIndexingPolicy
+ .getExcludedPaths
+ .asScala
+ .map(p => p.getPath)
+ .toArray should equal(Array("/*", raw"""/"_etag"/?"""))
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 1100
+ }
+
+ it can "create a table with custom indexing policy" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+
+ val indexPolicyJson = raw"""{"indexingMode":"consistent","automatic":true,"includedPaths":""" +
+ raw"""[{"path":"\/helloWorld\/?"},{"path":"\/mypk\/?"}],"excludedPaths":[{"path":"\/*"}]}"""
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " +
+ s"TBLPROPERTIES(partitionKeyPath = '/mypk', manualThroughput = '1100', indexingPolicy = '$indexPolicyJson')")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk"))
+ containerProperties
+ .getIndexingPolicy
+ .getIncludedPaths
+ .asScala
+ .map(p => p.getPath)
+ .toArray should equal(Array("/helloWorld/?", "/mypk/?"))
+ containerProperties
+ .getIndexingPolicy
+ .getExcludedPaths
+ .asScala
+ .map(p => p.getPath)
+ .toArray should equal(Array("/*", raw"""/"_etag"/?"""))
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 1100
+
+ val tblProperties = getTblProperties(spark, databaseName, containerName)
+
+ tblProperties should have size 8
+
+ tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "null"
+ tblProperties("CosmosPartitionCount") shouldEqual "1"
+ tblProperties("CosmosPartitionKeyDefinition") shouldEqual "{\"paths\":[\"/mypk\"],\"kind\":\"Hash\"}"
+ tblProperties("DefaultTtlInSeconds") shouldEqual "null"
+ tblProperties("VectorEmbeddingPolicy") shouldEqual "null"
+
+ // indexPolicyJson will be normalized by the backend - so not be the same as the input json
+ // for the purpose of this test I just want to make sure that the custom indexing options
+ // are included - correctness of json serialization of indexing policy is tested elsewhere
+ tblProperties("IndexingPolicy").contains("helloWorld") shouldEqual true
+ tblProperties("IndexingPolicy").contains("mypk") shouldEqual true
+
+ // would look like Manual|RUProvisioned|LastOfferModification
+ // - last modified as iso datetime like 2021-12-07T10:33:44Z
+ tblProperties("ProvisionedThroughput").startsWith("Manual|1100|") shouldEqual true
+ tblProperties("ProvisionedThroughput").length shouldEqual 32
+
+ // last modified as iso datetime like 2021-12-07T10:33:44Z
+ tblProperties("LastModified").length shouldEqual 20
+ }
+
+ it can "create a table with TTL -1" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " +
+ s"TBLPROPERTIES(partitionKeyPath = '/mypk', defaultTtlInSeconds = '-1')")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk"))
+ containerProperties.getDefaultTimeToLiveInSeconds shouldEqual -1
+
+ val tblProperties = getTblProperties(spark, databaseName, containerName)
+ tblProperties("DefaultTtlInSeconds") shouldEqual "-1"
+ }
+
+ it can "create a table with positive TTL" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " +
+ s"TBLPROPERTIES(partitionKeyPath = '/mypk', defaultTtlInSeconds = '5')")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk"))
+ containerProperties.getDefaultTimeToLiveInSeconds shouldEqual 5
+
+ val tblProperties = getTblProperties(spark, databaseName, containerName)
+ tblProperties("DefaultTtlInSeconds") shouldEqual "5"
+ }
+
+ it can "create a table with vector embedding policy" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ cleanupDatabaseLater(databaseName)
+
+ val vectorEmbeddingPolicyJson =
+ raw"""{"vectorEmbeddings":[{"path":"/vector1","dataType":"float32","distanceFunction":"cosine","dimensions":500}]}"""
+
+ val indexingPolicyJson =
+ raw"""{"indexingMode":"consistent","automatic":true,"includedPaths":[{"path":"\/mypk\/?"}],""" +
+ raw""""excludedPaths":[{"path":"\/*"}],"vectorIndexes":[{"path":"\/vector1","type":"flat"}]}"""
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp " +
+ s"TBLPROPERTIES(partitionKeyPath = '/mypk', manualThroughput = '1100', " +
+ s"indexingPolicy = '$indexingPolicyJson', " +
+ s"vectorEmbeddingPolicy = '$vectorEmbeddingPolicyJson')")
+
+ val containerProperties = cosmosClient.getDatabase(databaseName).getContainer(containerName).read().block().getProperties
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/mypk"))
+
+ // validate vector embedding policy
+ val vectorEmbeddingPolicy = containerProperties.getVectorEmbeddingPolicy
+ vectorEmbeddingPolicy should not be null
+ vectorEmbeddingPolicy.getVectorEmbeddings should have size 1
+ val embedding = vectorEmbeddingPolicy.getVectorEmbeddings.get(0)
+ embedding.getPath shouldEqual "/vector1"
+ embedding.getDataType.toString shouldEqual "float32"
+ embedding.getDistanceFunction.toString shouldEqual "cosine"
+ embedding.getEmbeddingDimensions shouldEqual 500
+
+ // validate vector indexes are in indexing policy
+ val vectorIndexes = containerProperties.getIndexingPolicy.getVectorIndexes
+ vectorIndexes should have size 1
+ vectorIndexes.get(0).getPath shouldEqual "/vector1"
+ vectorIndexes.get(0).getType shouldEqual "flat"
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 1100
+
+ val tblProperties = getTblProperties(spark, databaseName, containerName)
+
+ tblProperties should have size 8
+
+ tblProperties("CosmosPartitionKeyDefinition") shouldEqual "{\"paths\":[\"/mypk\"],\"kind\":\"Hash\"}"
+ tblProperties("DefaultTtlInSeconds") shouldEqual "null"
+ tblProperties("AnalyticalStoreTtlInSeconds") shouldEqual "null"
+
+ // validate vector embedding policy is in table properties (structured check)
+ val vepObjectMapper = Utils.getSimpleObjectMapper
+ val vepNode = vepObjectMapper.readTree(tblProperties("VectorEmbeddingPolicy"))
+ val vepEmbeddings = vepNode.get("vectorEmbeddings")
+ vepEmbeddings.size() shouldEqual 1
+ vepEmbeddings.get(0).get("path").asText() shouldEqual "/vector1"
+ vepEmbeddings.get(0).get("dataType").asText() shouldEqual "float32"
+ vepEmbeddings.get(0).get("distanceFunction").asText() shouldEqual "cosine"
+
+ // validate vector indexes are in indexing policy (structured check)
+ val ipNode = vepObjectMapper.readTree(tblProperties("IndexingPolicy"))
+ val vectorIndexesNode = ipNode.get("vectorIndexes")
+ vectorIndexesNode.size() shouldEqual 1
+ vectorIndexesNode.get(0).get("path").asText() shouldEqual "/vector1"
+ vectorIndexesNode.get(0).get("type").asText() shouldEqual "flat"
+
+ // would look like Manual|RUProvisioned|LastOfferModification
+ // - last modified as iso datetime like 2021-12-07T10:33:44Z
+ tblProperties("ProvisionedThroughput").startsWith("Manual|1100|") shouldEqual true
+ tblProperties("ProvisionedThroughput").length shouldEqual 32
+
+ // last modified as iso datetime like 2021-12-07T10:33:44Z
+ tblProperties("LastModified").length shouldEqual 20
+ }
+
+ it can "select from a catalog table with default TBLPROPERTIES" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ cleanupDatabaseLater(databaseName)
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName (word STRING, number INT) using cosmos.oltp;")
+
+ val container = cosmosClient.getDatabase(databaseName).getContainer(containerName)
+ val containerProperties = container.read().block().getProperties
+
+ // verify default partition key path is used
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id"))
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 400
+
+ for (state <- Array(true, false)) {
+ val objectNode = Utils.getSimpleObjectMapper.createObjectNode()
+ objectNode.put("name", "Shrodigner's mouse")
+ objectNode.put("type", "mouse")
+ objectNode.put("age", 20)
+ objectNode.put("isAlive", state)
+ objectNode.put("id", UUID.randomUUID().toString)
+ container.createItem(objectNode).block()
+ }
+
+ val dfWithInference = spark.sql(s"SELECT * FROM testCatalog.$databaseName.$containerName")
+ val rowsArrayUnfiltered= dfWithInference.collect()
+ rowsArrayUnfiltered should have size 2
+ val rowsArrayWithInference = dfWithInference.where("isAlive = 'true' and type = 'mouse'").collect()
+ rowsArrayWithInference should have size 1
+
+ val rowWithInference = rowsArrayWithInference(0)
+ rowWithInference.getAs[String]("name") shouldEqual "Shrodigner's mouse"
+ rowWithInference.getAs[String]("type") shouldEqual "mouse"
+ rowWithInference.getAs[Integer]("age") shouldEqual 20
+ rowWithInference.getAs[Boolean]("isAlive") shouldEqual true
+
+ val fieldNames = rowWithInference.schema.fields.map(field => field.name)
+ fieldNames.contains(CosmosTableSchemaInferrer.SelfAttributeName) shouldBe false
+ fieldNames.contains(CosmosTableSchemaInferrer.TimestampAttributeName) shouldBe false
+ fieldNames.contains(CosmosTableSchemaInferrer.ResourceIdAttributeName) shouldBe false
+ fieldNames.contains(CosmosTableSchemaInferrer.ETagAttributeName) shouldBe false
+ fieldNames.contains(CosmosTableSchemaInferrer.AttachmentsAttributeName) shouldBe false
+ }
+
+ it can "select from a catalog Cosmos view" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ val viewName = containerName + "view" + RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;")
+
+ val container = cosmosClient.getDatabase(databaseName).getContainer(containerName)
+ val containerProperties = container.read().block().getProperties
+
+ // verify default partition key path is used
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id"))
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 400
+
+ for (state <- Array(true, false)) {
+ val objectNode = Utils.getSimpleObjectMapper.createObjectNode()
+ objectNode.put("name", "Shrodigner's mouse")
+ objectNode.put("type", "mouse")
+ objectNode.put("age", 20)
+ objectNode.put("isAlive", state)
+ objectNode.put("id", UUID.randomUUID().toString)
+ container.createItem(objectNode).block()
+ }
+
+ spark.sql(
+ s"CREATE TABLE testCatalog.$databaseName.$viewName using cosmos.oltp " +
+ s"TBLPROPERTIES(isCosmosView = 'True') " +
+ s"OPTIONS (" +
+ s"spark.cosmos.database = '$databaseName', " +
+ s"spark.cosmos.container = '$containerName', " +
+ "spark.cosmos.read.inferSchema.enabled = 'True', " +
+ "spark.cosmos.read.inferSchema.includeSystemProperties = 'True', " +
+ "spark.cosmos.read.partitioning.strategy = 'Restrictive');")
+ val tables = spark.sql(s"SHOW TABLES in testCatalog.$databaseName;")
+
+ tables.collect() should have size 2
+
+ tables
+ .where(s"tableName = '$viewName' and namespace = '$databaseName'")
+ .collect() should have size 1
+
+ tables
+ .where(s"tableName = '$containerName' and namespace = '$databaseName'")
+ .collect() should have size 1
+
+ val dfWithInference = spark.sql(s"SELECT * FROM testCatalog.$databaseName.$viewName")
+ val rowsArrayUnfiltered= dfWithInference.collect()
+ rowsArrayUnfiltered should have size 2
+
+ val rowsArrayWithInference = dfWithInference.where("isAlive = 'true' and type = 'mouse'").collect()
+ rowsArrayWithInference should have size 1
+
+ val rowWithInference = rowsArrayWithInference(0)
+ rowWithInference.getAs[String]("name") shouldEqual "Shrodigner's mouse"
+ rowWithInference.getAs[String]("type") shouldEqual "mouse"
+ rowWithInference.getAs[Integer]("age") shouldEqual 20
+ rowWithInference.getAs[Boolean]("isAlive") shouldEqual true
+
+ val fieldNames = rowWithInference.schema.fields.map(field => field.name)
+ fieldNames.contains(CosmosTableSchemaInferrer.SelfAttributeName) shouldBe true
+ fieldNames.contains(CosmosTableSchemaInferrer.TimestampAttributeName) shouldBe true
+ fieldNames.contains(CosmosTableSchemaInferrer.ResourceIdAttributeName) shouldBe true
+ fieldNames.contains(CosmosTableSchemaInferrer.ETagAttributeName) shouldBe true
+ fieldNames.contains(CosmosTableSchemaInferrer.AttachmentsAttributeName) shouldBe true
+ }
+
+ it can "manage Cosmos view metadata in the catalog" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ val viewNameRaw = containerName +
+ "view" +
+ RandomStringUtils.randomAlphabetic(6).toLowerCase +
+ System.currentTimeMillis()
+ val viewNameWithSchemaInference = containerName +
+ "view" +
+ RandomStringUtils.randomAlphabetic(6).toLowerCase +
+ System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;")
+
+ val container = cosmosClient.getDatabase(databaseName).getContainer(containerName)
+ val containerProperties = container.read().block().getProperties
+
+ // verify default partition key path is used
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id"))
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 400
+
+ for (state <- Array(true, false)) {
+ val objectNode = Utils.getSimpleObjectMapper.createObjectNode()
+ objectNode.put("name", "Shrodigner's snake")
+ objectNode.put("type", "snake")
+ objectNode.put("age", 20)
+ objectNode.put("isAlive", state)
+ objectNode.put("id", UUID.randomUUID().toString)
+ container.createItem(objectNode).block()
+ }
+
+ spark.sql(
+ s"CREATE TABLE testCatalog.$databaseName.$viewNameRaw using cosmos.oltp " +
+ s"TBLPROPERTIES(isCosmosView = 'True') " +
+ s"OPTIONS (" +
+ s"spark.cosmos.database = '$databaseName', " +
+ s"spark.cosmos.container = '$containerName', " +
+ s"spark.sql.catalog.testCatalog.spark.cosmos.accountKey = '${TestConfigurations.MASTER_KEY}', " +
+ s"spark.sql.catalog.testCatalog.spark.cosmos.accountEndpoint = '${TestConfigurations.HOST}', " +
+ s"spark.cosmos.accountKey = '${TestConfigurations.MASTER_KEY}', " +
+ s"spark.cosmos.accountEndpoint = '${TestConfigurations.HOST}', " +
+ "spark.cosmos.read.inferSchema.enabled = 'False', " +
+ "spark.cosmos.read.partitioning.strategy = 'Restrictive');")
+
+ var tables = spark.sql(s"SHOW TABLES in testCatalog.$databaseName;")
+ tables.collect() should have size 2
+
+ spark.sql(
+ s"CREATE TABLE testCatalog.$databaseName.$viewNameWithSchemaInference using cosmos.oltp " +
+ s"TBLPROPERTIES(isCosmosView = 'True') " +
+ s"OPTIONS (" +
+ s"spark.cosmos.database = '$databaseName', " +
+ s"spark.cosmos.container = '$containerName', " +
+ s"spark.sql.catalog.testCatalog.spark.cosmos.accountKey = '${TestConfigurations.MASTER_KEY}', " +
+ s"spark.sql.catalog.testCatalog.spark.cosmos.accountEndpoint = '${TestConfigurations.HOST}', " +
+ s"spark.cosmos.accountKey = '${TestConfigurations.MASTER_KEY}', " +
+ s"spark.cosmos.accountEndpoint = '${TestConfigurations.HOST}', " +
+ "spark.cosmos.read.inferSchema.enabled = 'True', " +
+ "spark.cosmos.read.inferSchema.includeSystemProperties = 'False', " +
+ "spark.cosmos.read.partitioning.strategy = 'Restrictive');")
+
+ tables = spark.sql(s"SHOW TABLES in testCatalog.$databaseName;")
+ tables.collect() should have size 3
+
+ val filePath = spark.conf.get("spark.sql.catalog.testCatalog.spark.cosmos.views.repositoryPath")
+ val hdfsMetadataLog = new HDFSMetadataLog[String](spark, filePath)
+
+ hdfsMetadataLog.getLatest() match {
+ case None => throw new IllegalStateException("HDFS metadata file should have been written")
+ case Some((batchId, json)) =>
+
+ logInfo(s"BatchId: $batchId, Json: $json")
+
+ // Validate the master key is not stored anywhere
+ json.contains(TestConfigurations.MASTER_KEY) shouldEqual false
+ json.contains(TestConfigurations.SECONDARY_MASTER_KEY) shouldEqual false
+ json.contains(TestConfigurations.HOST) shouldEqual false
+
+ // validate that we can deserialize the persisted json
+ val deserializedViews = ViewDefinitionEnvelopeSerializer.fromJson(json)
+ deserializedViews.length >= 2 shouldBe true
+ deserializedViews
+ .exists(vd => vd.databaseName == databaseName && vd.viewName == viewNameRaw) shouldEqual true
+ deserializedViews
+ .exists(vd => vd.databaseName == databaseName &&
+ vd.viewName == viewNameWithSchemaInference) shouldEqual true
+ }
+
+ tables
+ .where(s"tableName = '$containerName' and namespace = '$databaseName'")
+ .collect() should have size 1
+ tables
+ .where(s"tableName = '$viewNameRaw' and namespace = '$databaseName'")
+ .collect() should have size 1
+ tables
+ .where(s"tableName = '$viewNameWithSchemaInference' and namespace = '$databaseName'")
+ .collect() should have size 1
+
+ val dfRaw = spark.sql(s"SELECT * FROM testCatalog.$databaseName.$viewNameRaw")
+ val rowsArrayUnfilteredRaw= dfRaw.collect()
+ rowsArrayUnfilteredRaw should have size 2
+
+ val fieldNamesRaw = dfRaw.schema.fields.map(field => field.name)
+ fieldNamesRaw.contains(CosmosTableSchemaInferrer.IdAttributeName) shouldBe true
+ fieldNamesRaw.contains(CosmosTableSchemaInferrer.RawJsonBodyAttributeName) shouldBe true
+ fieldNamesRaw.contains(CosmosTableSchemaInferrer.TimestampAttributeName) shouldBe true
+ fieldNamesRaw.contains(CosmosTableSchemaInferrer.SelfAttributeName) shouldBe false
+ fieldNamesRaw.contains(CosmosTableSchemaInferrer.ResourceIdAttributeName) shouldBe false
+ fieldNamesRaw.contains(CosmosTableSchemaInferrer.ETagAttributeName) shouldBe false
+ fieldNamesRaw.contains(CosmosTableSchemaInferrer.AttachmentsAttributeName) shouldBe false
+
+ val dfWithInference = spark.sql(s"SELECT * FROM testCatalog.$databaseName.$viewNameWithSchemaInference")
+ val rowsArrayUnfiltered= dfWithInference.collect()
+ rowsArrayUnfiltered should have size 2
+
+ val rowsArrayWithInference = dfWithInference.where("isAlive = 'true' and type = 'snake'").collect()
+ rowsArrayWithInference should have size 1
+
+ val rowWithInference = rowsArrayWithInference(0)
+ rowWithInference.getAs[String]("name") shouldEqual "Shrodigner's snake"
+ rowWithInference.getAs[String]("type") shouldEqual "snake"
+ rowWithInference.getAs[Integer]("age") shouldEqual 20
+ rowWithInference.getAs[Boolean]("isAlive") shouldEqual true
+
+ val fieldNames = rowWithInference.schema.fields.map(field => field.name)
+ fieldNames.contains(CosmosTableSchemaInferrer.SelfAttributeName) shouldBe false
+ fieldNames.contains(CosmosTableSchemaInferrer.TimestampAttributeName) shouldBe false
+ fieldNames.contains(CosmosTableSchemaInferrer.ResourceIdAttributeName) shouldBe false
+ fieldNames.contains(CosmosTableSchemaInferrer.ETagAttributeName) shouldBe false
+ fieldNames.contains(CosmosTableSchemaInferrer.AttachmentsAttributeName) shouldBe false
+
+ spark.sql(s"DROP TABLE testCatalog.$databaseName.$viewNameRaw;")
+ tables = spark.sql(s"SHOW TABLES in testCatalog.$databaseName;")
+ tables.collect() should have size 2
+
+ spark.sql(s"DROP TABLE testCatalog.$databaseName.$viewNameWithSchemaInference;")
+ tables = spark.sql(s"SHOW TABLES in testCatalog.$databaseName;")
+ tables.collect() should have size 1
+ }
+
+ "creating a view without specifying isCosmosView table property" should "throw IllegalArgumentException" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ val viewName = containerName +
+ "view" +
+ RandomStringUtils.randomAlphabetic(6).toLowerCase +
+ System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;")
+
+ val container = cosmosClient.getDatabase(databaseName).getContainer(containerName)
+ val containerProperties = container.read().block().getProperties
+
+ // verify default partition key path is used
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id"))
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 400
+
+ for (state <- Array(true, false)) {
+ val objectNode = Utils.getSimpleObjectMapper.createObjectNode()
+ objectNode.put("name", "Shrodigner's snake")
+ objectNode.put("type", "snake")
+ objectNode.put("age", 20)
+ objectNode.put("isAlive", state)
+ objectNode.put("id", UUID.randomUUID().toString)
+ container.createItem(objectNode).block()
+ }
+
+ try {
+ spark.sql(
+ s"CREATE TABLE testCatalog.$databaseName.$viewName using cosmos.oltp " +
+ s"TBLPROPERTIES(isCosmosViewWithTypo = 'True') " +
+ s"OPTIONS (" +
+ s"spark.cosmos.database = '$databaseName', " +
+ s"spark.cosmos.container = '$containerName', " +
+ "spark.cosmos.read.inferSchema.enabled = 'False', " +
+ "spark.cosmos.read.partitioning.strategy = 'Restrictive');")
+
+ fail("Expected IllegalArgumentException not thrown")
+ }
+ catch {
+ case expectedError: IllegalArgumentException =>
+ logInfo(s"Expected IllegaleArgumentException: $expectedError")
+ succeed
+ }
+ }
+
+ "creating a view with specifying isCosmosView==False table property" should "throw IllegalArgumentException" in {
+ val databaseName = getAutoCleanableDatabaseName
+ val containerName = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ val viewName = containerName +
+ "view" +
+ RandomStringUtils.randomAlphabetic(6).toLowerCase +
+ System.currentTimeMillis()
+
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ spark.sql(s"CREATE TABLE testCatalog.$databaseName.$containerName using cosmos.oltp;")
+
+ val container = cosmosClient.getDatabase(databaseName).getContainer(containerName)
+ val containerProperties = container.read().block().getProperties
+
+ // verify default partition key path is used
+ containerProperties.getPartitionKeyDefinition.getPaths.asScala.toArray should equal(Array("/id"))
+
+ // validate throughput
+ val throughput = cosmosClient.getDatabase(databaseName).getContainer(containerName).readThroughput().block().getProperties
+ throughput.getManualThroughput shouldEqual 400
+
+ for (state <- Array(true, false)) {
+ val objectNode = Utils.getSimpleObjectMapper.createObjectNode()
+ objectNode.put("name", "Shrodigner's snake")
+ objectNode.put("type", "snake")
+ objectNode.put("age", 20)
+ objectNode.put("isAlive", state)
+ objectNode.put("id", UUID.randomUUID().toString)
+ container.createItem(objectNode).block()
+ }
+
+ try {
+ spark.sql(
+ s"CREATE TABLE testCatalog.$databaseName.$viewName using cosmos.oltp " +
+ s"TBLPROPERTIES(isCosmosView = 'False') " +
+ s"OPTIONS (" +
+ s"spark.cosmos.database = '$databaseName', " +
+ s"spark.cosmos.container = '$containerName', " +
+ "spark.cosmos.read.inferSchema.enabled = 'False', " +
+ "spark.cosmos.read.partitioning.strategy = 'Restrictive');")
+
+ fail("Expected IllegalArgumentException not thrown")
+ }
+ catch {
+ case expectedError: IllegalArgumentException =>
+ logInfo(s"Expected IllegaleArgumentException: $expectedError")
+ succeed
+ }
+ }
+
+ it can "list all containers in a database" in {
+ val databaseName = getAutoCleanableDatabaseName
+ cosmosClient.createDatabase(databaseName).block()
+
+ // create multiple containers under the same database
+ val containerName1 = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ val containerName2 = RandomStringUtils.randomAlphabetic(6).toLowerCase + System.currentTimeMillis()
+ cosmosClient.getDatabase(databaseName).createContainer(containerName1, "/id").block()
+ cosmosClient.getDatabase(databaseName).createContainer(containerName2, "/id").block()
+
+ val containers = spark.sql(s"SHOW TABLES FROM testCatalog.$databaseName").collect()
+ containers should have size 2
+ containers
+ .filter(
+ row => row.getAs[String]("tableName").equals(containerName1)
+ || row.getAs[String]("tableName").equals(containerName2)) should have size 2
+ }
+
+ private def getTblProperties(spark: SparkSession, databaseName: String, containerName: String) = {
+ val descriptionDf = spark.sql(s"DESCRIBE TABLE EXTENDED testCatalog.$databaseName.$containerName;")
+ val tblPropertiesRowsArray = descriptionDf
+ .where("col_name = 'Table Properties'")
+ .collect()
+
+ for (row <- tblPropertiesRowsArray) {
+ logInfo(row.mkString)
+ }
+ tblPropertiesRowsArray should have size 1
+
+ // Output will look something like this
+ // [key1='value1',key2='value2',...]
+ val tblPropertiesText = tblPropertiesRowsArray(0).getAs[String]("data_type")
+ // parsing this into dictionary
+
+ val keyValuePairs = tblPropertiesText.substring(1, tblPropertiesText.length - 2).split("',")
+ keyValuePairs
+ .map(kvp => {
+ val columns = kvp.split("='")
+ (columns(0), columns(1))
+ })
+ .toMap
+ }
+
+ def createDatabase(spark: SparkSession, databaseName: String): DataFrame = {
+ spark.sql(s"CREATE DATABASE testCatalog.$databaseName;")
+ }
+
+ //scalastyle:on magic.number
+ //scalastyle:on multiple.string.literals
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosRowConverterTest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosRowConverterTest.scala
new file mode 100644
index 000000000000..a5bdc9df94c1
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/CosmosRowConverterTest.scala
@@ -0,0 +1,97 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
+import com.fasterxml.jackson.databind.ObjectMapper
+import com.fasterxml.jackson.databind.node.ObjectNode
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.types.TimestampNTZType
+
+import java.sql.{Date, Timestamp}
+import java.time.format.DateTimeFormatter
+import java.time.{LocalDateTime, OffsetDateTime}
+
+// scalastyle:off underscore.import
+import org.apache.spark.sql.types._
+// scalastyle:on underscore.import
+
+class CosmosRowConverterTest extends UnitSpec with BasicLoggingTrait {
+ //scalastyle:off null
+ //scalastyle:off multiple.string.literals
+ //scalastyle:off file.size.limit
+
+ val objectMapper = new ObjectMapper()
+ private[this] val defaultRowConverter =
+ CosmosRowConverter.get(
+ new CosmosSerializationConfig(
+ SerializationInclusionModes.Always,
+ SerializationDateTimeConversionModes.Default
+ )
+ )
+
+
+ "date and time and TimestampNTZType in spark row" should "translate to ObjectNode" in {
+ val colName1 = "testCol1"
+ val colName2 = "testCol2"
+ val colName3 = "testCol3"
+ val colName4 = "testCol4"
+ val currentMillis = System.currentTimeMillis()
+ val colVal1 = new Date(currentMillis)
+ val timestampNTZType = "2021-07-01T08:43:28.037"
+ val colVal2 = LocalDateTime.parse(timestampNTZType, DateTimeFormatter.ISO_DATE_TIME)
+ val colVal3 = currentMillis.toInt
+
+ val row = new GenericRowWithSchema(
+ Array(colVal1, colVal2, colVal3, colVal3),
+ StructType(Seq(StructField(colName1, DateType),
+ StructField(colName2, TimestampNTZType),
+ StructField(colName3, DateType),
+ StructField(colName4, TimestampType))))
+
+ val objectNode = defaultRowConverter.fromRowToObjectNode(row)
+ objectNode.get(colName1).asLong() shouldEqual currentMillis
+ objectNode.get(colName2).asText() shouldEqual "2021-07-01T08:43:28.037"
+ objectNode.get(colName3).asInt() shouldEqual colVal3
+ objectNode.get(colName4).asInt() shouldEqual colVal3
+ }
+
+ "time and TimestampNTZType in ObjectNode" should "translate to Row" in {
+ val colName1 = "testCol1"
+ val colName2 = "testCol2"
+ val colName3 = "testCol3"
+ val colName4 = "testCol4"
+ val colVal1 = System.currentTimeMillis()
+ val colVal1AsTime = new Timestamp(colVal1)
+ val colVal2 = System.currentTimeMillis()
+ val colVal2AsTime = new Timestamp(colVal2)
+ val colVal3 = "2021-01-20T20:10:15+01:00"
+ val colVal3AsTime = Timestamp.valueOf(OffsetDateTime.parse(colVal3, DateTimeFormatter.ISO_OFFSET_DATE_TIME).toLocalDateTime)
+ val colVal4 = "2021-07-01T08:43:28.037"
+ val colVal4AsTime = LocalDateTime.parse(colVal4, DateTimeFormatter.ISO_DATE_TIME)
+
+ val objectNode: ObjectNode = objectMapper.createObjectNode()
+ objectNode.put(colName1, colVal1)
+ objectNode.put(colName2, colVal2)
+ objectNode.put(colName3, colVal3)
+ objectNode.put(colName4, colVal4)
+ val schema = StructType(Seq(
+ StructField(colName1, TimestampType),
+ StructField(colName2, TimestampType),
+ StructField(colName3, TimestampType),
+ StructField(colName4, TimestampNTZType)))
+ val row = defaultRowConverter.fromObjectNodeToRow(schema, objectNode, SchemaConversionModes.Relaxed)
+ val asTime = row.get(0).asInstanceOf[Timestamp]
+ asTime.compareTo(colVal1AsTime) shouldEqual 0
+ val asTime2 = row.get(1).asInstanceOf[Timestamp]
+ asTime2.compareTo(colVal2AsTime) shouldEqual 0
+ val asTime3 = row.get(2).asInstanceOf[Timestamp]
+ asTime3.compareTo(colVal3AsTime) shouldEqual 0
+ val asTime4 = row.get(3).asInstanceOf[LocalDateTime]
+ asTime4.compareTo(colVal4AsTime) shouldEqual 0
+ }
+
+ //scalastyle:on null
+ //scalastyle:on multiple.string.literals
+ //scalastyle:on file.size.limit
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ItemsScanITest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ItemsScanITest.scala
new file mode 100644
index 000000000000..b6433c6d7b2f
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/ItemsScanITest.scala
@@ -0,0 +1,256 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.implementation.{CosmosClientMetadataCachesSnapshot, SparkBridgeImplementationInternal, TestConfigurations, Utils}
+import com.azure.cosmos.models.PartitionKey
+import com.fasterxml.jackson.databind.node.ObjectNode
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.connector.expressions.Expressions
+import org.apache.spark.sql.sources.{Filter, In}
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+
+import java.util.UUID
+import scala.collection.mutable.ListBuffer
+
+class ItemsScanITest
+ extends IntegrationSpec
+ with Spark
+ with AutoCleanableCosmosContainersWithPkAsPartitionKey {
+
+ //scalastyle:off multiple.string.literals
+ //scalastyle:off magic.number
+
+ private val idProperty = "id"
+ private val pkProperty = "pk"
+ private val itemIdentityProperty = "_itemIdentity"
+
+ private val analyzedAggregatedFilters =
+ AnalyzedAggregatedFilters(
+ QueryFilterAnalyzer.rootParameterizedQuery,
+ false,
+ Array.empty[Filter],
+ Array.empty[Filter],
+ Option.empty[List[ReadManyFilter]])
+
+ it should "only return readMany filtering property when runtTimeFiltering is enabled and readMany filtering is enabled" in {
+ val clientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots()
+
+ val testCases = Array(
+ // containerName, partitionKey property, expected readMany filtering property
+ (cosmosContainer, idProperty, idProperty),
+ (cosmosContainersWithPkAsPartitionKey, pkProperty, itemIdentityProperty)
+ )
+
+ for (testCase <- testCases) {
+ val partitionKeyDefinition =
+ cosmosClient
+ .getDatabase(cosmosDatabase)
+ .getContainer(testCase._1)
+ .read()
+ .block()
+ .getProperties
+ .getPartitionKeyDefinition
+
+ for (runTimeFilteringEnabled <- Array(true, false)) {
+ for (readManyFilteringEnabled <- Array(true, false)) {
+ logInfo(s"TestCase: containerName ${testCase._1}, partitionKeyProperty ${testCase._2}, " +
+ s"runtimeFilteringEnabled $runTimeFilteringEnabled, readManyFilteringEnabled $readManyFilteringEnabled")
+
+ val config = Map(
+ "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST,
+ "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> testCase._1,
+ "spark.cosmos.read.inferSchema.enabled" -> "true",
+ "spark.cosmos.applicationName" -> "ItemsScan",
+ "spark.cosmos.read.runtimeFiltering.enabled" -> runTimeFilteringEnabled.toString,
+ "spark.cosmos.read.readManyFiltering.enabled" -> readManyFilteringEnabled.toString
+ )
+ val readConfig = CosmosReadConfig.parseCosmosReadConfig(config)
+ val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config)
+ val schema = getDefaultSchema(testCase._2)
+
+ val itemScan = new ItemsScan(
+ spark,
+ schema,
+ config,
+ readConfig,
+ analyzedAggregatedFilters,
+ clientMetadataCachesSnapshots,
+ diagnosticsConfig,
+ "",
+ partitionKeyDefinition)
+ val arrayReferences = itemScan.filterAttributes()
+
+ if (runTimeFilteringEnabled && readManyFilteringEnabled) {
+ arrayReferences.size shouldBe 1
+ arrayReferences should contain theSameElementsAs Array(Expressions.column(testCase._3))
+ } else {
+ arrayReferences shouldBe empty
+ }
+ }
+ }
+ }
+ }
+
+ it should "only prune partitions when runtTimeFiltering is enabled and readMany filtering is enabled" in {
+ val clientMetadataCachesSnapshots = getCosmosClientMetadataCachesSnapshots()
+
+ val testCases = Array(
+ //containerName, partitionKeyProperty, expected readManyFiltering property
+ (cosmosContainer, idProperty, idProperty),
+ (cosmosContainersWithPkAsPartitionKey, pkProperty, itemIdentityProperty)
+ )
+ for (testCase <- testCases) {
+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(testCase._1)
+ val partitionKeyDefinition = container.read().block().getProperties.getPartitionKeyDefinition
+
+ // assert that there is more than one range
+ val feedRanges = container.getFeedRanges.block()
+ feedRanges.size() should be > 1
+
+ // first inject few items
+ val matchingItemList = ListBuffer[ObjectNode]()
+ for (_ <- 1 to 20) {
+ val objectNode = getNewItem(testCase._2)
+ container.createItem(objectNode).block()
+ matchingItemList += objectNode
+ logInfo(s"ID of test doc: ${objectNode.get(idProperty).asText()}")
+ }
+
+ // choose one of the items created above and filter by it
+ val runtimeFilters = getReadManyFilters(Array(matchingItemList(0)), testCase._2, testCase._3)
+
+ for (runTimeFilteringEnabled <- Array(true, false)) {
+ for (readManyFilteringEnabled <- Array(true, false)) {
+ logInfo(s"TestCase: containerName ${testCase._1}, partitionKeyProperty ${testCase._2}, " +
+ s"runtimeFilteringEnabled $runTimeFilteringEnabled, readManyFilteringEnabled $readManyFilteringEnabled")
+
+ val config = Map(
+ "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST,
+ "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> testCase._1,
+ "spark.cosmos.read.inferSchema.enabled" -> "true",
+ "spark.cosmos.applicationName" -> "ItemsScan",
+ "spark.cosmos.read.partitioning.strategy" -> "Restrictive",
+ "spark.cosmos.read.runtimeFiltering.enabled" -> runTimeFilteringEnabled.toString,
+ "spark.cosmos.read.readManyFiltering.enabled" -> readManyFilteringEnabled.toString
+ )
+ val readConfig = CosmosReadConfig.parseCosmosReadConfig(config)
+ val diagnosticsConfig = DiagnosticsConfig.parseDiagnosticsConfig(config)
+
+ val schema = getDefaultSchema(testCase._2)
+ val itemScan = new ItemsScan(
+ spark,
+ schema,
+ config,
+ readConfig,
+ analyzedAggregatedFilters,
+ clientMetadataCachesSnapshots,
+ diagnosticsConfig,
+ "",
+ partitionKeyDefinition)
+
+ val plannedInputPartitions = itemScan.planInputPartitions()
+ plannedInputPartitions.length shouldBe feedRanges.size() // using restrictive strategy
+
+ itemScan.filter(runtimeFilters)
+ val plannedInputPartitionAfterFiltering = itemScan.planInputPartitions()
+
+ if (runTimeFilteringEnabled && readManyFilteringEnabled) {
+ // partition can be pruned
+ plannedInputPartitionAfterFiltering.length shouldBe 1
+ val filterItemFeedRange =
+ SparkBridgeImplementationInternal.partitionKeyToNormalizedRange(
+ new PartitionKey(getPartitionKeyValue(matchingItemList(0), s"/${testCase._2}")),
+ partitionKeyDefinition)
+
+ val rangesOverlap =
+ SparkBridgeImplementationInternal.doRangesOverlap(
+ filterItemFeedRange,
+ plannedInputPartitionAfterFiltering(0).asInstanceOf[CosmosInputPartition].feedRange)
+
+ rangesOverlap shouldBe true
+ } else {
+ // no partition will be pruned
+ plannedInputPartitionAfterFiltering.length shouldBe plannedInputPartitions.length
+ plannedInputPartitionAfterFiltering should contain theSameElementsAs plannedInputPartitions
+ }
+ }
+ }
+ }
+ }
+
+ private def getCosmosClientMetadataCachesSnapshots(): Broadcast[CosmosClientMetadataCachesSnapshots] = {
+ val cosmosClientMetadataCachesSnapshot = new CosmosClientMetadataCachesSnapshot()
+ cosmosClientMetadataCachesSnapshot.serialize(cosmosClient)
+
+ spark.sparkContext.broadcast(
+ CosmosClientMetadataCachesSnapshots(
+ cosmosClientMetadataCachesSnapshot,
+ Option.empty[CosmosClientMetadataCachesSnapshot]))
+ }
+
+ private def getReadManyFilters(
+ filteringItems: Array[ObjectNode],
+ partitionKeyProperty: String,
+ readManyFilteringProperty: String): Array[Filter] = {
+ val readManyFilterValues =
+ filteringItems
+ .map(filteringItem => getReadManyFilteringValue(filteringItem, partitionKeyProperty, readManyFilteringProperty))
+
+ if (partitionKeyProperty.equalsIgnoreCase(idProperty)) {
+ Array[Filter](In(idProperty, readManyFilterValues.map(_.asInstanceOf[Any])))
+ } else {
+ Array[Filter](In(readManyFilteringProperty, readManyFilterValues.map(_.asInstanceOf[Any])))
+ }
+ }
+
+ private def getReadManyFilteringValue(
+ objectNode: ObjectNode,
+ partitionKeyProperty: String,
+ readManyFilteringProperty: String): String = {
+
+ if (readManyFilteringProperty.equals(itemIdentityProperty)) {
+ CosmosItemIdentityHelper
+ .getCosmosItemIdentityValueString(
+ objectNode.get(idProperty).asText(),
+ List(objectNode.get(partitionKeyProperty).asText()))
+ } else {
+ objectNode.get(idProperty).asText()
+ }
+ }
+
+ private def getNewItem(partitionKeyProperty: String): ObjectNode = {
+ val objectNode = Utils.getSimpleObjectMapper.createObjectNode()
+ val id = UUID.randomUUID().toString
+ objectNode.put(idProperty, id)
+
+ if (!partitionKeyProperty.equalsIgnoreCase(idProperty)) {
+ val pk = UUID.randomUUID().toString
+ objectNode.put(partitionKeyProperty, pk)
+ }
+
+ objectNode
+ }
+
+ private def getDefaultSchema(partitionKeyProperty: String): StructType = {
+ if (!partitionKeyProperty.equalsIgnoreCase(idProperty)) {
+ StructType(Seq(
+ StructField(idProperty, StringType),
+ StructField(pkProperty, StringType),
+ StructField(itemIdentityProperty, StringType)
+ ))
+ } else {
+ StructType(Seq(
+ StructField(idProperty, StringType)
+ ))
+ }
+ }
+
+ //scalastyle:on multiple.string.literals
+ //scalastyle:on magic.number
+}
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/RowSerializerPollTest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/RowSerializerPollTest.scala
new file mode 100644
index 000000000000..2335bedf917d
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/RowSerializerPollTest.scala
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+package com.azure.cosmos.spark
+
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+
+class RowSerializerPollTest extends RowSerializerPollSpec {
+ //scalastyle:off multiple.string.literals
+
+ "RowSerializer " should "be returned to the pool only a limited number of times" in {
+ val canRun = Platform.canRunTestAccessingDirectByteBuffer
+ assume(canRun._1, canRun._2)
+
+ val schema = StructType(Seq(StructField("column01", IntegerType), StructField("column02", StringType)))
+
+ for (_ <- 1 to 256) {
+ RowSerializerPool.returnSerializerToPool(schema, ExpressionEncoder.apply(schema).createSerializer()) shouldBe true
+ }
+
+ logInfo("First 256 attempt to pool succeeded")
+
+ RowSerializerPool.returnSerializerToPool(schema, ExpressionEncoder.apply(schema).createSerializer()) shouldBe false
+ }
+ //scalastyle:on multiple.string.literals
+}
+
diff --git a/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala
new file mode 100644
index 000000000000..5f9cb1dbdbc8
--- /dev/null
+++ b/sdk/cosmos/azure-cosmos-spark_4-1_2-13/src/test/scala/com/azure/cosmos/spark/SparkE2EQueryITest.scala
@@ -0,0 +1,70 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+package com.azure.cosmos.spark
+
+import com.azure.cosmos.implementation.TestConfigurations
+import com.fasterxml.jackson.databind.node.ObjectNode
+
+import java.util.UUID
+
+class SparkE2EQueryITest
+ extends SparkE2EQueryITestBase {
+
+ "spark query" can "return proper Cosmos specific query plan on explain with nullable properties" in {
+ val cosmosEndpoint = TestConfigurations.HOST
+ val cosmosMasterKey = TestConfigurations.MASTER_KEY
+
+ val id = UUID.randomUUID().toString
+
+ val rawItem =
+ s"""
+ | {
+ | "id" : "$id",
+ | "nestedObject" : {
+ | "prop1" : 5,
+ | "prop2" : "6"
+ | }
+ | }
+ |""".stripMargin
+
+ val objectNode = objectMapper.readValue(rawItem, classOf[ObjectNode])
+
+ val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
+ container.createItem(objectNode).block()
+
+ val cfg = Map("spark.cosmos.accountEndpoint" -> cosmosEndpoint,
+ "spark.cosmos.accountKey" -> cosmosMasterKey,
+ "spark.cosmos.database" -> cosmosDatabase,
+ "spark.cosmos.container" -> cosmosContainer,
+ "spark.cosmos.read.inferSchema.forceNullableProperties" -> "true",
+ "spark.cosmos.read.partitioning.strategy" -> "Restrictive"
+ )
+
+ val df = spark.read.format("cosmos.oltp").options(cfg).load()
+ val rowsArray = df.where("nestedObject.prop2 = '6'").collect()
+ rowsArray should have size 1
+
+ var output = new java.io.ByteArrayOutputStream()
+ Console.withOut(output) {
+ df.explain()
+ }
+ var queryPlan = output.toString.replaceAll("#\\d+", "#x")
+ logInfo(s"Query Plan: $queryPlan")
+ queryPlan.contains("Cosmos Query: SELECT * FROM r") shouldEqual true
+
+ output = new java.io.ByteArrayOutputStream()
+ Console.withOut(output) {
+ df.where("nestedObject.prop2 = '6'").explain()
+ }
+ queryPlan = output.toString.replaceAll("#\\d+", "#x")
+ logInfo(s"Query Plan: $queryPlan")
+ val expected = s"Cosmos Query: SELECT * FROM r WHERE (NOT(IS_NULL(r['nestedObject']['prop2'])) AND IS_DEFINED(r['nestedObject']['prop2'])) " +
+ s"AND r['nestedObject']['prop2']=" +
+ s"@param0${System.getProperty("line.separator")} > param: @param0 = 6"
+ queryPlan.contains(expected) shouldEqual true
+
+ val item = rowsArray(0)
+ item.getAs[String]("id") shouldEqual id
+ }
+}