From 4a3a38a6ab5ca78f35c2e85600152fdda499e082 Mon Sep 17 00:00:00 2001 From: bezu Date: Tue, 15 Jul 2025 23:00:59 +0300 Subject: [PATCH] Add support for declaration patterns in switch expressions by converting them to GetType() comparisons --- .../ExpressionSyntaxRewriter.cs | 61 ++++++++++++++++++- .../VariableReplacementRewriter.cs | 41 +++++++++++++ ...xpressionWithConstantPattern.verified.txt} | 0 ...itchExpressionWithTypePattern.verified.txt | 15 +++++ .../ProjectionExpressionGeneratorTests.cs | 48 ++++++++++++++- 5 files changed, 163 insertions(+), 2 deletions(-) create mode 100644 src/EntityFrameworkCore.Projectables.Generator/VariableReplacementRewriter.cs rename tests/EntityFrameworkCore.Projectables.Generator.Tests/{ProjectionExpressionGeneratorTests.SwitchExpression.verified.txt => ProjectionExpressionGeneratorTests.SwitchExpressionWithConstantPattern.verified.txt} (100%) create mode 100644 tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpressionWithTypePattern.verified.txt diff --git a/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs b/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs index 5c845fa..6340d90 100644 --- a/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs +++ b/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs @@ -190,7 +190,46 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition continue; } - throw new InvalidOperationException("Switch expressions rewriting is only supported with constant values"); + if (arm.Pattern is DeclarationPatternSyntax declaration) + { + var getTypeExpression = SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + (ExpressionSyntax)Visit(node.GoverningExpression), + SyntaxFactory.IdentifierName("GetType") + ); + + var getTypeCall = SyntaxFactory.InvocationExpression(getTypeExpression); + var typeofExpression = SyntaxFactory.TypeOfExpression(declaration.Type); + var equalsExpression = SyntaxFactory.BinaryExpression( + SyntaxKind.EqualsExpression, + getTypeCall, + typeofExpression + ); + + ExpressionSyntax condition = equalsExpression; + if (arm.WhenClause != null) + { + condition = SyntaxFactory.BinaryExpression( + SyntaxKind.LogicalAndExpression, + equalsExpression, + (ExpressionSyntax)Visit(arm.WhenClause.Condition) + ); + } + + var modifiedArmExpression = ReplaceVariableWithCast(armExpression, declaration, node.GoverningExpression); + currentExpression = SyntaxFactory.ConditionalExpression( + condition, + modifiedArmExpression, + currentExpression + ); + + continue; + } + + throw new InvalidOperationException( + $"Switch expressions rewriting supports only constant values and declaration patterns (Type var). " + + $"Unsupported pattern: {arm.Pattern.GetType().Name}" + ); } return currentExpression; @@ -346,5 +385,25 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition return base.VisitNullableType(node); } + + private ExpressionSyntax ReplaceVariableWithCast(ExpressionSyntax expression, DeclarationPatternSyntax declaration, ExpressionSyntax governingExpression) + { + if (declaration.Designation is SingleVariableDesignationSyntax variableDesignation) + { + var variableName = variableDesignation.Identifier.ValueText; + + var castExpression = SyntaxFactory.ParenthesizedExpression( + SyntaxFactory.CastExpression( + declaration.Type, + (ExpressionSyntax)Visit(governingExpression) + ) + ); + + var rewriter = new VariableReplacementRewriter(variableName, castExpression); + return (ExpressionSyntax)rewriter.Visit(expression); + } + + return expression; + } } } diff --git a/src/EntityFrameworkCore.Projectables.Generator/VariableReplacementRewriter.cs b/src/EntityFrameworkCore.Projectables.Generator/VariableReplacementRewriter.cs new file mode 100644 index 0000000..578aa74 --- /dev/null +++ b/src/EntityFrameworkCore.Projectables.Generator/VariableReplacementRewriter.cs @@ -0,0 +1,41 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +namespace EntityFrameworkCore.Projectables.Generator; + +public class VariableReplacementRewriter : CSharpSyntaxRewriter +{ + private readonly string _variableName; + private readonly ExpressionSyntax _replacement; + + public VariableReplacementRewriter(string variableName, ExpressionSyntax replacement) + { + _variableName = variableName; + _replacement = replacement; + } + + public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node) + { + if (node.Identifier.ValueText == _variableName) + { + return _replacement; + } + + return base.VisitIdentifierName(node); + } + + public override SyntaxNode? VisitMemberAccessExpression(MemberAccessExpressionSyntax node) + { + if (node.Expression is IdentifierNameSyntax identifier && + identifier.Identifier.ValueText == _variableName) + { + return SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + _replacement, + node.Name + ); + } + + return base.VisitMemberAccessExpression(node); + } +} diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpressionWithConstantPattern.verified.txt similarity index 100% rename from tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpression.verified.txt rename to tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpressionWithConstantPattern.verified.txt diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpressionWithTypePattern.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpressionWithTypePattern.verified.txt new file mode 100644 index 0000000..dcb1ec1 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpressionWithTypePattern.verified.txt @@ -0,0 +1,15 @@ +// +#nullable disable +using EntityFrameworkCore.Projectables; + +namespace EntityFrameworkCore.Projectables.Generated +{ + [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + static class _ItemMapper_ToData + { + static global::System.Linq.Expressions.Expression> Expression() + { + return (global::Item item) => item.GetType() == typeof(GroupItem) ? new global::GroupData(((GroupItem)item).Id, ((GroupItem)item).Name, ((GroupItem)item).Description) : item.GetType() == typeof(DocumentItem) ? new global::DocumentData(((DocumentItem)item).Id, ((DocumentItem)item).Name, ((DocumentItem)item).Priority) : null !; + } + } +} \ No newline at end of file diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs index a982c79..2a74263 100644 --- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs @@ -1785,7 +1785,7 @@ class Foo { } [Fact] - public Task SwitchExpression() + public Task SwitchExpressionWithConstantPattern() { var compilation = CreateCompilation(@" using EntityFrameworkCore.Projectables; @@ -1811,6 +1811,52 @@ class Foo { return Verifier.Verify(result.GeneratedTrees[0].ToString()); } + [Fact] + public Task SwitchExpressionWithTypePattern() + { + var compilation = CreateCompilation(@" +using EntityFrameworkCore.Projectables; + +public abstract class Item +{ + public int Id { get; set; } + public string Name { get; set; } +} + +public class GroupItem : Item +{ + public string Description { get; set; } +} + +public class DocumentItem : Item +{ + public int Priority { get; set; } +} + +public abstract record ItemData(int Id, string Name); +public record GroupData(int Id, string Name, string Description) : ItemData(Id, Name); +public record DocumentData(int Id, string Name, int Priority) : ItemData(Id, Name); + +public static class ItemMapper +{ + [Projectable] + public static ItemData ToData(this Item item) => + item switch + { + GroupItem groupItem => new GroupData(groupItem.Id, groupItem.Name, groupItem.Description), + DocumentItem documentItem => new DocumentData(documentItem.Id, documentItem.Name, documentItem.Priority), + _ => null! + }; +} +"); + + var result = RunGenerator(compilation); + + Assert.Empty(result.Diagnostics); + + return Verifier.Verify(result.GeneratedTrees[0].ToString()); + } + [Fact] public Task GenericTypes() {