diff --git a/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs b/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs
index 5c845fa..6340d90 100644
--- a/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs
+++ b/src/EntityFrameworkCore.Projectables.Generator/ExpressionSyntaxRewriter.cs
@@ -190,7 +190,46 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition
continue;
}
- throw new InvalidOperationException("Switch expressions rewriting is only supported with constant values");
+ if (arm.Pattern is DeclarationPatternSyntax declaration)
+ {
+ var getTypeExpression = SyntaxFactory.MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ (ExpressionSyntax)Visit(node.GoverningExpression),
+ SyntaxFactory.IdentifierName("GetType")
+ );
+
+ var getTypeCall = SyntaxFactory.InvocationExpression(getTypeExpression);
+ var typeofExpression = SyntaxFactory.TypeOfExpression(declaration.Type);
+ var equalsExpression = SyntaxFactory.BinaryExpression(
+ SyntaxKind.EqualsExpression,
+ getTypeCall,
+ typeofExpression
+ );
+
+ ExpressionSyntax condition = equalsExpression;
+ if (arm.WhenClause != null)
+ {
+ condition = SyntaxFactory.BinaryExpression(
+ SyntaxKind.LogicalAndExpression,
+ equalsExpression,
+ (ExpressionSyntax)Visit(arm.WhenClause.Condition)
+ );
+ }
+
+ var modifiedArmExpression = ReplaceVariableWithCast(armExpression, declaration, node.GoverningExpression);
+ currentExpression = SyntaxFactory.ConditionalExpression(
+ condition,
+ modifiedArmExpression,
+ currentExpression
+ );
+
+ continue;
+ }
+
+ throw new InvalidOperationException(
+ $"Switch expressions rewriting supports only constant values and declaration patterns (Type var). " +
+ $"Unsupported pattern: {arm.Pattern.GetType().Name}"
+ );
}
return currentExpression;
@@ -346,5 +385,25 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition
return base.VisitNullableType(node);
}
+
+ private ExpressionSyntax ReplaceVariableWithCast(ExpressionSyntax expression, DeclarationPatternSyntax declaration, ExpressionSyntax governingExpression)
+ {
+ if (declaration.Designation is SingleVariableDesignationSyntax variableDesignation)
+ {
+ var variableName = variableDesignation.Identifier.ValueText;
+
+ var castExpression = SyntaxFactory.ParenthesizedExpression(
+ SyntaxFactory.CastExpression(
+ declaration.Type,
+ (ExpressionSyntax)Visit(governingExpression)
+ )
+ );
+
+ var rewriter = new VariableReplacementRewriter(variableName, castExpression);
+ return (ExpressionSyntax)rewriter.Visit(expression);
+ }
+
+ return expression;
+ }
}
}
diff --git a/src/EntityFrameworkCore.Projectables.Generator/VariableReplacementRewriter.cs b/src/EntityFrameworkCore.Projectables.Generator/VariableReplacementRewriter.cs
new file mode 100644
index 0000000..578aa74
--- /dev/null
+++ b/src/EntityFrameworkCore.Projectables.Generator/VariableReplacementRewriter.cs
@@ -0,0 +1,41 @@
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+namespace EntityFrameworkCore.Projectables.Generator;
+
+public class VariableReplacementRewriter : CSharpSyntaxRewriter
+{
+ private readonly string _variableName;
+ private readonly ExpressionSyntax _replacement;
+
+ public VariableReplacementRewriter(string variableName, ExpressionSyntax replacement)
+ {
+ _variableName = variableName;
+ _replacement = replacement;
+ }
+
+ public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
+ {
+ if (node.Identifier.ValueText == _variableName)
+ {
+ return _replacement;
+ }
+
+ return base.VisitIdentifierName(node);
+ }
+
+ public override SyntaxNode? VisitMemberAccessExpression(MemberAccessExpressionSyntax node)
+ {
+ if (node.Expression is IdentifierNameSyntax identifier &&
+ identifier.Identifier.ValueText == _variableName)
+ {
+ return SyntaxFactory.MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ _replacement,
+ node.Name
+ );
+ }
+
+ return base.VisitMemberAccessExpression(node);
+ }
+}
diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpression.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpressionWithConstantPattern.verified.txt
similarity index 100%
rename from tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpression.verified.txt
rename to tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpressionWithConstantPattern.verified.txt
diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpressionWithTypePattern.verified.txt b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpressionWithTypePattern.verified.txt
new file mode 100644
index 0000000..dcb1ec1
--- /dev/null
+++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.SwitchExpressionWithTypePattern.verified.txt
@@ -0,0 +1,15 @@
+//
+#nullable disable
+using EntityFrameworkCore.Projectables;
+
+namespace EntityFrameworkCore.Projectables.Generated
+{
+ [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
+ static class _ItemMapper_ToData
+ {
+ static global::System.Linq.Expressions.Expression> Expression()
+ {
+ return (global::Item item) => item.GetType() == typeof(GroupItem) ? new global::GroupData(((GroupItem)item).Id, ((GroupItem)item).Name, ((GroupItem)item).Description) : item.GetType() == typeof(DocumentItem) ? new global::DocumentData(((DocumentItem)item).Id, ((DocumentItem)item).Name, ((DocumentItem)item).Priority) : null !;
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs
index a982c79..2a74263 100644
--- a/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs
+++ b/tests/EntityFrameworkCore.Projectables.Generator.Tests/ProjectionExpressionGeneratorTests.cs
@@ -1785,7 +1785,7 @@ class Foo {
}
[Fact]
- public Task SwitchExpression()
+ public Task SwitchExpressionWithConstantPattern()
{
var compilation = CreateCompilation(@"
using EntityFrameworkCore.Projectables;
@@ -1811,6 +1811,52 @@ class Foo {
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}
+ [Fact]
+ public Task SwitchExpressionWithTypePattern()
+ {
+ var compilation = CreateCompilation(@"
+using EntityFrameworkCore.Projectables;
+
+public abstract class Item
+{
+ public int Id { get; set; }
+ public string Name { get; set; }
+}
+
+public class GroupItem : Item
+{
+ public string Description { get; set; }
+}
+
+public class DocumentItem : Item
+{
+ public int Priority { get; set; }
+}
+
+public abstract record ItemData(int Id, string Name);
+public record GroupData(int Id, string Name, string Description) : ItemData(Id, Name);
+public record DocumentData(int Id, string Name, int Priority) : ItemData(Id, Name);
+
+public static class ItemMapper
+{
+ [Projectable]
+ public static ItemData ToData(this Item item) =>
+ item switch
+ {
+ GroupItem groupItem => new GroupData(groupItem.Id, groupItem.Name, groupItem.Description),
+ DocumentItem documentItem => new DocumentData(documentItem.Id, documentItem.Name, documentItem.Priority),
+ _ => null!
+ };
+}
+");
+
+ var result = RunGenerator(compilation);
+
+ Assert.Empty(result.Diagnostics);
+
+ return Verifier.Verify(result.GeneratedTrees[0].ToString());
+ }
+
[Fact]
public Task GenericTypes()
{