Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,46 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition
continue;
}

throw new InvalidOperationException("Switch expressions rewriting is only supported with constant values");
if (arm.Pattern is DeclarationPatternSyntax declaration)
{
var getTypeExpression = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
(ExpressionSyntax)Visit(node.GoverningExpression),
SyntaxFactory.IdentifierName("GetType")
);

var getTypeCall = SyntaxFactory.InvocationExpression(getTypeExpression);
var typeofExpression = SyntaxFactory.TypeOfExpression(declaration.Type);
var equalsExpression = SyntaxFactory.BinaryExpression(
SyntaxKind.EqualsExpression,
getTypeCall,
typeofExpression
);

ExpressionSyntax condition = equalsExpression;
if (arm.WhenClause != null)
{
condition = SyntaxFactory.BinaryExpression(
SyntaxKind.LogicalAndExpression,
equalsExpression,
(ExpressionSyntax)Visit(arm.WhenClause.Condition)
);
}

var modifiedArmExpression = ReplaceVariableWithCast(armExpression, declaration, node.GoverningExpression);
currentExpression = SyntaxFactory.ConditionalExpression(
condition,
modifiedArmExpression,
currentExpression
);

continue;
}

throw new InvalidOperationException(
$"Switch expressions rewriting supports only constant values and declaration patterns (Type var). " +
$"Unsupported pattern: {arm.Pattern.GetType().Name}"
);
}

return currentExpression;
Expand Down Expand Up @@ -346,5 +385,25 @@ public ExpressionSyntaxRewriter(INamedTypeSymbol targetTypeSymbol, NullCondition

return base.VisitNullableType(node);
}

private ExpressionSyntax ReplaceVariableWithCast(ExpressionSyntax expression, DeclarationPatternSyntax declaration, ExpressionSyntax governingExpression)
{
if (declaration.Designation is SingleVariableDesignationSyntax variableDesignation)
{
var variableName = variableDesignation.Identifier.ValueText;

var castExpression = SyntaxFactory.ParenthesizedExpression(
SyntaxFactory.CastExpression(
declaration.Type,
(ExpressionSyntax)Visit(governingExpression)
)
);

var rewriter = new VariableReplacementRewriter(variableName, castExpression);
return (ExpressionSyntax)rewriter.Visit(expression);
}

return expression;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace EntityFrameworkCore.Projectables.Generator;

public class VariableReplacementRewriter : CSharpSyntaxRewriter
{
private readonly string _variableName;
private readonly ExpressionSyntax _replacement;

public VariableReplacementRewriter(string variableName, ExpressionSyntax replacement)
{
_variableName = variableName;
_replacement = replacement;
}

public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
{
if (node.Identifier.ValueText == _variableName)
{
return _replacement;
}

return base.VisitIdentifierName(node);
}

public override SyntaxNode? VisitMemberAccessExpression(MemberAccessExpressionSyntax node)
{
if (node.Expression is IdentifierNameSyntax identifier &&
identifier.Identifier.ValueText == _variableName)
{
return SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
_replacement,
node.Name
);
}

return base.VisitMemberAccessExpression(node);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// <auto-generated/>
#nullable disable
using EntityFrameworkCore.Projectables;

namespace EntityFrameworkCore.Projectables.Generated
{
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
static class _ItemMapper_ToData
{
static global::System.Linq.Expressions.Expression<global::System.Func<global::Item, global::ItemData>> Expression()
{
return (global::Item item) => item.GetType() == typeof(GroupItem) ? new global::GroupData(((GroupItem)item).Id, ((GroupItem)item).Name, ((GroupItem)item).Description) : item.GetType() == typeof(DocumentItem) ? new global::DocumentData(((DocumentItem)item).Id, ((DocumentItem)item).Name, ((DocumentItem)item).Priority) : null !;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1785,7 +1785,7 @@ class Foo {
}

[Fact]
public Task SwitchExpression()
public Task SwitchExpressionWithConstantPattern()
{
var compilation = CreateCompilation(@"
using EntityFrameworkCore.Projectables;
Expand All @@ -1811,6 +1811,52 @@ class Foo {
return Verifier.Verify(result.GeneratedTrees[0].ToString());
}

[Fact]
public Task SwitchExpressionWithTypePattern()
{
var compilation = CreateCompilation(@"
using EntityFrameworkCore.Projectables;

public abstract class Item
{
public int Id { get; set; }
public string Name { get; set; }
}

public class GroupItem : Item
{
public string Description { get; set; }
}

public class DocumentItem : Item
{
public int Priority { get; set; }
}

public abstract record ItemData(int Id, string Name);
public record GroupData(int Id, string Name, string Description) : ItemData(Id, Name);
public record DocumentData(int Id, string Name, int Priority) : ItemData(Id, Name);

public static class ItemMapper
{
[Projectable]
public static ItemData ToData(this Item item) =>
item switch
{
GroupItem groupItem => new GroupData(groupItem.Id, groupItem.Name, groupItem.Description),
DocumentItem documentItem => new DocumentData(documentItem.Id, documentItem.Name, documentItem.Priority),
_ => null!
};
}
");

var result = RunGenerator(compilation);

Assert.Empty(result.Diagnostics);

return Verifier.Verify(result.GeneratedTrees[0].ToString());
}

[Fact]
public Task GenericTypes()
{
Expand Down
Loading