diff --git a/src/Microsoft.Data.SqlClient/tests/Common/Fixtures/AzureKeyVaultKeyFixtureBase.cs b/src/Microsoft.Data.SqlClient/tests/Common/Fixtures/AzureKeyVaultKeyFixtureBase.cs index f70560be9e..2836156b65 100644 --- a/src/Microsoft.Data.SqlClient/tests/Common/Fixtures/AzureKeyVaultKeyFixtureBase.cs +++ b/src/Microsoft.Data.SqlClient/tests/Common/Fixtures/AzureKeyVaultKeyFixtureBase.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.Security.Cryptography; using Azure.Core; using Azure.Security.KeyVault.Keys; @@ -20,20 +21,41 @@ namespace Microsoft.Data.SqlClient.Tests.Common.Fixtures; public abstract class AzureKeyVaultKeyFixtureBase : IDisposable { private readonly KeyClient _keyClient; - private readonly Random _randomGenerator; + private readonly RandomNumberGenerator _randomGenerator; private readonly List _createdKeys = new List(); protected AzureKeyVaultKeyFixtureBase(Uri keyVaultUri, TokenCredential keyVaultToken) { _keyClient = new KeyClient(keyVaultUri, keyVaultToken); - _randomGenerator = new Random(); + _randomGenerator = RandomNumberGenerator.Create(); } protected Uri CreateKey(string name, int keySize) { - CreateRsaKeyOptions createOptions = new CreateRsaKeyOptions(GenerateUniqueName(name)) { KeySize = keySize }; - KeyVaultKey created = _keyClient.CreateRsaKey(createOptions); + const int MaxConflictResolutions = 5; + KeyVaultKey created; + int i = 0; + + while (true) + { + CreateRsaKeyOptions createOptions = new CreateRsaKeyOptions(GenerateUniqueName(name)) { KeySize = keySize }; + + try + { + created = _keyClient.CreateRsaKey(createOptions); + break; + } + // It's possible for a key to already exist with the same name, even in a deleted state. If so, CreateRsaKey + // will throw an exception with HTTP status code 409 (Conflict.) + // We can't assume we possess permissions to purge or to recover the key, so regenerate the name and try again. + // Only make MaxConflictResolutions attempts, to avoid possible infinite loops. + catch (Azure.RequestFailedException conflictException) + when (conflictException.Status == 409 && i < MaxConflictResolutions) + { + i++; + } + } _createdKeys.Add(created); return created.Id; @@ -43,7 +65,7 @@ private string GenerateUniqueName(string name) { byte[] rndBytes = new byte[16]; - _randomGenerator.NextBytes(rndBytes); + _randomGenerator.GetBytes(rndBytes); return name + "-" + BitConverter.ToString(rndBytes); } @@ -66,5 +88,7 @@ protected virtual void Dispose(bool disposing) continue; } } + + _randomGenerator.Dispose(); } } diff --git a/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs b/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs index 11523b3f83..ec67807720 100644 --- a/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs +++ b/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs @@ -204,7 +204,8 @@ public void Dispose() #region Switch Value Getters and Setters - // These properties get or set the like-named underlying switch field value. + // These properties get the like-named underlying switch *property* value and set the underlying + // switch *field* value. This allows tests to verify the default switch values. // // They all throw if the value cannot be retrieved or set. @@ -214,7 +215,7 @@ public void Dispose() /// public bool? DisableTnirByDefault { - get => GetSwitchValue("s_disableTnirByDefault"); + get => GetSwitchPropertyValue(nameof(DisableTnirByDefault)); set => SetSwitchValue("s_disableTnirByDefault", value); } #endif @@ -224,7 +225,7 @@ public bool? DisableTnirByDefault /// public bool? EnableMultiSubnetFailoverByDefault { - get => GetSwitchValue("s_enableMultiSubnetFailoverByDefault"); + get => GetSwitchPropertyValue(nameof(EnableMultiSubnetFailoverByDefault)); set => SetSwitchValue("s_enableMultiSubnetFailoverByDefault", value); } @@ -234,7 +235,7 @@ public bool? EnableMultiSubnetFailoverByDefault /// public bool? GlobalizationInvariantMode { - get => GetSwitchValue("s_globalizationInvariantMode"); + get => GetSwitchPropertyValue(nameof(GlobalizationInvariantMode)); set => SetSwitchValue("s_globalizationInvariantMode", value); } #endif @@ -244,7 +245,7 @@ public bool? GlobalizationInvariantMode /// public bool? IgnoreServerProvidedFailoverPartner { - get => GetSwitchValue("s_ignoreServerProvidedFailoverPartner"); + get => GetSwitchPropertyValue(nameof(IgnoreServerProvidedFailoverPartner)); set => SetSwitchValue("s_ignoreServerProvidedFailoverPartner", value); } @@ -253,7 +254,7 @@ public bool? IgnoreServerProvidedFailoverPartner /// public bool? UseLegacyFailoverAlternationOnLoginSqlErrors { - get => GetSwitchValue("s_useLegacyFailoverAlternationOnLoginSqlErrors"); + get => GetSwitchPropertyValue(nameof(UseLegacyFailoverAlternationOnLoginSqlErrors)); set => SetSwitchValue("s_useLegacyFailoverAlternationOnLoginSqlErrors", value); } @@ -262,7 +263,7 @@ public bool? UseLegacyFailoverAlternationOnLoginSqlErrors /// public bool? LegacyRowVersionNullBehavior { - get => GetSwitchValue("s_legacyRowVersionNullBehavior"); + get => GetSwitchPropertyValue(nameof(LegacyRowVersionNullBehavior)); set => SetSwitchValue("s_legacyRowVersionNullBehavior", value); } @@ -271,7 +272,7 @@ public bool? LegacyRowVersionNullBehavior /// public bool? LegacyVarTimeZeroScaleBehaviour { - get => GetSwitchValue("s_legacyVarTimeZeroScaleBehaviour"); + get => GetSwitchPropertyValue(nameof(LegacyVarTimeZeroScaleBehaviour)); set => SetSwitchValue("s_legacyVarTimeZeroScaleBehaviour", value); } @@ -280,7 +281,7 @@ public bool? LegacyVarTimeZeroScaleBehaviour /// public bool? MakeReadAsyncBlocking { - get => GetSwitchValue("s_makeReadAsyncBlocking"); + get => GetSwitchPropertyValue(nameof(MakeReadAsyncBlocking)); set => SetSwitchValue("s_makeReadAsyncBlocking", value); } @@ -289,7 +290,7 @@ public bool? MakeReadAsyncBlocking /// public bool? SuppressInsecureTlsWarning { - get => GetSwitchValue("s_suppressInsecureTlsWarning"); + get => GetSwitchPropertyValue(nameof(SuppressInsecureTlsWarning)); set => SetSwitchValue("s_suppressInsecureTlsWarning", value); } @@ -298,7 +299,7 @@ public bool? SuppressInsecureTlsWarning /// public bool? TruncateScaledDecimal { - get => GetSwitchValue("s_truncateScaledDecimal"); + get => GetSwitchPropertyValue(nameof(TruncateScaledDecimal)); set => SetSwitchValue("s_truncateScaledDecimal", value); } @@ -307,7 +308,7 @@ public bool? TruncateScaledDecimal /// public bool? UseCompatibilityAsyncBehaviour { - get => GetSwitchValue("s_useCompatibilityAsyncBehaviour"); + get => GetSwitchPropertyValue(nameof(UseCompatibilityAsyncBehaviour)); set => SetSwitchValue("s_useCompatibilityAsyncBehaviour", value); } @@ -316,7 +317,7 @@ public bool? UseCompatibilityAsyncBehaviour /// public bool? UseCompatibilityProcessSni { - get => GetSwitchValue("s_useCompatibilityProcessSni"); + get => GetSwitchPropertyValue(nameof(UseCompatibilityProcessSni)); set => SetSwitchValue("s_useCompatibilityProcessSni", value); } @@ -325,7 +326,7 @@ public bool? UseCompatibilityProcessSni /// public bool? UseConnectionPoolV2 { - get => GetSwitchValue("s_useConnectionPoolV2"); + get => GetSwitchPropertyValue(nameof(UseConnectionPoolV2)); set => SetSwitchValue("s_useConnectionPoolV2", value); } @@ -335,7 +336,7 @@ public bool? UseConnectionPoolV2 /// public bool? UseManagedNetworking { - get => GetSwitchValue("s_useManagedNetworking"); + get => GetSwitchPropertyValue(nameof(UseManagedNetworking)); set => SetSwitchValue("s_useManagedNetworking", value); } #endif @@ -345,7 +346,7 @@ public bool? UseManagedNetworking /// public bool? UseMinimumLoginTimeout { - get => GetSwitchValue("s_useMinimumLoginTimeout"); + get => GetSwitchPropertyValue(nameof(UseMinimumLoginTimeout)); set => SetSwitchValue("s_useMinimumLoginTimeout", value); } @@ -358,19 +359,7 @@ public bool? UseMinimumLoginTimeout /// private static bool? GetSwitchValue(string fieldName) { - var assembly = Assembly.GetAssembly(typeof(SqlConnection)); - if (assembly is null) - { - throw new InvalidOperationException( - "Could not get assembly for Microsoft.Data.SqlClient"); - } - - var type = assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); - if (type is null) - { - throw new InvalidOperationException( - "Could not get type LocalAppContextSwitches"); - } + var type = GetLocalAppContextSwitchesType(); var field = type.GetField( fieldName, @@ -405,19 +394,7 @@ public bool? UseMinimumLoginTimeout /// private static void SetSwitchValue(string fieldName, bool? value) { - var assembly = Assembly.GetAssembly(typeof(SqlConnection)); - if (assembly is null) - { - throw new InvalidOperationException( - "Could not get assembly for Microsoft.Data.SqlClient"); - } - - var type = assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); - if (type is null) - { - throw new InvalidOperationException( - "Could not get type LocalAppContextSwitches"); - } + var type = GetLocalAppContextSwitchesType(); var field = type.GetField( fieldName, @@ -442,5 +419,49 @@ private static void SetSwitchValue(string fieldName, bool? value) field.SetValue(null, Enum.ToObject(field.FieldType, byteValue)); } + /// + /// Use reflection to get a switch property value from LocalAppContextSwitches. + /// + /// + /// Each property in LocalAppContextSwitchHelper corresponds to a like-named property in + /// LocalAppContextSwitches, which may return a different value when the AppContext switch + /// has not been set. + /// + private static bool GetSwitchPropertyValue(string propertyName) + { + var type = GetLocalAppContextSwitchesType(); + var property = type.GetProperty( + propertyName, + BindingFlags.Static | BindingFlags.Public); + + if (property == null) + { + throw new InvalidOperationException( + $"Property '{propertyName}' not found in LocalAppContextSwitches"); + } + + object? value = property.GetValue(null); + + return value is bool boolValue + ? boolValue + : throw new InvalidOperationException($"Property '{propertyName}' is not of type bool."); + } + + private static Type GetLocalAppContextSwitchesType() + { + var assembly = Assembly.GetAssembly(typeof(SqlConnection)); + if (assembly is null) + { + throw new InvalidOperationException("Could not get assembly for Microsoft.Data.SqlClient"); + } + + var type = assembly.GetType("Microsoft.Data.SqlClient.LocalAppContextSwitches"); + if (type is null) + { + throw new InvalidOperationException("Could not get type LocalAppContextSwitches"); + } + return type; + } + #endregion } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Batch/BatchTests.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Batch/BatchTests.cs index 991d55cdeb..7c843ad889 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Batch/BatchTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/Batch/BatchTests.cs @@ -7,6 +7,7 @@ using System.Data; using System.Data.Common; using System.Threading.Tasks; +using Microsoft.Data.SqlClient.Tests.Common.Fixtures.DatabaseObjects; using Xunit; namespace Microsoft.Data.SqlClient.ManualTesting.Tests @@ -377,9 +378,13 @@ public static void ExceptionWithoutBatchContainsNoBatch() [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] public static void ParameterInOutAndReturn() { - string create = - @" -CREATE PROCEDURE TestInAndOutParams + SqlParameter input = CreateParameter("@Input", SqlDbType.Int, 2); + SqlParameter inputOutput = CreateParameter("@InOut", SqlDbType.Int, 4, ParameterDirection.InputOutput); + SqlParameter output = CreateParameter("@Output", SqlDbType.Int, DBNull.Value, ParameterDirection.Output); + SqlParameter returned = CreateParameter("@RETURN_VALUE", SqlDbType.Int, DBNull.Value, ParameterDirection.ReturnValue); + + using (SqlConnection conn = new(DataTestUtility.TCPConnectionString)) + using (StoredProcedure spTestInAndOutParams = new(conn, "TestInAndOutParams", @" @Input int, @InOut int OUTPUT, @Output int = default OUTPUT @@ -388,26 +393,14 @@ CREATE PROCEDURE TestInAndOutParams SET NOCOUNT ON; SELECT @InOut = 2 * @InOut, @Output = 2 * @Input RETURN @Input -END"; - string drop = "DROP PROCEDURE TestInAndOutParams"; - - SqlParameter input = CreateParameter("@Input", SqlDbType.Int, 2); - SqlParameter inputOutput = CreateParameter("@InOut", SqlDbType.Int, 4, ParameterDirection.InputOutput); - SqlParameter output = CreateParameter("@Output", SqlDbType.Int, DBNull.Value, ParameterDirection.Output); - SqlParameter returned = CreateParameter("@RETURN_VALUE", SqlDbType.Int, DBNull.Value, ParameterDirection.ReturnValue); - try +END")) { - TryExecuteNonQueryCommand(drop); - ExecuteNonQueryCommand(create); - - using (SqlConnection conn = new SqlConnection(DataTestUtility.TCPConnectionString)) using (SqlBatch batch = new SqlBatch(conn)) { - conn.Open(); batch.Commands.Add(new SqlBatchCommand("SELECT @@VERSION")); batch.Commands.Add( new SqlBatchCommand( - "TestInAndOutParams", + spTestInAndOutParams.Name, CommandType.StoredProcedure, new[] { input, inputOutput, output, returned } ) @@ -416,10 +409,6 @@ RETURN @Input batch.ExecuteNonQuery(); } } - finally - { - TryExecuteNonQueryCommand(drop); - } Assert.Equal(8, Convert.ToInt32(inputOutput.Value)); Assert.Equal(4, Convert.ToInt32(output.Value)); @@ -656,28 +645,5 @@ private static SqlParameter CreateParameter(string name, SqlDbType type, T va parameter.Value = value; return parameter; } - - private static void ExecuteNonQueryCommand(string command) - { - using (SqlConnection conn = new SqlConnection(DataTestUtility.TCPConnectionString)) - using (SqlCommand cmd = conn.CreateCommand()) - { - conn.Open(); - cmd.CommandText = command; - cmd.ExecuteNonQuery(); - } - } - private static bool TryExecuteNonQueryCommand(string command) - { - try - { - ExecuteNonQueryCommand(command); - return true; - } - catch - { - } - return false; - } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlCommand/SqlCommandCancelTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlCommand/SqlCommandCancelTest.cs index 552fd4cea5..f7f10c15e3 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlCommand/SqlCommandCancelTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlCommand/SqlCommandCancelTest.cs @@ -372,7 +372,9 @@ private static void ExecuteCommandCancelExpected(object state) string errorMessage = SystemDataResourceManager.Instance.SQL_OperationCancelled; string errorMessageSevereFailure = SystemDataResourceManager.Instance.SQL_SevereError; - DataTestUtility.ExpectFailure(() => + // This could fail with either a SqlException or an InvalidOperationException depending on timing, + // so we will accept either but require the message to match expected cancellation messages + DataTestUtility.ExpectFailure(() => { threadsReady.SignalAndWait(); using (SqlDataReader r = command.ExecuteReader()) @@ -384,7 +386,9 @@ private static void ExecuteCommandCancelExpected(object state) } } while (r.NextResult()); } - }, new string[] { errorMessage, errorMessageSevereFailure }); + }, + new string[] { errorMessage, errorMessageSevereFailure }, + customExceptionVerifier: (ex) => ex is SqlException or InvalidOperationException); } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/LocalAppContextSwitchesTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/LocalAppContextSwitchesTest.cs index a468d8fd37..a2fb1146b4 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/LocalAppContextSwitchesTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/LocalAppContextSwitchesTest.cs @@ -1,8 +1,9 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; +using Microsoft.Data.SqlClient.Tests.Common; using Xunit; namespace Microsoft.Data.SqlClient.UnitTests; @@ -18,26 +19,28 @@ public class LocalAppContextSwitchesTest [Fact] public void TestDefaultAppContextSwitchValues() { - Assert.False(LocalAppContextSwitches.LegacyRowVersionNullBehavior); - Assert.False(LocalAppContextSwitches.SuppressInsecureTlsWarning); - Assert.False(LocalAppContextSwitches.MakeReadAsyncBlocking); - Assert.True(LocalAppContextSwitches.UseMinimumLoginTimeout); - Assert.True(LocalAppContextSwitches.LegacyVarTimeZeroScaleBehaviour); - Assert.True(LocalAppContextSwitches.UseCompatibilityProcessSni); - Assert.True(LocalAppContextSwitches.UseCompatibilityAsyncBehaviour); - Assert.False(LocalAppContextSwitches.UseConnectionPoolV2); - Assert.False(LocalAppContextSwitches.TruncateScaledDecimal); - Assert.False(LocalAppContextSwitches.IgnoreServerProvidedFailoverPartner); - Assert.False(LocalAppContextSwitches.UseLegacyFailoverAlternationOnLoginSqlErrors); - Assert.False(LocalAppContextSwitches.EnableMultiSubnetFailoverByDefault); + using LocalAppContextSwitchesHelper appContextSwitchesHelper = new(); + + Assert.False(appContextSwitchesHelper.LegacyRowVersionNullBehavior); + Assert.False(appContextSwitchesHelper.SuppressInsecureTlsWarning); + Assert.False(appContextSwitchesHelper.MakeReadAsyncBlocking); + Assert.True(appContextSwitchesHelper.UseMinimumLoginTimeout); + Assert.True(appContextSwitchesHelper.LegacyVarTimeZeroScaleBehaviour); + Assert.True(appContextSwitchesHelper.UseCompatibilityProcessSni); + Assert.True(appContextSwitchesHelper.UseCompatibilityAsyncBehaviour); + Assert.False(appContextSwitchesHelper.UseConnectionPoolV2); + Assert.False(appContextSwitchesHelper.TruncateScaledDecimal); + Assert.False(appContextSwitchesHelper.IgnoreServerProvidedFailoverPartner); + Assert.False(appContextSwitchesHelper.UseLegacyFailoverAlternationOnLoginSqlErrors); + Assert.False(appContextSwitchesHelper.EnableMultiSubnetFailoverByDefault); #if NET - Assert.False(LocalAppContextSwitches.GlobalizationInvariantMode); + Assert.False(appContextSwitchesHelper.GlobalizationInvariantMode); #endif #if NET && _WINDOWS - Assert.False(LocalAppContextSwitches.UseManagedNetworking); + Assert.False(appContextSwitchesHelper.UseManagedNetworking); #endif #if NETFRAMEWORK - Assert.False(LocalAppContextSwitches.DisableTnirByDefault); + Assert.False(appContextSwitchesHelper.DisableTnirByDefault); #endif } }