diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs index 06f9e9b9201..c78bfee6de0 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs @@ -93,7 +93,9 @@ public class SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) nameof(Math.Sign) when arguments is [var arg] && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float) || arg.Type == typeof(int) || arg.Type == typeof(long) || arg.Type == typeof(sbyte) || arg.Type == typeof(short)) - => TranslateFunction("SIGN", arg, nullTypeMapping: true), + // T-SQL SIGN returns the same type as its input, but Math.Sign always returns int; + // wrap with a CAST to avoid InvalidCastException at materialization time. + => TranslateSign(arg), nameof(double.DegreesToRadians) when arguments is [var arg] && (arg.Type == typeof(double) || arg.Type == typeof(float)) => TranslateFunction("RADIANS", arg), @@ -115,7 +117,27 @@ public class SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) _ => null }; - SqlExpression TranslateFunction(string sqlFunctionName, SqlExpression arg, bool nullTypeMapping = false) + SqlExpression TranslateSign(SqlExpression arg) + { + arg = sqlExpressionFactory.ApplyDefaultTypeMapping(arg)!; + + // T-SQL SIGN() returns the same type as its argument, but the CLR Math.Sign() always returns int. + // The function node must carry the argument's type mapping so that the CAST to int is not elided + // as a same-type no-op by SqlExpressionSimplifyingExpressionVisitor. + var signFunction = sqlExpressionFactory.Function( + "SIGN", + [arg], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + arg.Type, + arg.TypeMapping); + + return arg.Type == typeof(int) + ? signFunction + : sqlExpressionFactory.Convert(signFunction, typeof(int)); + } + + SqlExpression TranslateFunction(string sqlFunctionName, SqlExpression arg) { var typeMapping = ExpressionExtensions.InferTypeMapping(arg); return sqlExpressionFactory.Function( @@ -124,7 +146,7 @@ SqlExpression TranslateFunction(string sqlFunctionName, SqlExpression arg, bool nullable: true, argumentsPropagateNullability: Statics.TrueArrays[1], method.ReturnType, - nullTypeMapping ? null : typeMapping); + typeMapping); } SqlExpression TranslateBinaryFunction(string sqlFunctionName, SqlExpression arg1, SqlExpression arg2) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MathTranslationsCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MathTranslationsCosmosTest.cs index f65838697ca..0da472decac 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MathTranslationsCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MathTranslationsCosmosTest.cs @@ -444,6 +444,45 @@ public override async Task Sign() SELECT VALUE c FROM root c WHERE (SIGN(c["Double"]) > 0) +""", + // + """ +SELECT VALUE SIGN(c["Double"]) +FROM root c +"""); + } + + public override async Task Sign_decimal() + { + await base.Sign_decimal(); + + AssertSql( + """ +SELECT VALUE c +FROM root c +WHERE (SIGN(c["Decimal"]) > 0) +""", + // + """ +SELECT VALUE SIGN(c["Decimal"]) +FROM root c +"""); + } + + public override async Task Sign_int() + { + await base.Sign_int(); + + AssertSql( + """ +SELECT VALUE c +FROM root c +WHERE (SIGN(c["Int"]) > 0) +""", + // + """ +SELECT VALUE SIGN(c["Int"]) +FROM root c """); } @@ -456,6 +495,11 @@ public override async Task Sign_float() SELECT VALUE c FROM root c WHERE (SIGN(c["Float"]) > 0) +""", + // + """ +SELECT VALUE SIGN(c["Float"]) +FROM root c """); } diff --git a/test/EFCore.Specification.Tests/Query/Translations/MathTranslationsTestBase.cs b/test/EFCore.Specification.Tests/Query/Translations/MathTranslationsTestBase.cs index f6b9e993b27..f9734e32594 100644 --- a/test/EFCore.Specification.Tests/Query/Translations/MathTranslationsTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/Translations/MathTranslationsTestBase.cs @@ -186,12 +186,36 @@ public virtual Task Sqrt_float() => AssertQuery(ss => ss.Set().Where(b => b.Float > 0 && MathF.Sqrt(b.Float) > 0)); [ConditionalFact] - public virtual Task Sign() - => AssertQuery(ss => ss.Set().Where(b => Math.Sign(b.Double) > 0)); + public virtual async Task Sign() + { + await AssertQuery(ss => ss.Set().Where(b => Math.Sign(b.Double) > 0)); + + await AssertQueryScalar(ss => ss.Set().Select(b => Math.Sign(b.Double))); + } + + [ConditionalFact] + public virtual async Task Sign_decimal() + { + await AssertQuery(ss => ss.Set().Where(b => Math.Sign(b.Decimal) > 0)); + + await AssertQueryScalar(ss => ss.Set().Select(b => Math.Sign(b.Decimal))); + } [ConditionalFact] - public virtual Task Sign_float() - => AssertQuery(ss => ss.Set().Where(b => MathF.Sign(b.Float) > 0)); + public virtual async Task Sign_int() + { + await AssertQuery(ss => ss.Set().Where(b => Math.Sign(b.Int) > 0)); + + await AssertQueryScalar(ss => ss.Set().Select(b => Math.Sign(b.Int))); + } + + [ConditionalFact] + public virtual async Task Sign_float() + { + await AssertQuery(ss => ss.Set().Where(b => MathF.Sign(b.Float) > 0)); + + await AssertQueryScalar(ss => ss.Set().Select(b => MathF.Sign(b.Float))); + } [ConditionalFact] public virtual Task Max() diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MathTranslationsSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MathTranslationsSqlServerTest.cs index 3a848d36b84..150ba2c1d5d 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MathTranslationsSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MathTranslationsSqlServerTest.cs @@ -461,7 +461,46 @@ public override async Task Sign() """ SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] FROM [BasicTypesEntities] AS [b] -WHERE SIGN([b].[Double]) > 0 +WHERE CAST(SIGN([b].[Double]) AS int) > 0 +""", + // + """ +SELECT CAST(SIGN([b].[Double]) AS int) +FROM [BasicTypesEntities] AS [b] +"""); + } + + public override async Task Sign_decimal() + { + await base.Sign_decimal(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] +FROM [BasicTypesEntities] AS [b] +WHERE CAST(SIGN([b].[Decimal]) AS int) > 0 +""", + // + """ +SELECT CAST(SIGN([b].[Decimal]) AS int) +FROM [BasicTypesEntities] AS [b] +"""); + } + + public override async Task Sign_int() + { + await base.Sign_int(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] +FROM [BasicTypesEntities] AS [b] +WHERE SIGN([b].[Int]) > 0 +""", + // + """ +SELECT SIGN([b].[Int]) +FROM [BasicTypesEntities] AS [b] """); } @@ -473,7 +512,12 @@ public override async Task Sign_float() """ SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] FROM [BasicTypesEntities] AS [b] -WHERE SIGN([b].[Float]) > 0 +WHERE CAST(SIGN([b].[Float]) AS int) > 0 +""", + // + """ +SELECT CAST(SIGN([b].[Float]) AS int) +FROM [BasicTypesEntities] AS [b] """); } diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MathTranslationsSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MathTranslationsSqliteTest.cs index 1c96938b8de..0fde43e7ad0 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MathTranslationsSqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MathTranslationsSqliteTest.cs @@ -416,9 +416,20 @@ public override async Task Sign() SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan" FROM "BasicTypesEntities" AS "b" WHERE sign("b"."Double") > 0.0 +""", + // + """ +SELECT sign("b"."Double") +FROM "BasicTypesEntities" AS "b" """); } + public override async Task Sign_decimal() + => await AssertTranslationFailed(() => base.Sign_decimal()); // SQLite decimal support + + public override async Task Sign_int() + => await AssertTranslationFailed(() => base.Sign_int()); // SQLite int support + public override async Task Sign_float() { await base.Sign_float(); @@ -428,6 +439,11 @@ public override async Task Sign_float() SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan" FROM "BasicTypesEntities" AS "b" WHERE sign("b"."Float") > 0 +""", + // + """ +SELECT sign("b"."Float") +FROM "BasicTypesEntities" AS "b" """); }