Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -472,29 +472,51 @@ private string CreateInitialQuery()
}
else if (!string.IsNullOrEmpty(CatalogName))
{
CatalogName = SqlServerEscapeHelper.EscapeIdentifier(CatalogName);
CatalogName = SqlServerEscapeHelper.EscapeStringAsLiteral(SqlServerEscapeHelper.EscapeIdentifier(CatalogName));
}

string objectName = ADP.BuildMultiPartName(parts);
string escapedObjectName = SqlServerEscapeHelper.EscapeStringAsLiteral(objectName);
// Specify the column names explicitly. This is to ensure that we can map to hidden columns (e.g. columns in temporal tables.)
// If the target table doesn't exist, OBJECT_ID will return NULL and @Column_Names will remain non-null. The subsequent SELECT *
// query will then continue to fail with "Invalid object name" rather than with an unusual error because the query being executed
// is NULL.
// Some hidden columns (e.g. SQL Graph columns) cannot be selected, so we need to exclude them explicitly.
// Specify the column names explicitly. This is to ensure that we can map to hidden
// columns (e.g. columns in temporal tables.) If the target table doesn't exist,
// OBJECT_ID will return NULL and @Column_Names will remain non-null. The subsequent
// SELECT * query will then continue to fail with "Invalid object name" rather than with
// an unusual error because the query being executed is NULL.
//
// Some hidden columns (e.g. SQL Graph columns) cannot be selected, so we need to
// exclude them explicitly. The graph_type values excluded below are internal graph
// columns that cannot be selected directly:
//
// 1 = GRAPH_ID
// 3 = GRAPH_FROM_ID
// 4 = GRAPH_FROM_OBJ_ID
// 6 = GRAPH_TO_ID
// 7 = GRAPH_TO_OBJ_ID
//
// See: https://learn.microsoft.com/sql/relational-databases/graphs/sql-graph-architecture#syscolumns
//
// The column-name query is built as dynamic SQL and executed via sp_executesql so
// that it is not compiled (and rejected) on SQL Server versions that lack the
// graph_type column (e.g. SQL 2016). CatalogName and escapedObjectName are
// interpolated directly into the SQL string because SQL Server does not allow
// identifiers (database/schema/table names) to be passed as parameters. Both
// values are escaped via SqlServerEscapeHelper before interpolation.
return $"""
SELECT @@TRANCOUNT;

DECLARE @Object_ID INT = OBJECT_ID('{escapedObjectName}');
DECLARE @Column_Name_Query NVARCHAR(MAX);
DECLARE @Column_Names NVARCHAR(MAX) = NULL;
IF EXISTS (SELECT TOP 1 * FROM sys.all_columns WHERE [object_id] = OBJECT_ID('sys.all_columns') AND [name] = 'graph_type')
BEGIN
SELECT @Column_Names = COALESCE(@Column_Names + ', ', '') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}') AND COALESCE([graph_type], 0) NOT IN (1, 3, 4, 6, 7) ORDER BY [column_id] ASC;
SET @Column_Name_Query = N'SELECT @Column_Names = COALESCE(@Column_Names + '', '', '''') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) NOT IN (1, 3, 4, 6, 7) ORDER BY [column_id] ASC;';
END
ELSE
BEGIN
SELECT @Column_Names = COALESCE(@Column_Names + ', ', '') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}') ORDER BY [column_id] ASC;
SET @Column_Name_Query = N'SELECT @Column_Names = COALESCE(@Column_Names + '', '', '''') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID ORDER BY [column_id] ASC;';
END

EXEC sp_executesql @Column_Name_Query, N'@Object_ID INT, @Column_Names NVARCHAR(MAX) OUTPUT', @Object_ID = @Object_ID, @Column_Names = @Column_Names OUTPUT;
SELECT @Column_Names = COALESCE(@Column_Names, '*');

SET FMTONLY ON;
Expand Down Expand Up @@ -624,7 +646,7 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i

bool matched = false;
bool rejected = false;

// Look for a local match for the remote column.
for (int j = 0; j < _localColumnMappings.Count; ++j)
{
Expand All @@ -644,7 +666,7 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i

// Remove it from our unmatched set.
unmatchedColumns.Remove(localColumn.DestinationColumn);

// Check for column types that we refuse to bulk load, even
// though we found a match.
//
Expand Down Expand Up @@ -1437,7 +1459,7 @@ private void RunParserReliably(BulkCopySimpleResultSet bulkCopyHandler = null)
try
{
// @TODO: CER Exception Handling was removed here (see GH#3581)
_parser.Run(RunBehavior.UntilDone, null, null, bulkCopyHandler, _stateObj);
_parser.Run(RunBehavior.UntilDone, null, null, bulkCopyHandler, _stateObj);
}
finally
{
Expand Down Expand Up @@ -1760,7 +1782,7 @@ public void WriteToServer(DbDataReader reader)
try
{
statistics = SqlStatistics.StartTimer(Statistics);

ResetWriteToServerGlobalVariables();
_rowSource = reader;
_dbDataReaderRowSource = reader;
Expand Down Expand Up @@ -1796,13 +1818,13 @@ public void WriteToServer(IDataReader reader)
try
{
statistics = SqlStatistics.StartTimer(Statistics);

ResetWriteToServerGlobalVariables();
_rowSource = reader;
_sqlDataReaderRowSource = _rowSource as SqlDataReader;
_dbDataReaderRowSource = _rowSource as DbDataReader;
_rowSourceType = ValueSourceType.IDataReader;

WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false;
}
finally
Expand Down Expand Up @@ -1918,7 +1940,7 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok
try
{
statistics = SqlStatistics.StartTimer(Statistics);

ResetWriteToServerGlobalVariables();
if (rows.Length == 0)
{
Expand All @@ -1935,9 +1957,9 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok
_rowSourceType = ValueSourceType.RowArray;
_rowEnumerator = rows.GetEnumerator();
_isAsyncBulkCopy = true;

// It returns Task since _isAsyncBulkCopy = true;
return WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken);
return WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken);
}
finally
{
Expand All @@ -1964,19 +1986,19 @@ public Task WriteToServerAsync(DbDataReader reader, CancellationToken cancellati
{
throw SQL.BulkLoadPendingOperation();
}

SqlStatistics statistics = Statistics;
try
{
statistics = SqlStatistics.StartTimer(Statistics);

ResetWriteToServerGlobalVariables();
_rowSource = reader;
_sqlDataReaderRowSource = reader as SqlDataReader;
_dbDataReaderRowSource = reader;
_rowSourceType = ValueSourceType.DbDataReader;
_isAsyncBulkCopy = true;

// It returns Task since _isAsyncBulkCopy = true;
return WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken);
}
Expand Down Expand Up @@ -2016,7 +2038,7 @@ public Task WriteToServerAsync(IDataReader reader, CancellationToken cancellatio
_dbDataReaderRowSource = _rowSource as DbDataReader;
_rowSourceType = ValueSourceType.IDataReader;
_isAsyncBulkCopy = true;

// It returns Task since _isAsyncBulkCopy = true;
return WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken);
}
Expand Down Expand Up @@ -2056,15 +2078,15 @@ public Task WriteToServerAsync(DataTable table, DataRowState rowState, Cancellat
try
{
statistics = SqlStatistics.StartTimer(Statistics);

ResetWriteToServerGlobalVariables();
_rowStateToSkip = ((rowState == 0) || (rowState == DataRowState.Deleted)) ? DataRowState.Deleted : ~rowState | DataRowState.Deleted;
_rowSource = table;
_dataTableSource = table;
_rowSourceType = ValueSourceType.DataTable;
_rowEnumerator = table.Rows.GetEnumerator();
_isAsyncBulkCopy = true;

// It returns Task since _isAsyncBulkCopy = true;
return WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken);
}
Expand Down Expand Up @@ -2114,7 +2136,7 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok

bool finishedSynchronously = true;
_isBulkCopyingInProgress = true;

CreateOrValidateConnection(nameof(WriteToServer));

SqlConnectionInternal internalConnection = _connection.GetOpenTdsConnection();
Expand Down Expand Up @@ -3065,11 +3087,11 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio

// No need to cancel timer since SqlBulkCopy creates specific task source for reconnection.
AsyncHelper.SetTimeoutExceptionWithState(
completion: cancellableReconnectTS,
completion: cancellableReconnectTS,
timeout: BulkCopyTimeout,
state: _destinationTableName,
onFailure: static state =>
SQL.BulkLoadInvalidDestinationTable((string)state, SQL.CR_ReconnectTimeout()),
onFailure: static state =>
SQL.BulkLoadInvalidDestinationTable((string)state, SQL.CR_ReconnectTimeout()),
cancellationToken: CancellationToken.None
);

Expand Down Expand Up @@ -3242,7 +3264,7 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken)
}
return resultTask;
}

private void ResetWriteToServerGlobalVariables()
{
_dataTableSource = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public static class DataTestUtility
internal static readonly string KerberosDomainPassword = null;

// SQL server Version
private static string s_sQLServerVersion = string.Empty;
private static string s_sqlServerVersion;

//SQL Server EngineEdition
private static string s_sqlServerEngineEdition;
Expand Down Expand Up @@ -125,9 +125,9 @@ public static string SQLServerVersion
{
if (!string.IsNullOrEmpty(TCPConnectionString))
{
s_sQLServerVersion ??= GetSqlServerProperty(TCPConnectionString, ServerProperty.ProductMajorVersion);
s_sqlServerVersion ??= GetSqlServerProperty(TCPConnectionString, ServerProperty.ProductMajorVersion);
}
return s_sQLServerVersion;
return s_sqlServerVersion;
}
}

Expand Down Expand Up @@ -491,7 +491,14 @@ public static bool AreConnStringsSetup()

public static bool IsSQL2019() => string.Equals("15", SQLServerVersion.Trim());

public static bool IsSQL2016() => string.Equals("14", s_sQLServerVersion.Trim());
public static bool IsSQL2017() => string.Equals("14", SQLServerVersion.Trim());

public static bool IsSQL2016() => string.Equals("13", SQLServerVersion.Trim());

// "At least" version checks for use as ConditionalFact/ConditionalTheory conditions.
public static bool IsAtLeastSQL2017() => int.TryParse(SQLServerVersion?.Trim(), out int major) && major >= 14;

public static bool IsAtLeastSQL2019() => int.TryParse(SQLServerVersion?.Trim(), out int major) && major >= 15;

public static bool IsSQLAliasSetup()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class CertificateTest : IDisposable

// InstanceName will get replaced with an instance name in the connection string
private static string InstanceName = "MSSQLSERVER";

// s_instanceNamePrefix will get replaced with MSSQL$ is there is an instance name in connection string
private static string InstanceNamePrefix = "";

Expand All @@ -51,10 +51,14 @@ private static string ForceEncryptionRegistryPath
{
return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL15.{InstanceName}\MSSQLSERVER\SuperSocketNetLib";
}
if (DataTestUtility.IsSQL2016())
if (DataTestUtility.IsSQL2017())
{
return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL14.{InstanceName}\MSSQLSERVER\SuperSocketNetLib";
}
if (DataTestUtility.IsSQL2016())
{
return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL13.{InstanceName}\MSSQLSERVER\SuperSocketNetLib";
}
return string.Empty;
}
}
Expand Down Expand Up @@ -196,7 +200,9 @@ private static void CreateValidCertificate(string script)
RedirectStandardError = true,
RedirectStandardOutput = true,
UseShellExecute = false,
Arguments = $"{script} -Prefix {InstanceNamePrefix} -Instance {InstanceName}",
Arguments = string.IsNullOrEmpty(InstanceNamePrefix)
? $"{script} -Instance \"{InstanceName}\""
: $"{script} -Prefix \"{InstanceNamePrefix}\" -Instance \"{InstanceName}\"",
CreateNoWindow = false,
Verb = "runas"
}
Expand Down Expand Up @@ -224,7 +230,12 @@ private static void CreateValidCertificate(string script)
proc.Kill();
// allow async output to process
proc.WaitForExit(2000);
throw new Exception($"Could not generate certificate.Error out put: {output}");
throw new Exception($"Could not generate certificate. Error output: {output}");
}

if (proc.ExitCode != 0)
{
throw new Exception($"Certificate generation script failed with exit code {proc.ExitCode}. Output: {output}");
}
}
else
Expand Down Expand Up @@ -252,6 +263,11 @@ private static string GetLocalIpAddress()

private void RemoveCertificate()
{
if (string.IsNullOrEmpty(_thumbprint))
{
return;
}

using X509Store certStore = new(StoreName.Root, StoreLocation.LocalMachine);
certStore.Open(OpenFlags.ReadWrite);
X509Certificate2Collection certCollection = certStore.Certificates.Find(X509FindType.FindByThumbprint, _thumbprint, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ private static string ForceEncryptionRegistryPath
{
return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL15.{s_instanceName}\MSSQLSERVER\SuperSocketNetLib";
}
if (DataTestUtility.IsSQL2016())
if (DataTestUtility.IsSQL2017())
{
return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL14.{s_instanceName}\MSSQLSERVER\SuperSocketNetLib";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ public static void Test(string srcConstr, string dstConstr, string dstTable)
DataTestUtility.AssertEqualsWithDescription((long)3, stats["BuffersReceived"], "Unexpected BuffersReceived value.");
DataTestUtility.AssertEqualsWithDescription((long)3, stats["BuffersSent"], "Unexpected BuffersSent value.");
DataTestUtility.AssertEqualsWithDescription((long)0, stats["IduCount"], "Unexpected IduCount value.");
DataTestUtility.AssertEqualsWithDescription((long)6, stats["SelectCount"], "Unexpected SelectCount value.");
DataTestUtility.AssertEqualsWithDescription((long)8, stats["SelectCount"], "Unexpected SelectCount value.");
DataTestUtility.AssertEqualsWithDescription((long)3, stats["ServerRoundtrips"], "Unexpected ServerRoundtrips value.");
DataTestUtility.AssertEqualsWithDescription((long)9, stats["SelectRows"], "Unexpected SelectRows value.");
DataTestUtility.AssertEqualsWithDescription((long)11, stats["SelectRows"], "Unexpected SelectRows value.");
DataTestUtility.AssertEqualsWithDescription((long)2, stats["SumResultSets"], "Unexpected SumResultSets value.");
DataTestUtility.AssertEqualsWithDescription((long)0, stats["Transactions"], "Unexpected Transactions value.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SqlBulkCopyTests
{
public class SqlGraphTables
{
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureSynapse))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.IsAtLeastSQL2017))]
public void WriteToServer_CopyToSqlGraphNodeTable_Succeeds()
{
string connectionString = DataTestUtility.TCPConnectionString;
Expand Down
Loading
Loading