diff --git a/samples/BasicSample/Program.cs b/samples/BasicSample/Program.cs
index d724d99..d9716c6 100644
--- a/samples/BasicSample/Program.cs
+++ b/samples/BasicSample/Program.cs
@@ -21,7 +21,7 @@ public class User
public string FullName { get; set; }
private string _FullName => FirstName + " " + LastName;
- [Projectable(UseMemberBody = nameof(_TotalSpent))]
+ [Projectable(UseMemberBody = nameof(_TotalSpent), OnlyOnInclude = true)]
public double TotalSpent { get; set; }
private double _TotalSpent => Orders.Sum(x => x.PriceSum);
@@ -154,10 +154,11 @@ public static void Main(string[] args)
}
{
- var result = dbContext.Users.FirstOrDefault();
+ Console.WriteLine($"Unloaded total: {dbContext.Users.First().TotalSpent}");
+ var result = dbContext.Users.Include(x => x.TotalSpent).FirstOrDefault();
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");
- result = dbContext.Users.FirstOrDefault(x => x.TotalSpent > 1);
+ result = dbContext.Users.Include(x => x.TotalSpent).FirstOrDefault(x => x.TotalSpent > 1);
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");
var spent = dbContext.Users.Sum(x => x.TotalSpent);
diff --git a/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs b/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs
index 4c3082b..02835a2 100644
--- a/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs
+++ b/src/EntityFrameworkCore.Projectables.Abstractions/ProjectableAttribute.cs
@@ -1,8 +1,4 @@
using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
namespace EntityFrameworkCore.Projectables
{
@@ -23,5 +19,12 @@ public sealed class ProjectableAttribute : Attribute
/// or null to get it from the current member.
///
public string? UseMemberBody { get; set; }
+
+ ///
+ /// true will allow you to request for this property by
+ /// explicitly calling .Include(x => x.Property) on the query,
+ /// false will always consider this query to be included.
+ ///
+ public bool OnlyOnInclude { get; set; }
}
}
diff --git a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs
index 3eedc6d..43ea0b2 100644
--- a/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs
+++ b/src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs
@@ -5,6 +5,7 @@
using System.Linq.Expressions;
using System.Reflection;
using EntityFrameworkCore.Projectables.Extensions;
+using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
@@ -16,6 +17,7 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor
readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new();
readonly Dictionary _projectableMemberCache = new();
private bool _disableRootRewrite;
+ private List _includedProjections = new();
private IEntityType? _entityType;
private readonly MethodInfo _select;
@@ -60,6 +62,7 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
public Expression? Replace(Expression? node)
{
_disableRootRewrite = false;
+ _includedProjections.Clear();
var ret = Visit(node);
if (_disableRootRewrite)
@@ -138,6 +141,28 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
protected override Expression VisitMethodCall(MethodCallExpression node)
{
+ if (node.Method.Name == nameof(EntityFrameworkQueryableExtensions.Include))
+ {
+ var include = node.Arguments[1] switch {
+ ConstantExpression { Value: string str } => str,
+ LambdaExpression { Body: MemberExpression member } => member.Member.Name,
+ UnaryExpression { Operand: LambdaExpression { Body: MemberExpression member } } => member.Member.Name,
+ _ => null
+ };
+ // Only rewrite the include if it includes a projectable property (or if we don't know what's happening).
+ var ret = Visit(node.Arguments[0]);
+ // The visit here is needed because we need the _entityType defined on the query root for the condition below.
+ if (
+ include != null
+ && _entityType?.ClrType
+ ?.GetProperty(include)
+ ?.GetCustomAttribute() != null)
+ {
+ _includedProjections.Add(include);
+ return ret;
+ }
+ }
+
// Replace MethodGroup arguments with their reflected expressions.
// Note that MethodCallExpression.Update returns the original Expression if argument values have not changed.
node = node.Update(node.Object, node.Arguments.Select(arg => arg switch {
@@ -212,13 +237,13 @@ PropertyInfo property when nodeExpression is not null
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
- return base.Visit(
+ return Visit(
updatedBody
);
}
else
{
- return base.Visit(
+ return Visit(
reflectedExpression.Body
);
}
@@ -243,7 +268,14 @@ protected override Expression VisitExtension(Expression node)
private Expression _AddProjectableSelect(Expression node, IEntityType entityType)
{
var projectableProperties = entityType.ClrType.GetProperties()
- .Where(x => x.IsDefined(typeof(ProjectableAttribute), false))
+ .Where(x => {
+ var attr = x.GetCustomAttribute();
+ if (attr == null)
+ return false;
+ if (attr.OnlyOnInclude)
+ return _includedProjections.Contains(x.Name);
+ return true;
+ })
.Where(x => x.CanWrite)
.ToList();
@@ -291,7 +323,7 @@ private Expression _GetAccessor(PropertyInfo property, ParameterExpression para)
_expressionArgumentReplacer.ParameterArgumentMapping.Add(lambda.Parameters[0], para);
var updatedBody = _expressionArgumentReplacer.Visit(lambda.Body);
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
- return base.Visit(updatedBody);
+ return Visit(updatedBody);
}
}
}