diff --git a/Directory.Packages.props b/Directory.Packages.props index d25f1f4..6e3e005 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -7,6 +7,7 @@ + diff --git a/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs b/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs index 4c3082b..ab808fd 100644 --- a/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs +++ b/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs @@ -13,6 +13,10 @@ namespace EntityFrameworkCore.Projectables [AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = true, AllowMultiple = false)] public sealed class ProjectableAttribute : Attribute { + public ProjectableAttribute() + { + + } /// /// Get or set how null-conditional operators are handeled /// @@ -23,5 +27,10 @@ public sealed class ProjectableAttribute : Attribute /// or null to get it from the current member. /// public string? UseMemberBody { get; set; } + + /// + /// Parameters values for UseMemberBody. + /// + public object[]? UseMemberBodyArguments { get; set; } } } diff --git a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs index 2a85c9c..c5f9c5b 100644 --- a/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs +++ b/src/EntityFrameworkCore.Projectables/Infrastructure/Internal/CustomQueryCompiler.cs @@ -1,11 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Linq.Expressions; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using System.Transactions; +using System.Linq.Expressions; using EntityFrameworkCore.Projectables.Services; using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Query.Internal; diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs index c9d8ad4..33a7438 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs @@ -48,7 +48,7 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La reflectedExpression = projectableAttribute is not null ? _resolver.FindGeneratedExpression(memberInfo) - : null; + : (LambdaExpression?)null; _projectableMemberCache.Add(memberInfo, reflectedExpression); } @@ -209,19 +209,30 @@ PropertyInfo property when nodeExpression is not null if (nodeExpression is not null) { _expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], nodeExpression); + if (reflectedExpression.Parameters.Count > 1) + { + var projectableAttribute = nodeMember.GetCustomAttribute(false)!; + foreach (var parameterWithIndex in reflectedExpression.Parameters.Skip(1).Select((Parameter, Index) => new { Parameter, Index })) + { + var value = projectableAttribute!.UseMemberBodyArguments![parameterWithIndex.Index]; + _expressionArgumentReplacer.ParameterArgumentMapping.Add(parameterWithIndex.Parameter, Expression.Constant(value)); + } + } + var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body); _expressionArgumentReplacer.ParameterArgumentMapping.Clear(); - return base.Visit( + return Visit( updatedBody ); } else { - return base.Visit( + return Visit( reflectedExpression.Body ); } + } return base.VisitMember(node); diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs index b9062dd..0b36ce4 100644 --- a/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs +++ b/src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionResolver.cs @@ -13,13 +13,16 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo var projectableAttribute = projectableMemberInfo.GetCustomAttribute() ?? throw new InvalidOperationException("Expected member to have a Projectable attribute. None found"); - var expression = GetExpressionFromGeneratedType(projectableMemberInfo); + var expression = projectableAttribute.UseMemberBody is null? GetExpressionFromGeneratedType(projectableMemberInfo): null; + if (expression is null && projectableAttribute.UseMemberBody is not null && projectableAttribute.UseMemberBodyArguments is null) + { + expression = GetExpressionFromGeneratedType(projectableMemberInfo, true, projectableAttribute.UseMemberBody); + } if (expression is null && projectableAttribute.UseMemberBody is not null) { - expression = GetExpressionFromMemberBody(projectableMemberInfo, projectableAttribute.UseMemberBody); + expression = GetExpressionFromMemberBody(projectableMemberInfo, projectableAttribute.UseMemberBody, projectableAttribute.UseMemberBodyArguments); } - if (expression is null) { var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); @@ -33,17 +36,32 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo return expression; - static LambdaExpression? GetExpressionFromMemberBody(MemberInfo projectableMemberInfo, string memberName) + static LambdaExpression? GetExpressionFromMemberBody(MemberInfo projectableMemberInfo, string memberName, object[]? memberParameters) { var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); - var exprProperty = declaringType.GetProperty(memberName, BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + + var exprProperty = declaringType.GetProperty(memberName, BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic) + ?? declaringType.BaseType?.GetProperty(memberName, BindingFlags.Static | BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); var lambda = exprProperty?.GetValue(null) as LambdaExpression; if (lambda is not null) { - if (projectableMemberInfo is PropertyInfo property && - lambda.Parameters.Count == 1 && - lambda.Parameters[0].Type == declaringType && lambda.ReturnType == property.PropertyType) + + if (projectableMemberInfo is PropertyInfo property && (lambda.Parameters.Count == + (1 + (memberParameters?.Length ?? 0))) + && (lambda.Parameters[0].Type == declaringType || + lambda.Parameters[0].Type == + declaringType.BaseType) + && lambda.ReturnType == property.PropertyType + && (memberParameters?.Any() != true + || lambda.Parameters.Skip(1) + .Select((Parameter, Index) => + new { Parameter, Index }) + .All(p => p.Parameter.Type == + memberParameters[p.Index] + .GetType()))) + + { return lambda; } @@ -55,16 +73,16 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo return lambda; } } - return null; } - static LambdaExpression? GetExpressionFromGeneratedType(MemberInfo projectableMemberInfo) + static LambdaExpression? GetExpressionFromGeneratedType(MemberInfo projectableMemberInfo, bool useLocalType = false, string methodName = "Expression") { var declaringType = projectableMemberInfo.DeclaringType ?? throw new InvalidOperationException("Expected a valid type here"); var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name); - - var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName); + var expressionFactoryType = !useLocalType + ? declaringType.Assembly.GetType(generatedContainingTypeName) + : declaringType; if (expressionFactoryType is not null) { @@ -73,7 +91,7 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo expressionFactoryType = expressionFactoryType.MakeGenericType(declaringType.GenericTypeArguments); } - var expressionFactoryMethod = expressionFactoryType.GetMethod("Expression", BindingFlags.Static | BindingFlags.NonPublic); + var expressionFactoryMethod = expressionFactoryType.GetMethod(methodName, BindingFlags.Static | BindingFlags.NonPublic); var methodGenericArguments = projectableMemberInfo switch { MethodInfo methodInfo => methodInfo.GetGenericArguments(), @@ -90,7 +108,6 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo return expressionFactoryMethod.Invoke(null, null) as LambdaExpression ?? throw new InvalidOperationException("Expected lambda"); } } - return null; } } diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj b/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj index bb5b0c5..c47188e 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/EntityFrameworkCore.Projectables.FunctionalTests.csproj @@ -9,6 +9,7 @@ all runtime; build; native; contentfiles; analyzers; buildtransitive + diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs index 723ceee..0c3849b 100644 --- a/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionsMethods/ExtensionMethodTests.cs @@ -1,12 +1,7 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; +using System.Linq; using System.Threading.Tasks; using EntityFrameworkCore.Projectables.FunctionalTests.Helpers; using Microsoft.EntityFrameworkCore; -using Microsoft.EntityFrameworkCore.Scaffolding.Metadata; -using ScenarioTests; using VerifyXunit; using Xunit; diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/Helpers/SampleBodyParamDbContext.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/Helpers/SampleBodyParamDbContext.cs new file mode 100644 index 0000000..32a08ca --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/Helpers/SampleBodyParamDbContext.cs @@ -0,0 +1,75 @@ +using System.Collections.Generic; +using EntityFrameworkCore.Projectables.Infrastructure; +using Microsoft.EntityFrameworkCore; +using Order = EntityFrameworkCore.Projectables.FunctionalTests.MemberBodyParameterValueTests.Order; +using OrderItem = EntityFrameworkCore.Projectables.FunctionalTests.MemberBodyParameterValueTests.OrderItem; + +namespace EntityFrameworkCore.Projectables.FunctionalTests.Helpers +{ + public class SampleBodyParamDbContext : DbContext + { + readonly CompatibilityMode _compatibilityMode; + + public SampleBodyParamDbContext(CompatibilityMode compatibilityMode = CompatibilityMode.Full) + { + _compatibilityMode = compatibilityMode; + + var _orders = new List() { + new() { + Id = 1, + + }, + new() { + Id = 2, + + } + }; + + var _orders_items = new List() { + new() { + Id = 1, + OrderId = 1, + Name = "Order_1" + }, + new() { + Id = 2, + OrderId = 1, + Name = "Order_2" + }, + new() { + Id = 3, + OrderId = 2, + Name = "Order_3" + }, + new() { + Id = 4, + OrderId = 2, + Name = "Order_4" + }, + }; + + Order!.AddRange(_orders); + OrderItem!.AddRange(_orders_items); + SaveChanges(); + } + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) + { + optionsBuilder.UseInMemoryDatabase("TestDb"); + optionsBuilder.UseProjectables(options => { + options.CompatibilityMode(_compatibilityMode); // Needed by our ComplexModelTests + }); + } + + protected override void OnModelCreating(ModelBuilder modelBuilder) + { + modelBuilder.Entity() + .HasKey(x => x.Id); + modelBuilder.Entity().HasKey(x => x.Id); + + } + + public DbSet? Order { get; set; } + public DbSet? OrderItem { get; set; } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.FunctionalTests/MemberBodyParameterValueTests.cs b/tests/EntityFrameworkCore.Projectables.FunctionalTests/MemberBodyParameterValueTests.cs new file mode 100644 index 0000000..5982767 --- /dev/null +++ b/tests/EntityFrameworkCore.Projectables.FunctionalTests/MemberBodyParameterValueTests.cs @@ -0,0 +1,66 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.ComponentModel.DataAnnotations.Schema; +using System.Linq; +using System.Linq.Expressions; +using EntityFrameworkCore.Projectables.FunctionalTests.Helpers; +using Microsoft.EntityFrameworkCore; +using Xunit; + +namespace EntityFrameworkCore.Projectables.FunctionalTests +{ + public class MemberBodyParameterValueTests + { + public class Order + { + [Key, DatabaseGenerated(DatabaseGeneratedOption.None)] + public int Id { get; set; } + public List OrderItem { get; set; } = new List(); + + [Projectable(UseMemberBody = nameof(GetOrderItemNameExpression), UseMemberBodyArguments = new object[]{ 1 } )] + public string FirstOrderPropName => GetOrderItemName(1); + + + [Projectable(UseMemberBody = nameof(GetOrderItemNameInnerExpression))] + public string GetOrderItemName(int id) + => OrderItem.Where(av => av.Id == id) + .Select(av => av.Name) + .FirstOrDefault() ?? throw new ArgumentException("Argument matching identifier not found"); + + private static Expression> GetOrderItemNameInnerExpression() + => (@this, id) => @this.OrderItem + .Where(av => av.Id == id) + .Select(av => av.Name) + .FirstOrDefault() ?? string.Empty; + + public static Expression> GetOrderItemNameExpression + => (order, id) => order.GetOrderItemName(id); + } + + public class OrderItem + { + [Key, DatabaseGenerated(DatabaseGeneratedOption.None)] + public int Id { get; set; } + public int OrderId { get; set; } + public string Name { get; set; } = string.Empty; + } + + [Fact] + public void UseBodyParameterValue() + { + //Arrange + using var dbContext = new SampleBodyParamDbContext(); + + // Act + var query = dbContext + .Set() + .Include(a => a.OrderItem) + .FirstOrDefault(d => d.FirstOrderPropName == "Order_1"); + + // Assert + Assert.NotNull(query); + Assert.True(query!.FirstOrderPropName == "Order_1"); + } + } +} diff --git a/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs b/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs index a9ed5dd..133baeb 100644 --- a/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs +++ b/tests/EntityFrameworkCore.Projectables.Tests/Services/ProjectableExpressionReplacerTests.cs @@ -61,7 +61,7 @@ public void VisitMember_SimpleProperty() ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Replace(input); + var actual = subject.Visit(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -77,7 +77,7 @@ public void VisitMember_SimpleMethod() ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Replace(input); + var actual = subject.Visit(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -93,7 +93,7 @@ public void VisitMember_SimpleMethodWithArguments() ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Replace(input); + var actual = subject.Visit(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -109,7 +109,7 @@ public void VisitMember_SimpleStatefullProperty() ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Replace(input); + var actual = subject.Visit(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -125,7 +125,7 @@ public void VisitMember_SimpleStatefullMethod() ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Replace(input); + var actual = subject.Visit(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -141,7 +141,7 @@ public void VisitMember_SimpleStaticMethod() ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Replace(input); + var actual = subject.Visit(input); Assert.Equal(expected.ToString(), actual.ToString()); } @@ -157,7 +157,7 @@ public void VisitMember_SimpleStaticMethodWithArguments() ); var subject = new ProjectableExpressionReplacer(resolver); - var actual = subject.Replace(input); + var actual = subject.Visit(input); Assert.Equal(expected.ToString(), actual.ToString()); }