diff --git a/src/EFCore/ChangeTracking/ChangeTracker.cs b/src/EFCore/ChangeTracking/ChangeTracker.cs index 76f5400f052..452fa83d1b4 100644 --- a/src/EFCore/ChangeTracking/ChangeTracker.cs +++ b/src/EFCore/ChangeTracking/ChangeTracker.cs @@ -215,6 +215,44 @@ public virtual IEnumerable> Entries() .Select(e => new EntityEntry(e)); } + /// + /// Returns tracked entities that are in a given state from a fast cache. + /// + /// Entities in EntityState.Added state + /// Entities in Modified.Added state + /// Entities in Modified.Deleted state + /// Entities in Modified.Unchanged state + /// An entry for each entity that matched the search criteria. + public virtual IEnumerable GetEntriesForState( + bool added = false, + bool modified = false, + bool deleted = false, + bool unchanged = false) + { + return StateManager.GetEntriesForState(added, modified, deleted, unchanged) + .Select(e => new EntityEntry(e)); + } + + /// + /// Returns tracked entities that are in a given state from a fast cache. + /// + /// Entities in EntityState.Added state + /// Entities in Modified.Added state + /// Entities in Modified.Deleted state + /// Entities in Modified.Unchanged state + /// An entry for each entity that matched the search criteria. + public virtual IEnumerable> GetEntriesForState( + bool added = false, + bool modified = false, + bool deleted = false, + bool unchanged = false) + where TEntity : class + { + return StateManager.GetEntriesForState(added, modified, deleted, unchanged) + .Where(e => e.Entity is TEntity) + .Select(e => new EntityEntry(e)); + } + private void TryDetectChanges() { if (AutoDetectChangesEnabled) diff --git a/src/EFCore/ChangeTracking/EntityEntry.cs b/src/EFCore/ChangeTracking/EntityEntry.cs index c719da3ec06..8f05ad2e989 100644 --- a/src/EFCore/ChangeTracking/EntityEntry.cs +++ b/src/EFCore/ChangeTracking/EntityEntry.cs @@ -680,6 +680,65 @@ public virtual void Reload() public virtual async Task ReloadAsync(CancellationToken cancellationToken = default) => Reload(await GetDatabaseValuesAsync(cancellationToken).ConfigureAwait(false)); + /// + /// Reloads the entity from the database using the specified . + /// + /// + /// + /// The behavior of this method depends on the specified: + /// + /// + /// : Overwrites both current and original property values with values from the database. + /// The entity will be in the state after calling this method. + /// + /// + /// : Updates original property values with values from the database, + /// but preserves any local modifications to current values. Modified properties remain modified with their current values. + /// + /// + /// If the entity does not exist in the database, the entity will be . + /// Calling Reload on an entity that does not exist in the database is a no-op. + /// + /// + /// See Accessing tracked entities in EF Core for more information and + /// examples. + /// + /// + /// The merge option controlling how database values are applied to the entity. + public virtual void Reload(MergeOption mergeOption) + => Reload(GetDatabaseValues(), mergeOption); + + /// + /// Reloads the entity from the database using the specified . + /// + /// + /// + /// The behavior of this method depends on the specified: + /// + /// + /// : Overwrites both current and original property values with values from the database. + /// The entity will be in the state after calling this method. + /// + /// + /// : Updates original property values with values from the database, + /// but preserves any local modifications to current values. Modified properties remain modified with their current values. + /// + /// + /// If the entity does not exist in the database, the entity will be . + /// Calling Reload on an entity that does not exist in the database is a no-op. + /// + /// + /// See Accessing tracked entities in EF Core for more information and + /// examples. + /// + /// + /// The merge option controlling how database values are applied to the entity. + /// A to observe while waiting for the task to complete. + /// A task that represents the asynchronous operation. + /// If the is canceled. + public virtual async Task ReloadAsync(MergeOption mergeOption, CancellationToken cancellationToken = default) + => Reload(await GetDatabaseValuesAsync(cancellationToken).ConfigureAwait(false), mergeOption); + private void Reload(PropertyValues? storeValues) { if (storeValues == null) @@ -697,6 +756,30 @@ private void Reload(PropertyValues? storeValues) State = EntityState.Unchanged; } } + private void Reload(PropertyValues? storeValues, MergeOption mergeOption) + { + if (storeValues == null) + { + if (State != EntityState.Added) + { + State = EntityState.Deleted; + State = EntityState.Detached; + } + } + else + { + foreach (var property in Metadata.GetProperties()) + { + var value = storeValues[property]; + InternalEntry.ReloadValue(property, value, mergeOption, updateEntityState: false); + } + + if (mergeOption == MergeOption.OverwriteChanges) + { + State = EntityState.Unchanged; + } + } + } [field: AllowNull, MaybeNull] private IEntityFinder Finder diff --git a/src/EFCore/ChangeTracking/Internal/InternalEntryBase.cs b/src/EFCore/ChangeTracking/Internal/InternalEntryBase.cs index 99dd4e0010e..23ab03cbac5 100644 --- a/src/EFCore/ChangeTracking/Internal/InternalEntryBase.cs +++ b/src/EFCore/ChangeTracking/Internal/InternalEntryBase.cs @@ -998,6 +998,38 @@ private void DetectChanges(IComplexProperty complexProperty) } } + /// + /// Refreshes the property value with the value from the database + /// + /// Property + /// New value from database + /// MergeOption + /// Sets the EntityState to Unchanged if MergeOption.OverwriteChanges else calls ChangeDetector to determine changes + public virtual void ReloadValue(IPropertyBase propertyBase, object? value, MergeOption mergeOption, bool updateEntityState) + { + var property = (IProperty)propertyBase; + EnsureOriginalValues(); + bool isModified = IsModified(property); + _originalValues.SetValue(property, value, -1); + if (mergeOption == MergeOption.OverwriteChanges || !isModified) + { + SetProperty(propertyBase, value, isMaterialization: true, setModified: false); + } + + if (updateEntityState) + { + if (mergeOption == MergeOption.OverwriteChanges) + { + SetEntityState(EntityState.Unchanged); + } + else if (StateManager is StateManager stateManager + && stateManager.ChangeDetector is ChangeDetector changeDetector) + { + changeDetector.DetectValueChange(this, property); + } + } + } + private void ReorderOriginalComplexCollectionEntries(IComplexProperty complexProperty, IList? newOriginalCollection) { Check.DebugAssert(HasOriginalValuesSnapshot, "This should only be called when original values are present"); diff --git a/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs b/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs index 40aff14c560..d18c400fdf9 100644 --- a/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs +++ b/src/EFCore/Extensions/EntityFrameworkQueryableExtensions.cs @@ -3099,6 +3099,87 @@ public static IQueryable AsTracking( #endregion + #region Refreshing + + internal static readonly MethodInfo RefreshMethodInfo + = typeof(EntityFrameworkQueryableExtensions).GetMethod( + nameof(Refresh), [typeof(IQueryable<>).MakeGenericType(Type.MakeGenericMethodParameter(0)), typeof(MergeOption)])!; + + + /// + /// Specifies that the current Entity Framework LINQ query should refresh already loaded objects with the specified merge option. + /// + /// The type of entity being queried. + /// The source query. + /// The MergeOption + /// A new query annotated with the given tag. + public static IQueryable Refresh( + this IQueryable source, + [NotParameterized] MergeOption mergeOption) + { + if (HasNonTrackingOrIgnoreAutoIncludes(source.Expression)) + { + throw new InvalidOperationException(CoreStrings.RefreshNonTrackingQuery); + } + + if (HasMultipleMergeOptions(source.Expression)) + { + throw new InvalidOperationException(CoreStrings.RefreshMultipleMergeOptions); + } + return + source.Provider is EntityQueryProvider + ? source.Provider.CreateQuery( + Expression.Call( + instance: null, + method: RefreshMethodInfo.MakeGenericMethod(typeof(T)), + arg0: source.Expression, + arg1: Expression.Constant(mergeOption))) + : source; + } + + private static bool HasNonTrackingOrIgnoreAutoIncludes(Expression expression) + { + Expression? current = expression; + while (current is MethodCallExpression call) + { + var method = call.Method; + if (method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)) + { + var name = method.Name; + if (name == nameof(AsNoTracking) + || name == nameof(AsNoTrackingWithIdentityResolution) + || name == nameof(IgnoreAutoIncludes)) + { + return true; + } + } + + current = call.Arguments.Count > 0 ? call.Arguments[0] : null; + } + + return false; + } + + private static bool HasMultipleMergeOptions(Expression expression) + { + Expression? current = expression; + while (current is MethodCallExpression call) + { + var method = call.Method; + if (method.DeclaringType == typeof(EntityFrameworkQueryableExtensions) + && method.Name == nameof(Refresh)) + { + return true; + } + + current = call.Arguments.Count > 0 ? call.Arguments[0] : null; + } + + return false; + } + + #endregion + #region Tagging /// diff --git a/src/EFCore/MergeOption.cs b/src/EFCore/MergeOption.cs new file mode 100644 index 00000000000..a8eecb143af --- /dev/null +++ b/src/EFCore/MergeOption.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore; + +/// +/// The different ways that new objects loaded from the database can be merged with existing objects already in memory. +/// +public enum MergeOption +{ + /// + /// Will only append new (top level-unique) rows. This is the default behavior. + /// + AppendOnly = 0, + + /// + /// The incoming values for this row will be written to both the current value and + /// the original value versions of the data for each column. + /// + OverwriteChanges = 1, + + /// + /// The incoming values for this row will be written to the original value version + /// of each column. The current version of the data in each column will not be changed. + /// + PreserveChanges = 2 +} diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs index 193253511d8..7e3f9ca482d 100644 --- a/src/EFCore/Properties/CoreStrings.Designer.cs +++ b/src/EFCore/Properties/CoreStrings.Designer.cs @@ -3489,6 +3489,19 @@ public static string WrongStateManager(object? entityType) GetString("WrongStateManager", nameof(entityType)), entityType); + /// + /// Unable to refresh when is not tracked query + /// + public static string RefreshNonTrackingQuery + => GetString("RefreshNonTrackingQuery"); + + /// + /// Merge option changed on same query + /// + public static string RefreshMultipleMergeOptions + => GetString("RefreshMultipleMergeOptions"); + + private static string GetString(string name, params string[] formatterNames) { var value = _resourceManager.GetString(name)!; diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx index 38331c9aad2..766afbc75ca 100644 --- a/src/EFCore/Properties/CoreStrings.resx +++ b/src/EFCore/Properties/CoreStrings.resx @@ -1577,6 +1577,12 @@ The navigation '{1_entityType}.{0_navigation}' cannot have 'IsLoaded' set to false because the referenced entity is non-null and is therefore loaded. + + The Refresh method can only be called once per query. Multiple merge options were specified. + + + The Refresh method requires a tracking query. Call AsTracking() before calling Refresh(), or use a context with tracking enabled by default. + The principal and dependent ends of the relationship cannot be changed once foreign key or principal key properties have been specified. Remove the conflicting configuration. diff --git a/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs b/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs index 3589fcc6195..5563e2467c1 100644 --- a/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/QueryableMethodNormalizingExpressionVisitor.cs @@ -360,6 +360,14 @@ when methodCallExpression.Arguments is _queryCompilationContext.IgnoreAutoIncludes = true; return visitedExpression; } + + case nameof(EntityFrameworkQueryableExtensions.Refresh): + { + var visitedExpression = Visit(methodCallExpression.Arguments[0]); + _queryCompilationContext.RefreshMergeOption = methodCallExpression.Arguments[1].GetConstantValue(); + return visitedExpression; + } + } } diff --git a/src/EFCore/Query/QueryCompilationContext.cs b/src/EFCore/Query/QueryCompilationContext.cs index af8d5e9f4ce..fd5700c8bf5 100644 --- a/src/EFCore/Query/QueryCompilationContext.cs +++ b/src/EFCore/Query/QueryCompilationContext.cs @@ -136,6 +136,21 @@ public QueryCompilationContext(QueryCompilationContextDependencies dependencies, /// public virtual bool IgnoreAutoIncludes { get; internal set; } + /// + /// + /// A value indicating how already loaded objects should be merged and refreshed with the results of this query. + /// + /// + /// This property is typically used by database providers (and other extensions). It is generally + /// not used in application code. + /// + /// + /// + /// See Implementation of database providers and extensions + /// and How EF Core queries work for more information and examples. + /// + public virtual MergeOption RefreshMergeOption { get; internal set; } + /// /// The set of tags applied to this query. /// diff --git a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs index 7e071275b9e..5be4a441a7f 100644 --- a/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs @@ -59,7 +59,8 @@ protected ShapedQueryCompilingExpressionVisitor( dependencies.EntityMaterializerSource, dependencies.LiftableConstantFactory, queryCompilationContext.QueryTrackingBehavior, - queryCompilationContext.SupportsPrecompiledQuery); + queryCompilationContext.SupportsPrecompiledQuery, + queryCompilationContext.RefreshMergeOption); _constantVerifyingExpressionVisitor = new ConstantVerifyingExpressionVisitor(dependencies.TypeMappingSource); _materializationConditionConstantLifter = new MaterializationConditionConstantLifter(dependencies.LiftableConstantFactory); @@ -377,7 +378,8 @@ private sealed class StructuralTypeMaterializerInjector( IStructuralTypeMaterializerSource materializerSource, ILiftableConstantFactory liftableConstantFactory, QueryTrackingBehavior queryTrackingBehavior, - bool supportsPrecompiledQuery) + bool supportsPrecompiledQuery, + MergeOption mergeOption) : ExpressionVisitor { private static readonly ConstructorInfo MaterializationContextConstructor @@ -410,6 +412,8 @@ private static readonly MethodInfo CreateNullKeyValueInNoTrackingQueryMethod private readonly bool _queryStateManager = queryTrackingBehavior is QueryTrackingBehavior.TrackAll or QueryTrackingBehavior.NoTrackingWithIdentityResolution; + private readonly MergeOption _mergeOption = mergeOption; + private readonly ISet _visitedEntityTypes = new HashSet(); private readonly MaterializationConditionConstantLifter _materializationConditionConstantLifter = new(liftableConstantFactory); private int _currentEntityIndex; @@ -523,7 +527,15 @@ private Expression ProcessStructuralTypeShaper(StructuralTypeShaperExpression sh Assign( instanceVariable, Convert( MakeMemberAccess(entryVariable, EntityMemberInfo), - clrType))), + clrType)), + // Update the existing entity with new property values from the database + // if the merge option is not AppendOnly + _mergeOption != MergeOption.AppendOnly + ? UpdateExistingEntityWithDatabaseValues( + entryVariable, + concreteEntityTypeVariable, + materializationContextVariable, + shaper) : Empty()), MaterializeEntity( shaper, materializationContextVariable, concreteEntityTypeVariable, instanceVariable, entryVariable)))); @@ -766,5 +778,71 @@ private BlockExpression CreateFullMaterializeExpression( return Block(blockExpressions); } + + /// + /// Creates an expression to update an existing tracked entity with values from the database, + /// similar to the EntityEntry.Reload() method. + /// + /// The variable representing the existing InternalEntityEntry. + /// The variable representing the concrete entity type. + /// The materialization context variable. + /// The structural type shaper expression. + /// An expression that updates the existing entity with database values. + private Expression UpdateExistingEntityWithDatabaseValues( + ParameterExpression entryVariable, + ParameterExpression concreteEntityTypeVariable, + ParameterExpression materializationContextVariable, + StructuralTypeShaperExpression shaper) + { + var updateExpressions = new List(); + var typeBase = shaper.StructuralType; + + if (typeBase is not IEntityType entityType) + { + // For complex types, we don't update existing instances + return Empty(); + } + + var valueBufferExpression = Call(materializationContextVariable, MaterializationContext.GetValueBufferMethod); + + // Get all properties to update (exclude key properties which should not change) + var propertiesToUpdate = entityType.GetProperties() + .Where(p => !p.IsPrimaryKey()) + .ToList(); + + var setReloadValueMethod = typeof(InternalEntityEntry) + .GetMethod(nameof(InternalEntityEntry.ReloadValue), new[] { typeof(IPropertyBase), typeof(object), typeof(MergeOption), typeof(bool) })!; + + // Update original values similar to EntityEntry.Reload() + // This ensures that the original values snapshot reflects the database state + var dbProperties = propertiesToUpdate; + int count = dbProperties.Count(); + int i = 0; + foreach (var property in dbProperties) + { + i++; + var newValue = valueBufferExpression.CreateValueBufferReadValueExpression( + property.ClrType, + property.GetIndex(), + property); + + var setOriginalValueExpression = Call( + entryVariable, + setReloadValueMethod, + Constant(property), + property.ClrType.IsValueType && property.IsNullable + ? (Expression)Convert(newValue, typeof(object)) + : Convert(newValue, typeof(object)), + Constant(_mergeOption), + Constant(i == count)); + + updateExpressions.Add(setOriginalValueExpression); + } + + return updateExpressions.Count > 0 + ? (Expression)Block(updateExpressions) + : Empty(); + } + } } diff --git a/test/EFCore.Specification.Tests/MergeOptionTestBase.cs b/test/EFCore.Specification.Tests/MergeOptionTestBase.cs new file mode 100644 index 00000000000..2414727d47b --- /dev/null +++ b/test/EFCore.Specification.Tests/MergeOptionTestBase.cs @@ -0,0 +1,886 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.TestUtilities; + +namespace Microsoft.EntityFrameworkCore.Query; + +#nullable disable + +public abstract partial class MergeOptionTestBase(TFixture fixture) : IClassFixture + where TFixture : MergeOptionTestBase.MergeOptionFixtureBase +{ + protected TFixture Fixture { get; } = fixture; + + protected DbContext CreateContext() => Fixture.CreateContext(); + + protected abstract void UseTransaction(DbContext context, Action testAction); + + protected abstract Task UseTransactionAsync(DbContext context, Func testAction); + + protected virtual void ClearLog() + { + } + + protected virtual void RecordLog() + { + } + + [ConditionalFact] + public virtual void Can_use_Refresh_with_OverwriteChanges() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + + product.Name = "Modified locally"; + Assert.Equal("Modified locally", product.Name); + + var newName = "Changed in database"; + UpdateProductNameInDatabase(ctx, product.Id, newName); + + var refreshed = ctx.Set() + .Where(p => p.Id == product.Id) + .Refresh(MergeOption.OverwriteChanges) + .First(); + + Assert.Same(product, refreshed); + Assert.Equal(newName, refreshed.Name); + Assert.Equal(newName, ctx.Entry(product).Property(p => p.Name).OriginalValue); + Assert.Equal(EntityState.Unchanged, ctx.Entry(product).State); + }); + } + + [ConditionalFact] + public virtual async Task Can_use_Refresh_with_OverwriteChanges_async() + { + using var context = CreateContext(); + + await UseTransactionAsync(context, async ctx => + { + var product = await ctx.Set().FirstAsync(); + + product.Name = "Modified locally"; + Assert.Equal("Modified locally", product.Name); + + var newName = "Changed in database"; + await UpdateProductNameInDatabaseAsync(ctx, product.Id, newName); + + var refreshed = await ctx.Set() + .Where(p => p.Id == product.Id) + .Refresh(MergeOption.OverwriteChanges) + .FirstAsync(); + + Assert.Same(product, refreshed); + Assert.Equal(newName, refreshed.Name); + Assert.Equal(newName, ctx.Entry(product).Property(p => p.Name).OriginalValue); + Assert.Equal(EntityState.Unchanged, ctx.Entry(product).State); + }); + } + + [ConditionalFact] + public virtual void Refresh_with_PreserveChanges_keeps_local_modifications() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + + product.Price = 999.99m; + Assert.Equal(999.99m, product.Price); + ctx.Entry(product).Property(p => p.Price).IsModified = true; + + var newPrice = 123.45m; + UpdateProductPriceInDatabase(ctx, product.Id, newPrice); + + var refreshed = ctx.Set() + .Where(p => p.Id == product.Id) + .Refresh(MergeOption.PreserveChanges) + .First(); + + Assert.Same(product, refreshed); + Assert.Equal(999.99m, refreshed.Price); + Assert.Equal(newPrice, ctx.Entry(product).Property(p => p.Price).OriginalValue); + Assert.Equal(EntityState.Modified, ctx.Entry(product).State); + }); + } + + [ConditionalFact] + public virtual void Refresh_throws_on_non_tracking_query() + { + using var context = CreateContext(); + + Assert.Throws(() => + context.Set() + .AsNoTracking() + .Refresh(MergeOption.OverwriteChanges) + .ToList()); + } + + [ConditionalFact] + public virtual void Refresh_throws_on_multiple_merge_options() + { + using var context = CreateContext(); + + Assert.Throws(() => + context.Set() + .Refresh(MergeOption.OverwriteChanges) + .Refresh(MergeOption.PreserveChanges) + .ToList()); + } + + [ConditionalFact] + public virtual void Refresh_works_with_ToList() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var products = ctx.Set().ToList(); + var firstProduct = products.First(); + + firstProduct.Name = "Modified"; + + UpdateProductNameInDatabase(ctx, firstProduct.Id, "Updated"); + + var refreshed = ctx.Set() + .Refresh(MergeOption.OverwriteChanges) + .ToList(); + + Assert.Equal("Updated", firstProduct.Name); + }); + } + + [ConditionalFact] + public virtual void Refresh_works_with_FirstOrDefault() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + product.Name = "Modified"; + + UpdateProductNameInDatabase(ctx, product.Id, "Updated"); + + var refreshed = ctx.Set() + .Where(p => p.Id == product.Id) + .Refresh(MergeOption.OverwriteChanges) + .FirstOrDefault(); + + Assert.Same(product, refreshed); + Assert.Equal("Updated", refreshed.Name); + }); + } + + [ConditionalFact] + public virtual void Refresh_works_with_Include() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var order = ctx.Set() + .Include(o => o.OrderDetails) + .First(); + + order.CustomerName = "Modified"; + + UpdateOrderCustomerNameInDatabase(ctx, order.Id, "Updated"); + + var refreshed = ctx.Set() + .Include(o => o.OrderDetails) + .Where(o => o.Id == order.Id) + .Refresh(MergeOption.OverwriteChanges) + .First(); + + Assert.Same(order, refreshed); + Assert.Equal("Updated", refreshed.CustomerName); + }); + } + + [ConditionalFact] + public virtual void Refresh_with_modified_property() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + product.Price = 100m; + product.Quantity = 5; + + ctx.SaveChanges(); + + UpdateProductInDatabase(ctx, product.Id, 200m, 10); + + var refreshed = ctx.Set() + .Where(p => p.Id == product.Id) + .Refresh(MergeOption.OverwriteChanges) + .First(); + + Assert.Equal(200m, refreshed.Price); + Assert.Equal(10, refreshed.Quantity); + }); + } + + [ConditionalFact] + public virtual void EntityEntry_Reload_with_MergeOption_OverwriteChanges() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + product.Name = "Modified"; + + UpdateProductNameInDatabase(ctx, product.Id, "Updated"); + + ctx.Entry(product).Reload(MergeOption.OverwriteChanges); + + Assert.Equal("Updated", product.Name); + Assert.Equal(EntityState.Unchanged, ctx.Entry(product).State); + }); + } + + [ConditionalFact] + public virtual async Task EntityEntry_ReloadAsync_with_MergeOption_OverwriteChanges() + { + using var context = CreateContext(); + + await UseTransactionAsync(context, async ctx => + { + var product = await ctx.Set().FirstAsync(); + product.Name = "Modified"; + + await UpdateProductNameInDatabaseAsync(ctx, product.Id, "Updated"); + + await ctx.Entry(product).ReloadAsync(MergeOption.OverwriteChanges); + + Assert.Equal("Updated", product.Name); + Assert.Equal(EntityState.Unchanged, ctx.Entry(product).State); + }); + } + + [ConditionalFact] + public virtual void EntityEntry_Reload_with_MergeOption_PreserveChanges() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + + product.Price = 999.99m; + + UpdateProductPriceInDatabase(ctx, product.Id, 123.45m); + + ctx.Entry(product).Reload(MergeOption.PreserveChanges); + + Assert.Equal(999.99m, product.Price); + Assert.Equal(123.45m, ctx.Entry(product).Property(p => p.Price).OriginalValue); + Assert.Equal(EntityState.Modified, ctx.Entry(product).State); + }); + } + + [ConditionalFact] + public virtual void Refresh_many_to_many_relationship() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var student = ctx.Set().Include(s => s.Courses).First(); + var originalCourseCount = student.Courses.Count; + + var courseToAdd = ctx.Set().First(c => !student.Courses.Contains(c)); + + AddStudentCourseInDatabase(ctx, student.Id, courseToAdd.Id); + + var coll = ctx.Entry(student).Collection(s => s.Courses); + coll.IsLoaded = false; + coll.Load(); + + Assert.Equal(originalCourseCount + 1, student.Courses.Count); + Assert.Contains(student.Courses, c => c.Id == courseToAdd.Id); + }); + } + + [ConditionalFact] + public virtual async Task Refresh_many_to_many_relationship_async() + { + using var context = CreateContext(); + + await UseTransactionAsync(context, async ctx => + { + var student = await ctx.Set().Include(s => s.Courses).FirstAsync(); + var originalCourseCount = student.Courses.Count(); + + var courseToAdd = await ctx.Set().FirstAsync(c => !student.Courses.Contains(c)); + + await AddStudentCourseInDatabaseAsync(ctx, student.Id, courseToAdd.Id); + + var coll = ctx.Entry(student).Collection(s => s.Courses); + coll.IsLoaded = false; + await coll.LoadAsync(); + + Assert.Equal(originalCourseCount + 1, student.Courses.Count); + Assert.Contains(student.Courses, c => c.Id == courseToAdd.Id); + }); + } + + [ConditionalFact] + public virtual void Refresh_with_shadow_property() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var book = ctx.Set().First(); + var originalPublisher = ctx.Entry(book).Property("Publisher").CurrentValue; + + var newPublisher = "Updated Publisher"; + UpdateBookPublisherInDatabase(ctx, book.Id, newPublisher); + + var refreshed = ctx.Set() + .Where(b => b.Id == book.Id) + .Refresh(MergeOption.OverwriteChanges) + .First(); + + Assert.Same(book, refreshed); + Assert.Equal(newPublisher, ctx.Entry(book).Property("Publisher").CurrentValue); + }); + } + + [ConditionalFact] + public virtual void Refresh_respects_global_query_filter() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var activeCategories = ctx.Set().ToList(); + + Assert.All(activeCategories, c => Assert.True(c.IsActive)); + Assert.DoesNotContain(activeCategories, c => c.Id == 2); + }); + } + + [ConditionalFact] + public virtual void Refresh_with_primitive_collection() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + var originalTags = product.Tags.ToList(); + + var newTags = new List { "newTag1", "newTag2", "newTag3" }; + UpdateProductTagsInDatabase(ctx, product.Id, newTags); + + var refreshed = ctx.Set() + .Where(p => p.Id == product.Id) + .Refresh(MergeOption.OverwriteChanges) + .First(); + + Assert.Equal(3, refreshed.Tags.Count); + Assert.Contains("newTag1", refreshed.Tags); + Assert.Contains("newTag2", refreshed.Tags); + }); + } + + [ConditionalFact] + public virtual void Refresh_with_enum_value_converter() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + product.Status = ProductStatus.Active; + + UpdateProductStatusInDatabase(ctx, product.Id, ProductStatus.Discontinued); + + var refreshed = ctx.Set() + .Where(p => p.Id == product.Id) + .Refresh(MergeOption.OverwriteChanges) + .First(); + + Assert.Equal(ProductStatus.Discontinued, refreshed.Status); + }); + } + + [ConditionalFact] + public virtual void Refresh_entity_in_different_states() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var unchangedProduct = ctx.Set().OrderBy(p => p.Id).First(); + var modifiedProduct = ctx.Set().OrderBy(p => p.Id).Skip(1).First(); + modifiedProduct.Name = "Modified Name"; + + var newProduct = new Product { Id = 999, Name = "New Product", Price = 99.99m, Quantity = 10, Status = ProductStatus.Active, Tags = [] }; + ctx.Add(newProduct); + + Assert.Equal(EntityState.Unchanged, ctx.Entry(unchangedProduct).State); + Assert.Equal(EntityState.Modified, ctx.Entry(modifiedProduct).State); + Assert.Equal(EntityState.Added, ctx.Entry(newProduct).State); + + UpdateProductNameInDatabase(ctx, unchangedProduct.Id, "DB Updated"); + + var refreshed = ctx.Set() + .Refresh(MergeOption.OverwriteChanges) + .ToList(); + + Assert.Equal("DB Updated", unchangedProduct.Name); + Assert.Equal(EntityState.Unchanged, ctx.Entry(unchangedProduct).State); + }); + } + + [ConditionalFact] + public virtual void Refresh_with_ThenInclude() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var student = ctx.Set() + .Include(s => s.Courses) + .ThenInclude(c => c.Students) + .First(); + + student.Name = "Modified Name"; + + UpdateStudentNameInDatabase(ctx, student.Id, "Updated Name"); + + var refreshed = ctx.Set() + .Include(s => s.Courses) + .ThenInclude(c => c.Students) + .Where(s => s.Id == student.Id) + .Refresh(MergeOption.OverwriteChanges) + .First(); + + Assert.Same(student, refreshed); + Assert.Equal("Updated Name", refreshed.Name); + Assert.NotEmpty(refreshed.Courses); + }); + } + + [ConditionalFact] + public virtual void Refresh_PreserveChanges_with_unchanged_entity() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + var originalName = product.Name; + + Assert.Equal(EntityState.Unchanged, ctx.Entry(product).State); + + UpdateProductNameInDatabase(ctx, product.Id, "DB Modified"); + + var refreshed = ctx.Set() + .Where(p => p.Id == product.Id) + .Refresh(MergeOption.PreserveChanges) + .First(); + + Assert.Equal("DB Modified", refreshed.Name); + Assert.Equal(EntityState.Unchanged, ctx.Entry(product).State); + }); + } + + [ConditionalFact] + public virtual void Refresh_PreserveChanges_modified_property_not_overwritten() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + var originalPrice = product.Price; + + product.Price = 999.99m; + product.Name = "Modified Name"; + ctx.Entry(product).Property(p => p.Price).IsModified = true; + ctx.Entry(product).Property(p => p.Name).IsModified = true; + + var newPrice = 123.45m; + var newName = "DB Name"; + UpdateProductInDatabase(ctx, product.Id, newPrice, product.Quantity); + UpdateProductNameInDatabase(ctx, product.Id, newName); + + var refreshed = ctx.Set() + .Where(p => p.Id == product.Id) + .Refresh(MergeOption.PreserveChanges) + .First(); + + Assert.Equal(999.99m, refreshed.Price); + Assert.Equal("Modified Name", refreshed.Name); + Assert.Equal(newPrice, ctx.Entry(product).Property(p => p.Price).OriginalValue); + Assert.Equal(newName, ctx.Entry(product).Property(p => p.Name).OriginalValue); + Assert.Equal(EntityState.Modified, ctx.Entry(product).State); + }); + } + + [ConditionalFact] + public virtual void Refresh_unchanged_with_mismatched_original_value() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + var currentName = product.Name; + + Assert.Equal(EntityState.Unchanged, ctx.Entry(product).State); + + ctx.Entry(product).Property(p => p.Name).OriginalValue = "Different Original"; + ctx.Entry(product).Property(p => p.Name).IsModified = false; + + UpdateProductNameInDatabase(ctx, product.Id, "DB Updated Name"); + + var refreshed = ctx.Set() + .Where(p => p.Id == product.Id) + .Refresh(MergeOption.PreserveChanges) + .First(); + + Assert.Equal("DB Updated Name", refreshed.Name); + Assert.Equal("DB Updated Name", ctx.Entry(product).Property(p => p.Name).OriginalValue); + Assert.Equal(EntityState.Unchanged, ctx.Entry(product).State); + }); + } + + [ConditionalFact] + public virtual void Refresh_modified_with_matching_original_value() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var product = ctx.Set().First(); + var dbValue = product.Price; + + product.Price = 500.00m; + Assert.Equal(EntityState.Modified, ctx.Entry(product).State); + + var originalValueInDb = ctx.Entry(product).Property(p => p.Price).OriginalValue; + Assert.Equal(dbValue, originalValueInDb); + + var refreshed = ctx.Set() + .Where(p => p.Id == product.Id) + .Refresh(MergeOption.PreserveChanges) + .First(); + + Assert.Equal(500.00m, refreshed.Price); + Assert.Equal(dbValue, ctx.Entry(product).Property(p => p.Price).OriginalValue); + Assert.Equal(EntityState.Modified, ctx.Entry(product).State); + }); + } + + [ConditionalFact] + public virtual void Refresh_with_owned_entity() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var order = ctx.Set().First(); + var originalShippingCity = order.ShippingAddress.City; + + order.ShippingAddress.City = "Modified City"; + + UpdateOrderShippingCityInDatabase(ctx, order.Id, "DB City"); + + var refreshed = ctx.Set() + .Where(o => o.Id == order.Id) + .Refresh(MergeOption.OverwriteChanges) + .First(); + + Assert.Equal("DB City", refreshed.ShippingAddress.City); + }); + } + + [ConditionalFact] + public virtual void Refresh_with_tph_inheritance() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var premiumProduct = ctx.Set().First(); + var originalRewardPoints = premiumProduct.RewardPoints; + + premiumProduct.RewardPoints = 9999; + + UpdatePremiumProductRewardPointsInDatabase(ctx, premiumProduct.Id, 5000); + + var refreshed = ctx.Set() + .Where(p => p.Id == premiumProduct.Id) + .Refresh(MergeOption.OverwriteChanges) + .First(); + + Assert.Equal(5000, refreshed.RewardPoints); + Assert.Equal(EntityState.Unchanged, ctx.Entry(refreshed).State); + }); + } + + [ConditionalFact] + public virtual async Task Refresh_with_streaming_query() + { + using var context = CreateContext(); + + await UseTransactionAsync(context, async ctx => + { + var count = 0; + await foreach (var product in ctx.Set() + .Refresh(MergeOption.AppendOnly) + .AsAsyncEnumerable()) + { + Assert.NotNull(product.Name); + count++; + if (count >= 2) + break; + } + + Assert.True(count >= 2); + }); + } + + [ConditionalFact] + public virtual void Refresh_same_entity_projected_multiple_times() + { + using var context = CreateContext(); + + UseTransaction(context, ctx => + { + var result = ctx.Set() + .Select(p => new { First = p, Second = p }) + .Refresh(MergeOption.AppendOnly) + .First(); + + Assert.Same(result.First, result.Second); + }); + } + + protected abstract void AddStudentCourseInDatabase(DbContext context, int studentId, int courseId); + protected abstract Task AddStudentCourseInDatabaseAsync(DbContext context, int studentId, int courseId); + protected abstract void UpdateBookPublisherInDatabase(DbContext context, int bookId, string newPublisher); + protected abstract void UpdateProductTagsInDatabase(DbContext context, int productId, List newTags); + protected abstract void UpdateProductStatusInDatabase(DbContext context, int productId, ProductStatus newStatus); + protected abstract void UpdateStudentNameInDatabase(DbContext context, int studentId, string newName); + protected abstract void UpdateOrderShippingCityInDatabase(DbContext context, int orderId, string newCity); + protected abstract void UpdatePremiumProductRewardPointsInDatabase(DbContext context, int productId, int newRewardPoints); + + protected abstract void UpdateProductNameInDatabase(DbContext context, int id, string newName); + protected abstract Task UpdateProductNameInDatabaseAsync(DbContext context, int id, string newName); + protected abstract void UpdateProductPriceInDatabase(DbContext context, int id, decimal newPrice); + protected abstract void UpdateOrderCustomerNameInDatabase(DbContext context, int id, string newName); + protected abstract void UpdateProductInDatabase(DbContext context, int id, decimal newPrice, int newQuantity); + + public abstract class MergeOptionFixtureBase : SharedStoreFixtureBase + { + protected override string StoreName => "MergeOptionTest"; + + protected override bool RecreateStore + => true; + + protected override Type ContextType { get; } = typeof(MergeOptionContext); + + protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext context) + { + modelBuilder.Entity(b => + { + b.Property(p => p.Id).ValueGeneratedNever(); + b.Property(p => p.Name).IsRequired(); + }); + + modelBuilder.Entity(b => + { + b.Property(o => o.Id).ValueGeneratedNever(); + b.HasMany(o => o.OrderDetails).WithOne(od => od.Order).HasForeignKey(od => od.OrderId); + b.OwnsOne(o => o.ShippingAddress, a => + { + a.Property(addr => addr.City).IsRequired(); + }); + }); + + modelBuilder.Entity(b => + { + b.Property(od => od.Id).ValueGeneratedNever(); + }); + + modelBuilder.Entity(b => + { + b.Property(s => s.Id).ValueGeneratedNever(); + b.HasMany(s => s.Courses).WithMany(c => c.Students); + }); + + modelBuilder.Entity(b => + { + b.Property(c => c.Id).ValueGeneratedNever(); + }); + + modelBuilder.Entity(b => + { + b.Property(bk => bk.Id).ValueGeneratedNever(); + b.Property("CreatedDate"); + b.Property("Publisher").HasMaxLength(100); + }); + + modelBuilder.Entity(b => + { + b.Property(c => c.Id).ValueGeneratedNever(); + b.HasQueryFilter(c => c.IsActive); + }); + + modelBuilder.Entity(b => + { + b.HasBaseType(); + }); + } + + protected override async Task SeedAsync(PoolableDbContext context) + { + await context.Database.EnsureCreatedResilientlyAsync(); + + var product1 = new Product { Id = 1, Name = "Product 1", Price = 10.99m, Quantity = 100, Status = ProductStatus.Active, Tags = ["tag1", "tag2"] }; + var product2 = new Product { Id = 2, Name = "Product 2", Price = 20.99m, Quantity = 50, Status = ProductStatus.Active, Tags = ["tag2", "tag3"] }; + var product3 = new Product { Id = 3, Name = "Product 3", Price = 30.99m, Quantity = 25, Status = ProductStatus.Inactive, Tags = ["tag3"] }; + var premiumProduct1 = new PremiumProduct { Id = 4, Name = "Premium Product 1", Price = 99.99m, Quantity = 10, Status = ProductStatus.Active, Tags = ["premium"], RewardPoints = 1000 }; + + var order1 = new Order + { + Id = 1, + CustomerName = "Customer 1", + ShippingAddress = new Address { Street = "123 Main St", City = "City1", PostalCode = "12345" }, + OrderDetails = new List + { + new() { Id = 1, ProductId = 1, Quantity = 2 }, + new() { Id = 2, ProductId = 2, Quantity = 1 } + } + }; + + var order2 = new Order + { + Id = 2, + CustomerName = "Customer 2", + ShippingAddress = new Address { Street = "456 Oak Ave", City = "City2", PostalCode = "67890" }, + OrderDetails = new List + { + new() { Id = 3, ProductId = 3, Quantity = 3 } + } + }; + + var course1 = new Course { Id = 1, Name = "Math", Description = "Mathematics" }; + var course2 = new Course { Id = 2, Name = "Science", Description = "Natural Sciences" }; + var course3 = new Course { Id = 3, Name = "History", Description = "World History" }; + + var student1 = new Student { Id = 1, Name = "John", Email = "john@test.com", Courses = [course1, course2] }; + var student2 = new Student { Id = 2, Name = "Jane", Email = "jane@test.com", Courses = [course2, course3] }; + + var book1 = new Book { Id = 1, Title = "Book 1" }; + var book2 = new Book { Id = 2, Title = "Book 2" }; + + var category1 = new Category { Id = 1, Name = "Active Category", IsActive = true }; + var category2 = new Category { Id = 2, Name = "Inactive Category", IsActive = false }; + + context.AddRange(product1, product2, product3, premiumProduct1, order1, order2); + context.AddRange(student1, student2, course3); + context.AddRange(book1, book2); + context.AddRange(category1, category2); + + // Set shadow properties for books + context.Entry(book1).Property("CreatedDate").CurrentValue = DateTime.UtcNow.AddDays(-30); + context.Entry(book1).Property("Publisher").CurrentValue = "Publisher A"; + context.Entry(book2).Property("CreatedDate").CurrentValue = DateTime.UtcNow.AddDays(-15); + context.Entry(book2).Property("Publisher").CurrentValue = "Publisher B"; + + await context.SaveChangesAsync(); + } + } + + protected class MergeOptionContext(DbContextOptions options) : PoolableDbContext(options); + + protected class Product + { + public int Id { get; set; } + public string Name { get; set; } + public decimal Price { get; set; } + public int Quantity { get; set; } + public List Tags { get; set; } + public ProductStatus Status { get; set; } + } + + protected class PremiumProduct : Product + { + public int RewardPoints { get; set; } + } + + protected enum ProductStatus + { + Active, + Inactive, + Discontinued + } + + protected class Order + { + public int Id { get; set; } + public string CustomerName { get; set; } + public List OrderDetails { get; set; } + public Address ShippingAddress { get; set; } + } + + protected class Address + { + public string Street { get; set; } + public string City { get; set; } + public string PostalCode { get; set; } + } + + protected class OrderDetail + { + public int Id { get; set; } + public int OrderId { get; set; } + public Order Order { get; set; } + public int ProductId { get; set; } + public int Quantity { get; set; } + } + + protected class Student + { + public int Id { get; set; } + public string Name { get; set; } + public string Email { get; set; } + public List Courses { get; set; } + } + + protected class Course + { + public int Id { get; set; } + public string Name { get; set; } + public string Description { get; set; } + public List Students { get; set; } + } + + protected class Book + { + public int Id { get; set; } + public string Title { get; set; } + + } + + protected class Category + { + public int Id { get; set; } + public string Name { get; set; } + public bool IsActive { get; set; } + } +} diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/MergeOptionSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/MergeOptionSqlServerTest.cs new file mode 100644 index 00000000000..cc15d16dbb3 --- /dev/null +++ b/test/EFCore.SqlServer.FunctionalTests/Query/MergeOptionSqlServerTest.cs @@ -0,0 +1,90 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.SqlServer.Storage.Internal; +using Microsoft.EntityFrameworkCore.TestUtilities; + +namespace Microsoft.EntityFrameworkCore.Query; + +public class MergeOptionSqlServerTest(MergeOptionSqlServerTest.MergeOptionSqlServerFixture fixture) + : MergeOptionTestBase(fixture) +{ + protected override void UseTransaction(DbContext context, Action testAction) + { + using var transaction = context.Database.BeginTransaction(); + testAction(context); + transaction.Rollback(); + } + + protected override async Task UseTransactionAsync(DbContext context, Func testAction) + { + await using var transaction = await context.Database.BeginTransactionAsync(); + await testAction(context); + await transaction.RollbackAsync(); + } + + protected override void UpdateProductNameInDatabase(DbContext context, int id, string newName) + => context.Database.ExecuteSql($"UPDATE [Product] SET Name = {newName} WHERE Id = {id}"); + + protected override Task UpdateProductNameInDatabaseAsync(DbContext context, int id, string newName) + => context.Database.ExecuteSqlAsync($"UPDATE [Product] SET Name = {newName} WHERE Id = {id}"); + + protected override void UpdateProductPriceInDatabase(DbContext context, int id, decimal newPrice) + => context.Database.ExecuteSql($"UPDATE [Product] SET Price = {newPrice} WHERE Id = {id}"); + + protected override void UpdateOrderCustomerNameInDatabase(DbContext context, int id, string newName) + => context.Database.ExecuteSql($"UPDATE [Order] SET CustomerName = {newName} WHERE Id = {id}"); + + protected override void UpdateProductInDatabase(DbContext context, int id, decimal newPrice, int newQuantity) + => context.Database.ExecuteSql($"UPDATE [Product] SET Price = {newPrice}, Quantity = {newQuantity} WHERE Id = {id}"); + + protected override void AddStudentCourseInDatabase(DbContext context, int studentId, int courseId) + => context.Database.ExecuteSql($"INSERT INTO [CourseStudent] (CoursesId, StudentsId) VALUES ({courseId}, {studentId})"); + + protected override Task AddStudentCourseInDatabaseAsync(DbContext context, int studentId, int courseId) + => context.Database.ExecuteSqlAsync($"INSERT INTO [CourseStudent] (CoursesId, StudentsId) VALUES ({courseId}, {studentId})"); + + protected override void UpdateBookPublisherInDatabase(DbContext context, int bookId, string newPublisher) + => context.Database.ExecuteSql($"UPDATE [Book] SET Publisher = {newPublisher} WHERE Id = {bookId}"); + + protected override void UpdateProductTagsInDatabase(DbContext context, int productId, List newTags) + { + var tagsJson = System.Text.Json.JsonSerializer.Serialize(newTags); + context.Database.ExecuteSql($"UPDATE [Product] SET Tags = {tagsJson} WHERE Id = {productId}"); + } + + protected override void UpdateProductStatusInDatabase(DbContext context, int productId, ProductStatus newStatus) + => context.Database.ExecuteSql($"UPDATE [Product] SET Status = {(int)newStatus} WHERE Id = {productId}"); + + protected override void UpdateStudentNameInDatabase(DbContext context, int studentId, string newName) + => context.Database.ExecuteSql($"UPDATE [Student] SET Name = {newName} WHERE Id = {studentId}"); + + protected override void UpdateOrderShippingCityInDatabase(DbContext context, int orderId, string newCity) + => context.Database.ExecuteSql($"UPDATE [Order] SET ShippingAddress_City = {newCity} WHERE Id = {orderId}"); + + protected override void UpdatePremiumProductRewardPointsInDatabase(DbContext context, int productId, int newRewardPoints) + => context.Database.ExecuteSql($"UPDATE [Product] SET RewardPoints = {newRewardPoints} WHERE Id = {productId}"); + + public class MergeOptionSqlServerFixture : MergeOptionFixtureBase + { + protected override ITestStoreFactory TestStoreFactory + => SqlServerTestStoreFactory.Instance; + + public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder) + => base.AddOptions(builder) + .UseSqlServer(b => b.ExecutionStrategy(c => new SqlServerExecutionStrategy(c))) + .ConfigureWarnings(w => + { + w.Ignore(CoreEventId.FirstWithoutOrderByAndFilterWarning); + w.Ignore(SqlServerEventId.DecimalTypeDefaultWarning); + }); + + protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext context) + { + base.OnModelCreating(modelBuilder, context); + + modelBuilder.Entity().Property(p => p.Price).HasPrecision(18, 2); + modelBuilder.Entity(); + } + } +} diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/MergeOptionSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/MergeOptionSqliteTest.cs new file mode 100644 index 00000000000..29c7aa2b383 --- /dev/null +++ b/test/EFCore.Sqlite.FunctionalTests/Query/MergeOptionSqliteTest.cs @@ -0,0 +1,76 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.TestUtilities; + +namespace Microsoft.EntityFrameworkCore.Query; + +public class MergeOptionSqliteTest(MergeOptionSqliteTest.MergeOptionSqliteFixture fixture) + : MergeOptionTestBase(fixture) +{ + protected override void UseTransaction(DbContext context, Action testAction) + { + using var transaction = context.Database.BeginTransaction(); + testAction(context); + transaction.Rollback(); + } + + protected override async Task UseTransactionAsync(DbContext context, Func testAction) + { + await using var transaction = await context.Database.BeginTransactionAsync(); + await testAction(context); + await transaction.RollbackAsync(); + } + + protected override void UpdateProductNameInDatabase(DbContext context, int id, string newName) + => context.Database.ExecuteSql($"UPDATE \"Product\" SET Name = {newName} WHERE Id = {id}"); + + protected override Task UpdateProductNameInDatabaseAsync(DbContext context, int id, string newName) + => context.Database.ExecuteSqlAsync($"UPDATE \"Product\" SET Name = {newName} WHERE Id = {id}"); + + protected override void UpdateProductPriceInDatabase(DbContext context, int id, decimal newPrice) + => context.Database.ExecuteSql($"UPDATE \"Product\" SET Price = {newPrice} WHERE Id = {id}"); + + protected override void UpdateOrderCustomerNameInDatabase(DbContext context, int id, string newName) + => context.Database.ExecuteSql($"UPDATE \"Order\" SET CustomerName = {newName} WHERE Id = {id}"); + + protected override void UpdateProductInDatabase(DbContext context, int id, decimal newPrice, int newQuantity) + => context.Database.ExecuteSql($"UPDATE \"Product\" SET Price = {newPrice}, Quantity = {newQuantity} WHERE Id = {id}"); + + protected override void AddStudentCourseInDatabase(DbContext context, int studentId, int courseId) + => context.Database.ExecuteSql($"INSERT INTO \"CourseStudent\" (CoursesId, StudentsId) VALUES ({courseId}, {studentId})"); + + protected override Task AddStudentCourseInDatabaseAsync(DbContext context, int studentId, int courseId) + => context.Database.ExecuteSqlAsync($"INSERT INTO \"CourseStudent\" (CoursesId, StudentsId) VALUES ({courseId}, {studentId})"); + + protected override void UpdateBookPublisherInDatabase(DbContext context, int bookId, string newPublisher) + => context.Database.ExecuteSql($"UPDATE \"Book\" SET Publisher = {newPublisher} WHERE Id = {bookId}"); + + protected override void UpdateProductTagsInDatabase(DbContext context, int productId, List newTags) + { + var tagsJson = System.Text.Json.JsonSerializer.Serialize(newTags); + context.Database.ExecuteSql($"UPDATE \"Product\" SET Tags = {tagsJson} WHERE Id = {productId}"); + } + + protected override void UpdateProductStatusInDatabase(DbContext context, int productId, ProductStatus newStatus) + => context.Database.ExecuteSql($"UPDATE \"Product\" SET Status = {(int)newStatus} WHERE Id = {productId}"); + + protected override void UpdateStudentNameInDatabase(DbContext context, int studentId, string newName) + => context.Database.ExecuteSql($"UPDATE \"Student\" SET Name = {newName} WHERE Id = {studentId}"); + + protected override void UpdateOrderShippingCityInDatabase(DbContext context, int orderId, string newCity) + => context.Database.ExecuteSql($"UPDATE \"Order\" SET ShippingAddress_City = {newCity} WHERE Id = {orderId}"); + + protected override void UpdatePremiumProductRewardPointsInDatabase(DbContext context, int productId, int newRewardPoints) + => context.Database.ExecuteSql($"UPDATE \"Product\" SET RewardPoints = {newRewardPoints} WHERE Id = {productId}"); + + public class MergeOptionSqliteFixture : MergeOptionFixtureBase + { + protected override ITestStoreFactory TestStoreFactory + => SqliteTestStoreFactory.Instance; + + public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder) + => base.AddOptions(builder) + .ConfigureWarnings(w => w.Ignore(CoreEventId.FirstWithoutOrderByAndFilterWarning)); + } +}