diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 710d81045f..75cde1c56f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -472,29 +472,51 @@ private string CreateInitialQuery() } else if (!string.IsNullOrEmpty(CatalogName)) { - CatalogName = SqlServerEscapeHelper.EscapeIdentifier(CatalogName); + CatalogName = SqlServerEscapeHelper.EscapeStringAsLiteral(SqlServerEscapeHelper.EscapeIdentifier(CatalogName)); } string objectName = ADP.BuildMultiPartName(parts); string escapedObjectName = SqlServerEscapeHelper.EscapeStringAsLiteral(objectName); - // Specify the column names explicitly. This is to ensure that we can map to hidden columns (e.g. columns in temporal tables.) - // If the target table doesn't exist, OBJECT_ID will return NULL and @Column_Names will remain non-null. The subsequent SELECT * - // query will then continue to fail with "Invalid object name" rather than with an unusual error because the query being executed - // is NULL. - // Some hidden columns (e.g. SQL Graph columns) cannot be selected, so we need to exclude them explicitly. + // Specify the column names explicitly. This is to ensure that we can map to hidden + // columns (e.g. columns in temporal tables.) If the target table doesn't exist, + // OBJECT_ID will return NULL and @Column_Names will remain non-null. The subsequent + // SELECT * query will then continue to fail with "Invalid object name" rather than with + // an unusual error because the query being executed is NULL. + // + // Some hidden columns (e.g. SQL Graph columns) cannot be selected, so we need to + // exclude them explicitly. The graph_type values excluded below are internal graph + // columns that cannot be selected directly: + // + // 1 = GRAPH_ID + // 3 = GRAPH_FROM_ID + // 4 = GRAPH_FROM_OBJ_ID + // 6 = GRAPH_TO_ID + // 7 = GRAPH_TO_OBJ_ID + // + // See: https://learn.microsoft.com/sql/relational-databases/graphs/sql-graph-architecture#syscolumns + // + // The column-name query is built as dynamic SQL and executed via sp_executesql so + // that it is not compiled (and rejected) on SQL Server versions that lack the + // graph_type column (e.g. SQL 2016). CatalogName and escapedObjectName are + // interpolated directly into the SQL string because SQL Server does not allow + // identifiers (database/schema/table names) to be passed as parameters. Both + // values are escaped via SqlServerEscapeHelper before interpolation. return $""" SELECT @@TRANCOUNT; +DECLARE @Object_ID INT = OBJECT_ID('{escapedObjectName}'); +DECLARE @Column_Name_Query NVARCHAR(MAX); DECLARE @Column_Names NVARCHAR(MAX) = NULL; IF EXISTS (SELECT TOP 1 * FROM sys.all_columns WHERE [object_id] = OBJECT_ID('sys.all_columns') AND [name] = 'graph_type') BEGIN - SELECT @Column_Names = COALESCE(@Column_Names + ', ', '') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}') AND COALESCE([graph_type], 0) NOT IN (1, 3, 4, 6, 7) ORDER BY [column_id] ASC; + SET @Column_Name_Query = N'SELECT @Column_Names = COALESCE(@Column_Names + '', '', '''') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) NOT IN (1, 3, 4, 6, 7) ORDER BY [column_id] ASC;'; END ELSE BEGIN - SELECT @Column_Names = COALESCE(@Column_Names + ', ', '') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}') ORDER BY [column_id] ASC; + SET @Column_Name_Query = N'SELECT @Column_Names = COALESCE(@Column_Names + '', '', '''') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID ORDER BY [column_id] ASC;'; END +EXEC sp_executesql @Column_Name_Query, N'@Object_ID INT, @Column_Names NVARCHAR(MAX) OUTPUT', @Object_ID = @Object_ID, @Column_Names = @Column_Names OUTPUT; SELECT @Column_Names = COALESCE(@Column_Names, '*'); SET FMTONLY ON; @@ -624,7 +646,7 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i bool matched = false; bool rejected = false; - + // Look for a local match for the remote column. for (int j = 0; j < _localColumnMappings.Count; ++j) { @@ -644,7 +666,7 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i // Remove it from our unmatched set. unmatchedColumns.Remove(localColumn.DestinationColumn); - + // Check for column types that we refuse to bulk load, even // though we found a match. // @@ -1437,7 +1459,7 @@ private void RunParserReliably(BulkCopySimpleResultSet bulkCopyHandler = null) try { // @TODO: CER Exception Handling was removed here (see GH#3581) - _parser.Run(RunBehavior.UntilDone, null, null, bulkCopyHandler, _stateObj); + _parser.Run(RunBehavior.UntilDone, null, null, bulkCopyHandler, _stateObj); } finally { @@ -1760,7 +1782,7 @@ public void WriteToServer(DbDataReader reader) try { statistics = SqlStatistics.StartTimer(Statistics); - + ResetWriteToServerGlobalVariables(); _rowSource = reader; _dbDataReaderRowSource = reader; @@ -1796,13 +1818,13 @@ public void WriteToServer(IDataReader reader) try { statistics = SqlStatistics.StartTimer(Statistics); - + ResetWriteToServerGlobalVariables(); _rowSource = reader; _sqlDataReaderRowSource = _rowSource as SqlDataReader; _dbDataReaderRowSource = _rowSource as DbDataReader; _rowSourceType = ValueSourceType.IDataReader; - + WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } finally @@ -1918,7 +1940,7 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok try { statistics = SqlStatistics.StartTimer(Statistics); - + ResetWriteToServerGlobalVariables(); if (rows.Length == 0) { @@ -1935,9 +1957,9 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok _rowSourceType = ValueSourceType.RowArray; _rowEnumerator = rows.GetEnumerator(); _isAsyncBulkCopy = true; - + // It returns Task since _isAsyncBulkCopy = true; - return WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken); + return WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken); } finally { @@ -1964,19 +1986,19 @@ public Task WriteToServerAsync(DbDataReader reader, CancellationToken cancellati { throw SQL.BulkLoadPendingOperation(); } - + SqlStatistics statistics = Statistics; try { statistics = SqlStatistics.StartTimer(Statistics); - + ResetWriteToServerGlobalVariables(); _rowSource = reader; _sqlDataReaderRowSource = reader as SqlDataReader; _dbDataReaderRowSource = reader; _rowSourceType = ValueSourceType.DbDataReader; _isAsyncBulkCopy = true; - + // It returns Task since _isAsyncBulkCopy = true; return WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); } @@ -2016,7 +2038,7 @@ public Task WriteToServerAsync(IDataReader reader, CancellationToken cancellatio _dbDataReaderRowSource = _rowSource as DbDataReader; _rowSourceType = ValueSourceType.IDataReader; _isAsyncBulkCopy = true; - + // It returns Task since _isAsyncBulkCopy = true; return WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); } @@ -2056,7 +2078,7 @@ public Task WriteToServerAsync(DataTable table, DataRowState rowState, Cancellat try { statistics = SqlStatistics.StartTimer(Statistics); - + ResetWriteToServerGlobalVariables(); _rowStateToSkip = ((rowState == 0) || (rowState == DataRowState.Deleted)) ? DataRowState.Deleted : ~rowState | DataRowState.Deleted; _rowSource = table; @@ -2064,7 +2086,7 @@ public Task WriteToServerAsync(DataTable table, DataRowState rowState, Cancellat _rowSourceType = ValueSourceType.DataTable; _rowEnumerator = table.Rows.GetEnumerator(); _isAsyncBulkCopy = true; - + // It returns Task since _isAsyncBulkCopy = true; return WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken); } @@ -2114,7 +2136,7 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok bool finishedSynchronously = true; _isBulkCopyingInProgress = true; - + CreateOrValidateConnection(nameof(WriteToServer)); SqlConnectionInternal internalConnection = _connection.GetOpenTdsConnection(); @@ -3065,11 +3087,11 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio // No need to cancel timer since SqlBulkCopy creates specific task source for reconnection. AsyncHelper.SetTimeoutExceptionWithState( - completion: cancellableReconnectTS, + completion: cancellableReconnectTS, timeout: BulkCopyTimeout, state: _destinationTableName, - onFailure: static state => - SQL.BulkLoadInvalidDestinationTable((string)state, SQL.CR_ReconnectTimeout()), + onFailure: static state => + SQL.BulkLoadInvalidDestinationTable((string)state, SQL.CR_ReconnectTimeout()), cancellationToken: CancellationToken.None ); @@ -3242,7 +3264,7 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken) } return resultTask; } - + private void ResetWriteToServerGlobalVariables() { _dataTableSource = null; diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs index 5bceb81b58..2e652841c9 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs @@ -86,7 +86,7 @@ public static class DataTestUtility internal static readonly string KerberosDomainPassword = null; // SQL server Version - private static string s_sQLServerVersion = string.Empty; + private static string s_sqlServerVersion; //SQL Server EngineEdition private static string s_sqlServerEngineEdition; @@ -125,9 +125,9 @@ public static string SQLServerVersion { if (!string.IsNullOrEmpty(TCPConnectionString)) { - s_sQLServerVersion ??= GetSqlServerProperty(TCPConnectionString, ServerProperty.ProductMajorVersion); + s_sqlServerVersion ??= GetSqlServerProperty(TCPConnectionString, ServerProperty.ProductMajorVersion); } - return s_sQLServerVersion; + return s_sqlServerVersion; } } @@ -491,7 +491,14 @@ public static bool AreConnStringsSetup() public static bool IsSQL2019() => string.Equals("15", SQLServerVersion.Trim()); - public static bool IsSQL2016() => string.Equals("14", s_sQLServerVersion.Trim()); + public static bool IsSQL2017() => string.Equals("14", SQLServerVersion.Trim()); + + public static bool IsSQL2016() => string.Equals("13", SQLServerVersion.Trim()); + + // "At least" version checks for use as ConditionalFact/ConditionalTheory conditions. + public static bool IsAtLeastSQL2017() => int.TryParse(SQLServerVersion?.Trim(), out int major) && major >= 14; + + public static bool IsAtLeastSQL2019() => int.TryParse(SQLServerVersion?.Trim(), out int major) && major >= 15; public static bool IsSQLAliasSetup() { diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTest.cs index fbc7cd222a..32780dcfe6 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTest.cs @@ -32,7 +32,7 @@ public class CertificateTest : IDisposable // InstanceName will get replaced with an instance name in the connection string private static string InstanceName = "MSSQLSERVER"; - + // s_instanceNamePrefix will get replaced with MSSQL$ is there is an instance name in connection string private static string InstanceNamePrefix = ""; @@ -51,10 +51,14 @@ private static string ForceEncryptionRegistryPath { return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL15.{InstanceName}\MSSQLSERVER\SuperSocketNetLib"; } - if (DataTestUtility.IsSQL2016()) + if (DataTestUtility.IsSQL2017()) { return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL14.{InstanceName}\MSSQLSERVER\SuperSocketNetLib"; } + if (DataTestUtility.IsSQL2016()) + { + return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL13.{InstanceName}\MSSQLSERVER\SuperSocketNetLib"; + } return string.Empty; } } @@ -196,7 +200,9 @@ private static void CreateValidCertificate(string script) RedirectStandardError = true, RedirectStandardOutput = true, UseShellExecute = false, - Arguments = $"{script} -Prefix {InstanceNamePrefix} -Instance {InstanceName}", + Arguments = string.IsNullOrEmpty(InstanceNamePrefix) + ? $"{script} -Instance \"{InstanceName}\"" + : $"{script} -Prefix \"{InstanceNamePrefix}\" -Instance \"{InstanceName}\"", CreateNoWindow = false, Verb = "runas" } @@ -224,7 +230,12 @@ private static void CreateValidCertificate(string script) proc.Kill(); // allow async output to process proc.WaitForExit(2000); - throw new Exception($"Could not generate certificate.Error out put: {output}"); + throw new Exception($"Could not generate certificate. Error output: {output}"); + } + + if (proc.ExitCode != 0) + { + throw new Exception($"Certificate generation script failed with exit code {proc.ExitCode}. Output: {output}"); } } else @@ -252,6 +263,11 @@ private static string GetLocalIpAddress() private void RemoveCertificate() { + if (string.IsNullOrEmpty(_thumbprint)) + { + return; + } + using X509Store certStore = new(StoreName.Root, StoreLocation.LocalMachine); certStore.Open(OpenFlags.ReadWrite); X509Certificate2Collection certCollection = certStore.Certificates.Find(X509FindType.FindByThumbprint, _thumbprint, false); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs index 98ce8efafa..bb1bb8b5f2 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs @@ -89,7 +89,7 @@ private static string ForceEncryptionRegistryPath { return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL15.{s_instanceName}\MSSQLSERVER\SuperSocketNetLib"; } - if (DataTestUtility.IsSQL2016()) + if (DataTestUtility.IsSQL2017()) { return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL14.{s_instanceName}\MSSQLSERVER\SuperSocketNetLib"; } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/CopyAllFromReader.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/CopyAllFromReader.cs index 5ba727be5d..f55e512a35 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/CopyAllFromReader.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/CopyAllFromReader.cs @@ -61,9 +61,9 @@ public static void Test(string srcConstr, string dstConstr, string dstTable) DataTestUtility.AssertEqualsWithDescription((long)3, stats["BuffersReceived"], "Unexpected BuffersReceived value."); DataTestUtility.AssertEqualsWithDescription((long)3, stats["BuffersSent"], "Unexpected BuffersSent value."); DataTestUtility.AssertEqualsWithDescription((long)0, stats["IduCount"], "Unexpected IduCount value."); - DataTestUtility.AssertEqualsWithDescription((long)6, stats["SelectCount"], "Unexpected SelectCount value."); + DataTestUtility.AssertEqualsWithDescription((long)8, stats["SelectCount"], "Unexpected SelectCount value."); DataTestUtility.AssertEqualsWithDescription((long)3, stats["ServerRoundtrips"], "Unexpected ServerRoundtrips value."); - DataTestUtility.AssertEqualsWithDescription((long)9, stats["SelectRows"], "Unexpected SelectRows value."); + DataTestUtility.AssertEqualsWithDescription((long)11, stats["SelectRows"], "Unexpected SelectRows value."); DataTestUtility.AssertEqualsWithDescription((long)2, stats["SumResultSets"], "Unexpected SumResultSets value."); DataTestUtility.AssertEqualsWithDescription((long)0, stats["Transactions"], "Unexpected Transactions value."); } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/SqlGraphTables.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/SqlGraphTables.cs index d83693080f..837ce96e76 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/SqlGraphTables.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/SqlGraphTables.cs @@ -11,7 +11,7 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SqlBulkCopyTests { public class SqlGraphTables { - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureSynapse))] + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.IsAtLeastSQL2017))] public void WriteToServer_CopyToSqlGraphNodeTable_Succeeds() { string connectionString = DataTestUtility.TCPConnectionString; diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/TestBulkCopyWithUTF8.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/TestBulkCopyWithUTF8.cs index 5b7112476d..da5287e831 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/TestBulkCopyWithUTF8.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/TestBulkCopyWithUTF8.cs @@ -26,6 +26,13 @@ public sealed class TestBulkCopyWithUtf8 : IDisposable /// public TestBulkCopyWithUtf8() { + // xUnit instantiates the class even when ConditionalTheory conditions cause the + // test to be skipped, so we must guard setup that requires UTF-8 collations. + if (!DataTestUtility.IsAtLeastSQL2019()) + { + return; + } + using SqlConnection sourceConnection = new SqlConnection(GetConnectionString(true)); sourceConnection.Open(); SetupTables(sourceConnection, s_sourceTable, s_destinationTable, s_insertQuery); @@ -36,6 +43,12 @@ public TestBulkCopyWithUtf8() /// public void Dispose() { + // Guard matches the constructor: no tables were created on older SQL versions. + if (!DataTestUtility.IsAtLeastSQL2019()) + { + return; + } + using SqlConnection connection = new SqlConnection(GetConnectionString(true)); connection.Open(); DataTestUtility.DropTable(connection, s_sourceTable); @@ -56,7 +69,7 @@ private string GetConnectionString(bool enableMars) /// /// Creates source and destination tables with a varchar(max) column with a collation setting - /// that stores the data in UTF8 encoding and inserts the data in the source table. + /// that stores the data in UTF8 encoding and inserts the data in the source table. /// private void SetupTables(SqlConnection connection, string sourceTable, string destinationTable, string insertQuery) { @@ -75,7 +88,8 @@ private void SetupTables(SqlConnection connection, string sourceTable, string de [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer), - nameof(DataTestUtility.IsNotAzureSynapse))] + nameof(DataTestUtility.IsNotAzureSynapse), + nameof(DataTestUtility.IsAtLeastSQL2019))] [InlineData(true, true)] [InlineData(false, true)] [InlineData(true, false)] @@ -139,7 +153,8 @@ public void BulkCopy_Utf8Data_ShouldMatchSource(bool isMarsEnabled, bool enableS [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer), - nameof(DataTestUtility.IsNotAzureSynapse))] + nameof(DataTestUtility.IsNotAzureSynapse), + nameof(DataTestUtility.IsAtLeastSQL2019))] [InlineData(true, true)] [InlineData(false, true)] [InlineData(true, false)]