Skip to content

Commit ffe5b3f

Browse files
committed
Add option to preserve CancellationToken parameters
It can still be useful to preserve the CancellationToken in some scenarios. For example, for [cancelling a synchronous bulk copy operation][1]. ```csharp bulkCopy.SqlRowsCopied += (_, e) => { if (cancellationToken.IsCancellationRequested) { e.Abort = true; } }; ``` I plan to submit a pull request at [PhenX.EntityFrameworkCore.BulkInsert][2] to use *Sync Method Generator* in order to greatly simplify async + sync methods implementations. 😉 [1]: https://github.com/PhenX/PhenX.EntityFrameworkCore.BulkInsert/blob/137d2fc8fed17b5aa7e6f11fccc079b7f463aff0/src/PhenX.EntityFrameworkCore.BulkInsert.SqlServer/SqlServerBulkInsertProvider.cs#L68-L74 [2]: https://github.com/PhenX/PhenX.EntityFrameworkCore.BulkInsert
1 parent 6a0b1ef commit ffe5b3f

9 files changed

Lines changed: 133 additions & 11 deletions

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ A list of changes applied to the new synchronized method:
6262
\*\* `Memory` and `ReadOnlyMemory` is preserved in sync methods if it is a type argument of a collection. This is due to a compiler limitation which states that a `ref struct` can't be the element type of an array.
6363

6464
- Remove parameters
65-
- [CancellationToken](https://learn.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken)
65+
- [CancellationToken](https://learn.microsoft.com/en-us/dotnet/api/system.threading.cancellationtoken), unless the `PreserveCancellationToken` property is set to `true`.
6666
- [IProgress\<T>](https://learn.microsoft.com/en-us/dotnet/api/system.iprogress-1), unless the `PreserveProgress` property is set to `true`.
6767
- Invocation changes
6868
- Remove `ConfigureAwait` from [Tasks](https://learn.microsoft.com/en-us/dotnet/api/system.threading.tasks.task.configureawait) and [Asynchronous Enumerations](https://learn.microsoft.com/en-us/dotnet/api/system.threading.tasks.taskasyncenumerableextensions.configureawait)
@@ -106,6 +106,18 @@ public async Task MethodAsync(IProgress<double> progress)
106106
}
107107
```
108108

109+
#### PreserveCancellationToken
110+
111+
By default, this source generator removes `CancellationToken` parameters from async methods. To preserve them, use the `PreserveCancellationToken` option.
112+
113+
```cs
114+
[Zomp.SyncMethodGenerator.CreateSyncVersion(PreserveCancellationToken = true)]
115+
public async Task MethodAsync(CancellationToken cancellationToken = default)
116+
{
117+
cancellationToken.ThrowIfCancellationRequested();
118+
}
119+
```
120+
109121
### CreateSyncVersionAttribute on a type
110122

111123
You can also decorate your type (class, struct, record, or interface) to generate a sync version for every asynchronous method.

src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ namespace Zomp.SyncMethodGenerator;
1111
/// <param name="semanticModel">The semantic model.</param>
1212
/// <param name="disableNullable">Instructs the source generator that nullable context should be disabled.</param>
1313
/// <param name="preserveProgress">Instructs the source generator to preserve <see cref="IProgress{T}"/> parameters.</param>
14+
/// <param name="preserveCancellationToken">Instructs the source generator to preserve <see cref="CancellationToken"/> parameters.</param>
1415
/// <param name="methodName">Method declaration syntax.</param>
15-
internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel, bool disableNullable, bool preserveProgress, MethodDeclarationSyntax methodName) : CSharpSyntaxRewriter
16+
internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel, bool disableNullable, bool preserveProgress, bool preserveCancellationToken, MethodDeclarationSyntax methodName) : CSharpSyntaxRewriter
1617
{
1718
public const string SyncOnly = "SYNC_ONLY";
1819

@@ -455,8 +456,19 @@ var entries
455456
}
456457
}
457458

459+
// Remove [EnumeratorCancellation] attribute
460+
var attributeLists = default(SyntaxList<AttributeListSyntax>);
461+
foreach (var attributeList in node.AttributeLists)
462+
{
463+
var attributes = SeparatedList(attributeList.Attributes.Where(a => !IsEnumeratorCancellationAttribute(a)));
464+
if (attributes.Count > 0)
465+
{
466+
attributeLists = attributeLists.Add(attributeList.WithAttributes(attributes));
467+
}
468+
}
469+
458470
return node.Type is null || TypeAlreadyQualified(node.Type) ? @base
459-
: @base.WithType(ProcessType(node.Type)).WithTriviaFrom(@base);
471+
: @base.WithType(ProcessType(node.Type)).WithAttributeLists(attributeLists).WithTriviaFrom(@base);
460472
}
461473

462474
/// <inheritdoc/>
@@ -1879,11 +1891,11 @@ private string GetNewName(IMethodSymbol methodSymbol)
18791891
private string ReplaceWithSpan(ISymbol symbol)
18801892
=> Regex.Replace(symbol.Name, Memory, Span);
18811893

1882-
private bool ShouldRemoveType(ITypeSymbol symbol)
1894+
private bool ShouldRemoveType(ITypeSymbol symbol, bool isArgument = false)
18831895
{
18841896
if (symbol is IArrayTypeSymbol at)
18851897
{
1886-
return ShouldRemoveType(at.ElementType);
1898+
return ShouldRemoveType(at.ElementType, isArgument);
18871899
}
18881900

18891901
if (symbol is not INamedTypeSymbol namedSymbol)
@@ -1903,7 +1915,7 @@ private bool ShouldRemoveType(ITypeSymbol symbol)
19031915
}
19041916
}
19051917

1906-
return (namedSymbol.IsIProgress && !preserveProgress) || namedSymbol.IsCancellationToken;
1918+
return (namedSymbol.IsIProgress && !preserveProgress) || (isArgument ? namedSymbol.IsCancellationToken : !preserveCancellationToken);
19071919
}
19081920

19091921
private bool ShouldRemoveArgument(ISymbol symbol, bool isNegated = false) => symbol switch
@@ -1912,10 +1924,10 @@ private bool ShouldRemoveType(ITypeSymbol symbol)
19121924
IPropertySymbol { IsCancellationRequested: true } => !isNegated,
19131925
IMethodSymbol ms =>
19141926
IsSpecialMethod(ms) is SpecialMethod.None or SpecialMethod.Drop
1915-
&& ((ShouldRemoveType(ms.ReturnType) && ms.MethodKind != MethodKind.LocalFunction)
1916-
|| (ms.ReceiverType is { } receiver && ShouldRemoveType(receiver)))
1927+
&& ((ShouldRemoveType(ms.ReturnType, isArgument: true) && ms.MethodKind != MethodKind.LocalFunction)
1928+
|| (ms.ReceiverType is { } receiver && ShouldRemoveType(receiver, isArgument: true)))
19171929
&& !HasSyncMethod(ms),
1918-
_ => ShouldRemoveType(GetReturnType(symbol)),
1930+
_ => ShouldRemoveType(GetReturnType(symbol), isArgument: true),
19191931
};
19201932

19211933
private bool PreProcess(
@@ -2155,6 +2167,12 @@ private SyncOnlyAttributeContext ProcessSyncOnlyAttributes(SyntaxTriviaList synt
21552167

21562168
private ISymbol? GetSymbol(SyntaxNode node) => semanticModel.GetSymbolInfo(node).Symbol;
21572169

2170+
private bool IsEnumeratorCancellationAttribute(AttributeSyntax attributeSyntax)
2171+
{
2172+
var type = GetSymbol(attributeSyntax)?.ContainingType;
2173+
return type?.IsEnumeratorCancellationAttribute == true;
2174+
}
2175+
21582176
private bool CanDropDeclaration(LocalDeclarationStatementSyntax local)
21592177
{
21602178
var symbol = GetSymbol(local.Declaration.Type);
@@ -2203,7 +2221,7 @@ private TypeSyntax ProcessSyntaxUsingSymbol(TypeSyntax typeSyntax)
22032221
};
22042222

22052223
private bool ChecksIfNegatedIsCancellationRequested(ExpressionSyntax condition)
2206-
=> RemoveParentheses(condition) is PrefixUnaryExpressionSyntax { OperatorToken.RawKind: (int)SyntaxKind.ExclamationToken } pe
2224+
=> !preserveCancellationToken && RemoveParentheses(condition) is PrefixUnaryExpressionSyntax { OperatorToken.RawKind: (int)SyntaxKind.ExclamationToken } pe
22072225
&& RemoveParentheses(pe.Operand) is MemberAccessExpressionSyntax { Name.Identifier.ValueText: IsCancellationRequested } mae
22082226
&& GetSymbol(mae) is { ContainingType.IsCancellationToken: true };
22092227

src/Zomp.SyncMethodGenerator/Extensions.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ internal static class Extensions
5252
ContainingNamespace: { Name: Threading, ContainingNamespace: { Name: System, ContainingNamespace.IsGlobalNamespace: true } }
5353
};
5454

55+
public bool IsEnumeratorCancellationAttribute => symbol is
56+
{
57+
Name: EnumeratorCancellationAttribute, IsGenericType: false,
58+
ContainingNamespace: { Name: CompilerServices, ContainingNamespace: { Name: Runtime, ContainingNamespace: { Name: System, ContainingNamespace.IsGlobalNamespace: true } } }
59+
};
60+
5561
public bool IsIProgress => symbol is
5662
{
5763
Name: IProgress, IsGenericType: true,
@@ -118,6 +124,7 @@ internal static class Extensions
118124
// Type names
119125
private const string Enumerator = nameof(Span<>.Enumerator);
120126
private const string CancellationToken = nameof(global::System.Threading.CancellationToken);
127+
private const string EnumeratorCancellationAttribute = nameof(global::System.Runtime.CompilerServices.EnumeratorCancellationAttribute);
121128
private const string IAsyncEnumerable = nameof(IAsyncEnumerable<>);
122129
private const string IAsyncEnumerator = nameof(IAsyncEnumerator<>);
123130
private const string Task = nameof(Task<>);

src/Zomp.SyncMethodGenerator/SourceGenerationHelper.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ internal class {{SyncMethodSourceGenerator.CreateSyncVersionAttribute}} : System
3636
/// Gets or sets a value indicating whether <see cref="System.IProgress{T}"/> parameters will be preserved in the generated code. False by default.
3737
/// </summary>
3838
public bool {{SyncMethodSourceGenerator.PreserveProgress}} { get; set; }
39+
40+
/// <summary>
41+
/// Gets or sets a value indicating whether <see cref="System.Threading.CancellationToken"/> parameters will be preserved in the generated code. False by default.
42+
/// </summary>
43+
public bool {{SyncMethodSourceGenerator.PreserveCancellationToken}} { get; set; }
3944
}
4045
#endif
4146
}

src/Zomp.SyncMethodGenerator/SyncMethodSourceGenerator.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public class SyncMethodSourceGenerator : IIncrementalGenerator
2121

2222
internal const string OmitNullableDirective = "OmitNullableDirective";
2323
internal const string PreserveProgress = "PreserveProgress";
24+
internal const string PreserveCancellationToken = "PreserveCancellationToken";
2425

2526
/// <inheritdoc/>
2627
public void Initialize(IncrementalGeneratorInitializationContext context)
@@ -237,9 +238,10 @@ static string BuildClassName(MethodParentDeclaration c)
237238
disableNullable |= explicitDisableNullable;
238239

239240
var preserveProgress = syncMethodGeneratorAttributeData.NamedArguments.FirstOrDefault(c => c.Key == PreserveProgress) is { Value.Value: true };
241+
var preserveCancellationToken = syncMethodGeneratorAttributeData.NamedArguments.FirstOrDefault(c => c.Key == PreserveCancellationToken) is { Value.Value: true };
240242

241243
var toVisit = extensionParent ?? (SyntaxNode)methodDeclarationSyntax;
242-
var rewriter = new AsyncToSyncRewriter(context.SemanticModel, disableNullable, preserveProgress, methodDeclarationSyntax);
244+
var rewriter = new AsyncToSyncRewriter(context.SemanticModel, disableNullable, preserveProgress, preserveCancellationToken, methodDeclarationSyntax);
243245
var sn = rewriter.Visit(toVisit);
244246
var content = sn.ToFullString();
245247

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
namespace Generator.Tests;
2+
3+
public class PreserveCancellationTokenTests
4+
{
5+
[Fact]
6+
public Task PreserveIsCancellationRequestedAndThrowIfCancellationRequested() => """
7+
[CreateSyncVersion(PreserveCancellationToken = true)]
8+
async Task MethodAsync(CancellationToken ct = default)
9+
{
10+
if (!ct.IsCancellationRequested)
11+
{
12+
ct.ThrowIfCancellationRequested();
13+
}
14+
}
15+
""".Verify();
16+
17+
[Fact]
18+
public Task RemoveCancellationTokenForSyncMethods() => """
19+
[CreateSyncVersion(PreserveCancellationToken = true)]
20+
async Task MethodAsync(CancellationToken ct = default)
21+
{
22+
Stream.Null.FlushAsync(ct);
23+
}
24+
""".Verify();
25+
26+
[Theory]
27+
[InlineData("[EnumeratorCancellation]")]
28+
[InlineData("[EnumeratorCancellationAttribute]")]
29+
[InlineData("[System.Runtime.CompilerServices.EnumeratorCancellation]")]
30+
[InlineData("[System.Runtime.CompilerServices.EnumeratorCancellationAttribute]")]
31+
[InlineData("[global::System.Runtime.CompilerServices.EnumeratorCancellationAttribute]")]
32+
public Task RemoveEnumeratorCancellationAttribute(string attribute) => $$"""
33+
[CreateSyncVersion(PreserveCancellationToken = true)]
34+
async IAsyncEnumerable<int> FibonacciAsync({{attribute}} CancellationToken ct = default)
35+
{
36+
var f0 = 0;
37+
var f1 = 1;
38+
yield return f0;
39+
yield return f1;
40+
while (!ct.IsCancellationRequested)
41+
{
42+
var fn = f0 + f1;
43+
yield return fn;
44+
await Task.Yield();
45+
f0 = f1;
46+
f1 = fn;
47+
}
48+
}
49+
""".Verify(disableUnique: true);
50+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
//HintName: Test.Class.MethodAsync.g.cs
2+
void Method(global::System.Threading.CancellationToken ct = default)
3+
{
4+
if (!ct.IsCancellationRequested)
5+
{
6+
ct.ThrowIfCancellationRequested();
7+
}
8+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
//HintName: Test.Class.MethodAsync.g.cs
2+
void Method(global::System.Threading.CancellationToken ct = default)
3+
{
4+
global::System.IO.Stream.Null.Flush();
5+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//HintName: Test.Class.FibonacciAsync.g.cs
2+
global::System.Collections.Generic.IEnumerable<int> Fibonacci(global::System.Threading.CancellationToken ct = default)
3+
{
4+
var f0 = 0;
5+
var f1 = 1;
6+
yield return f0;
7+
yield return f1;
8+
while (!ct.IsCancellationRequested)
9+
{
10+
var fn = f0 + f1;
11+
yield return fn;
12+
f0 = f1;
13+
f1 = fn;
14+
}
15+
}

0 commit comments

Comments
 (0)