From c2d5e577a6df0e7887cf64a5276ca51c487eb5a4 Mon Sep 17 00:00:00 2001 From: Jonathan Peppers Date: Sat, 2 May 2026 20:20:18 -0500 Subject: [PATCH 1/4] Add hybrid sorting with [HybridSortingNetwork] attribute Implement a hybrid quicksort that uses 3-way partitioning for large arrays (n > 64) and sorting-network base cases for small sub-arrays (n <= 64). This allows sorting arrays of any size without requiring user-supplied OnFallback methods. New features: - [HybridSortingNetwork(typeof(T))] attribute for code generation - Sort(Span) / Sort(T[]) for arbitrary-length arrays - PartialSort(Span, int k) for top-k sorting - NthElement(Span, int n) for quickselect - 3-way Dutch National Flag partition (handles duplicates well) - Embedded optimal sorting network data for sizes 2-64 - Support for int, byte, short, long, float, double types - HybridSortEmitter generates all code via source generator SIMD partition using AVX-512F Compress / AVX-512 VBMI2 Compress is scaffolded as a future optimization on top of the scalar partition. Resolves #34 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../HybridSortEmitter.cs | 413 ++++++++++++++++++ .../SortingNetworkGenerator.cs | 117 +++++ SortingNetworks.Tests/GeneratorTests.cs | 36 ++ SortingNetworks.Tests/HybridSortTests.cs | 260 +++++++++++ SortingNetworks.Tests/HybridSorter.cs | 11 + .../HybridSortingNetworkAttribute.cs | 28 ++ 6 files changed, 865 insertions(+) create mode 100644 SortingNetworks.Generators/HybridSortEmitter.cs create mode 100644 SortingNetworks.Tests/HybridSortTests.cs create mode 100644 SortingNetworks.Tests/HybridSorter.cs create mode 100644 SortingNetworks/HybridSortingNetworkAttribute.cs diff --git a/SortingNetworks.Generators/HybridSortEmitter.cs b/SortingNetworks.Generators/HybridSortEmitter.cs new file mode 100644 index 0000000..8e269a6 --- /dev/null +++ b/SortingNetworks.Generators/HybridSortEmitter.cs @@ -0,0 +1,413 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace SortingNetworks.Generators +{ + /// + /// Emits hybrid sorting code that combines AVX-512 SIMD partitioning + /// for large arrays with sorting-network base cases for small sub-arrays. + /// Generates Sort, PartialSort, and NthElement methods. + /// + internal static class HybridSortEmitter + { + private const int BaseThreshold = 64; + + /// + /// Emits all hybrid sort methods for the given element type. + /// Does NOT emit the shared network data — call once separately. + /// + internal static string Emit(string typeName, SpecialType specialType) + { + var sb = new StringBuilder(); + + bool isString = typeName == "string"; + int elemSize = GetElementSize(specialType); + bool canSimdPartition = CanEmitSimdPartition(specialType); + + // Emit Sort(Span) and Sort(T[]) + EmitSortMethods(sb, typeName, isString); + + // Emit PartialSort(Span, int k) and overload + EmitPartialSortMethods(sb, typeName); + + // Emit NthElement(Span, int n) and overload + EmitNthElementMethods(sb, typeName); + + // Emit private HybridQuickSort + EmitHybridQuickSort(sb, typeName, isString); + + // Emit private HybridQuickSelect (for PartialSort/NthElement) + EmitHybridQuickSelect(sb, typeName, isString); + + // Emit private SortSmall + EmitSortSmall(sb, typeName, isString); + + // Emit MedianOfThree + EmitMedianOfThree(sb, typeName, isString); + + // Emit ScalarPartition3Way + EmitScalarPartition3Way(sb, typeName, isString); + + // Emit SIMD partition if applicable + if (canSimdPartition) + { + EmitSimdPartition(sb, typeName, specialType, elemSize); + } + + return sb.ToString(); + } + + private static void EmitSortMethods(StringBuilder sb, string typeName, bool isString) + { + sb.AppendLine($" /// Sorts a span of {typeName} using a hybrid SIMD quicksort with sorting network base case."); + sb.AppendLine($" public static void Sort(System.Span<{typeName}> span)"); + sb.AppendLine(" {"); + sb.AppendLine(" if (span.Length <= 1) return;"); + sb.AppendLine($" if (span.Length <= {BaseThreshold})"); + sb.AppendLine(" {"); + sb.AppendLine($" HybridSortSmall(ref System.Runtime.InteropServices.MemoryMarshal.GetReference(span), span.Length);"); + sb.AppendLine(" return;"); + sb.AppendLine(" }"); + sb.AppendLine(" HybridQuickSort(span);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine($" /// Sorts an array of {typeName} using a hybrid SIMD quicksort with sorting network base case."); + sb.AppendLine($" public static void Sort({typeName}[] array)"); + sb.AppendLine(" {"); + sb.AppendLine(" System.ArgumentNullException.ThrowIfNull(array);"); + sb.AppendLine($" Sort((System.Span<{typeName}>)array);"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void EmitPartialSortMethods(StringBuilder sb, string typeName) + { + sb.AppendLine($" /// Partially sorts a span so that the first elements are the smallest in sorted order."); + sb.AppendLine($" public static void PartialSort(System.Span<{typeName}> span, int k)"); + sb.AppendLine(" {"); + sb.AppendLine(" if (k < 0 || k > span.Length) throw new System.ArgumentOutOfRangeException(nameof(k));"); + sb.AppendLine(" if (k <= 1 || span.Length <= 1) return;"); + sb.AppendLine(" HybridQuickSelect(span, k);"); + sb.AppendLine($" var left = span.Slice(0, k);"); + sb.AppendLine($" if (left.Length <= {BaseThreshold})"); + sb.AppendLine($" HybridSortSmall(ref System.Runtime.InteropServices.MemoryMarshal.GetReference(left), left.Length);"); + sb.AppendLine(" else"); + sb.AppendLine(" HybridQuickSort(left);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine($" /// Partially sorts an array so that the first elements are the smallest in sorted order."); + sb.AppendLine($" public static void PartialSort({typeName}[] array, int k)"); + sb.AppendLine(" {"); + sb.AppendLine(" System.ArgumentNullException.ThrowIfNull(array);"); + sb.AppendLine($" PartialSort((System.Span<{typeName}>)array, k);"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void EmitNthElementMethods(StringBuilder sb, string typeName) + { + sb.AppendLine($" /// Rearranges elements so that the element at index is the element that would be there if the span were sorted."); + sb.AppendLine($" public static void NthElement(System.Span<{typeName}> span, int n)"); + sb.AppendLine(" {"); + sb.AppendLine(" if (n < 0 || n >= span.Length) throw new System.ArgumentOutOfRangeException(nameof(n));"); + sb.AppendLine(" if (span.Length <= 1) return;"); + sb.AppendLine(" HybridQuickSelect(span, n + 1);"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine($" /// Rearranges elements so that the element at index is the element that would be there if the array were sorted."); + sb.AppendLine($" public static void NthElement({typeName}[] array, int n)"); + sb.AppendLine(" {"); + sb.AppendLine(" System.ArgumentNullException.ThrowIfNull(array);"); + sb.AppendLine($" NthElement((System.Span<{typeName}>)array, n);"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void EmitHybridQuickSort(StringBuilder sb, string typeName, bool isString) + { + string gt = GetGreaterThan(typeName, isString); + + sb.AppendLine($" private static void HybridQuickSort(System.Span<{typeName}> span)"); + sb.AppendLine(" {"); + sb.AppendLine($" while (span.Length > {BaseThreshold})"); + sb.AppendLine(" {"); + sb.AppendLine($" ref {typeName} first = ref System.Runtime.InteropServices.MemoryMarshal.GetReference(span);"); + sb.AppendLine($" {typeName} pivot = HybridMedianOfThree("); + sb.AppendLine(" first,"); + sb.AppendLine(" System.Runtime.CompilerServices.Unsafe.Add(ref first, span.Length / 2),"); + sb.AppendLine(" System.Runtime.CompilerServices.Unsafe.Add(ref first, span.Length - 1));"); + sb.AppendLine(); + sb.AppendLine(" HybridPartition3Way(span, pivot, out int lt, out int gt);"); + sb.AppendLine(); + sb.AppendLine(" // Recurse on the smaller side, iterate on the larger"); + sb.AppendLine(" var left = span.Slice(0, lt);"); + sb.AppendLine(" var right = span.Slice(gt);"); + sb.AppendLine(" if (left.Length <= right.Length)"); + sb.AppendLine(" {"); + sb.AppendLine($" if (left.Length > 1)"); + sb.AppendLine(" {"); + sb.AppendLine($" if (left.Length <= {BaseThreshold})"); + sb.AppendLine($" HybridSortSmall(ref System.Runtime.InteropServices.MemoryMarshal.GetReference(left), left.Length);"); + sb.AppendLine(" else"); + sb.AppendLine(" HybridQuickSort(left);"); + sb.AppendLine(" }"); + sb.AppendLine(" span = right;"); + sb.AppendLine(" }"); + sb.AppendLine(" else"); + sb.AppendLine(" {"); + sb.AppendLine($" if (right.Length > 1)"); + sb.AppendLine(" {"); + sb.AppendLine($" if (right.Length <= {BaseThreshold})"); + sb.AppendLine($" HybridSortSmall(ref System.Runtime.InteropServices.MemoryMarshal.GetReference(right), right.Length);"); + sb.AppendLine(" else"); + sb.AppendLine(" HybridQuickSort(right);"); + sb.AppendLine(" }"); + sb.AppendLine(" span = left;"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine($" if (span.Length > 1)"); + sb.AppendLine($" HybridSortSmall(ref System.Runtime.InteropServices.MemoryMarshal.GetReference(span), span.Length);"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void EmitHybridQuickSelect(StringBuilder sb, string typeName, bool isString) + { + sb.AppendLine($" private static void HybridQuickSelect(System.Span<{typeName}> span, int k)"); + sb.AppendLine(" {"); + sb.AppendLine(" while (span.Length > 1)"); + sb.AppendLine(" {"); + sb.AppendLine($" ref {typeName} first = ref System.Runtime.InteropServices.MemoryMarshal.GetReference(span);"); + sb.AppendLine($" {typeName} pivot = HybridMedianOfThree("); + sb.AppendLine(" first,"); + sb.AppendLine(" System.Runtime.CompilerServices.Unsafe.Add(ref first, span.Length / 2),"); + sb.AppendLine(" System.Runtime.CompilerServices.Unsafe.Add(ref first, span.Length - 1));"); + sb.AppendLine(); + sb.AppendLine(" HybridPartition3Way(span, pivot, out int lt, out int gt);"); + sb.AppendLine(); + sb.AppendLine(" if (k <= lt)"); + sb.AppendLine(" span = span.Slice(0, lt);"); + sb.AppendLine(" else if (k > gt)"); + sb.AppendLine(" { span = span.Slice(gt); k -= gt; }"); + sb.AppendLine(" else"); + sb.AppendLine(" return; // k-th element is in the pivot region"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void EmitSortSmall(StringBuilder sb, string typeName, bool isString) + { + string condition = GetGreaterThanForRef(typeName, isString, "a", "b"); + + sb.AppendLine(" [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]"); + sb.AppendLine($" private static void HybridSortSmall(ref {typeName} first, int length)"); + sb.AppendLine(" {"); + sb.AppendLine(" int offset = HybridGetNetworkOffset(length);"); + sb.AppendLine(" int pairCount = HybridGetNetworkOffset(length + 1) - offset;"); + sb.AppendLine(" System.ReadOnlySpan data = HybridNetworkData;"); + sb.AppendLine(" for (int i = 0; i < pairCount; i += 2)"); + sb.AppendLine(" {"); + sb.AppendLine($" ref {typeName} a = ref System.Runtime.CompilerServices.Unsafe.Add(ref first, data[offset + i]);"); + sb.AppendLine($" ref {typeName} b = ref System.Runtime.CompilerServices.Unsafe.Add(ref first, data[offset + i + 1]);"); + sb.AppendLine($" if ({condition})"); + sb.AppendLine(" {"); + sb.AppendLine($" {typeName} temp = a;"); + sb.AppendLine(" a = b;"); + sb.AppendLine(" b = temp;"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void EmitMedianOfThree(StringBuilder sb, string typeName, bool isString) + { + string gt = GetGreaterThan(typeName, isString); + + sb.AppendLine(" [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]"); + sb.AppendLine($" private static {typeName} HybridMedianOfThree({typeName} a, {typeName} b, {typeName} c)"); + sb.AppendLine(" {"); + sb.AppendLine($" if ({gt.Replace("$a", "a").Replace("$b", "b")}) {{ {typeName} t = a; a = b; b = t; }}"); + sb.AppendLine($" if ({gt.Replace("$a", "b").Replace("$b", "c")}) {{ b = c; if ({gt.Replace("$a", "a").Replace("$b", "b")}) b = a; }}"); + sb.AppendLine(" return b;"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void EmitScalarPartition3Way(StringBuilder sb, string typeName, bool isString) + { + string lt = GetLessThan(typeName, isString); + string gt = GetGreaterThan(typeName, isString); + + sb.AppendLine($" private static void HybridPartition3Way(System.Span<{typeName}> span, {typeName} pivot, out int ltEnd, out int gtStart)"); + sb.AppendLine(" {"); + sb.AppendLine($" ref {typeName} first = ref System.Runtime.InteropServices.MemoryMarshal.GetReference(span);"); + sb.AppendLine(" int lo = 0;"); + sb.AppendLine(" int mid = 0;"); + sb.AppendLine(" int hi = span.Length - 1;"); + sb.AppendLine(" while (mid <= hi)"); + sb.AppendLine(" {"); + sb.AppendLine($" ref {typeName} elem = ref System.Runtime.CompilerServices.Unsafe.Add(ref first, mid);"); + sb.AppendLine($" if ({lt.Replace("$a", "elem").Replace("$b", "pivot")})"); + sb.AppendLine(" {"); + sb.AppendLine($" ref {typeName} target = ref System.Runtime.CompilerServices.Unsafe.Add(ref first, lo);"); + sb.AppendLine($" {typeName} temp = target;"); + sb.AppendLine(" target = elem;"); + sb.AppendLine(" elem = temp;"); + sb.AppendLine(" lo++;"); + sb.AppendLine(" mid++;"); + sb.AppendLine(" }"); + sb.AppendLine($" else if ({gt.Replace("$a", "elem").Replace("$b", "pivot")})"); + sb.AppendLine(" {"); + sb.AppendLine($" ref {typeName} target = ref System.Runtime.CompilerServices.Unsafe.Add(ref first, hi);"); + sb.AppendLine($" {typeName} temp = target;"); + sb.AppendLine(" target = elem;"); + sb.AppendLine(" elem = temp;"); + sb.AppendLine(" hi--;"); + sb.AppendLine(" }"); + sb.AppendLine(" else"); + sb.AppendLine(" {"); + sb.AppendLine(" mid++;"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" ltEnd = lo;"); + sb.AppendLine(" gtStart = hi + 1;"); + sb.AppendLine(" }"); + sb.AppendLine(); + } + + private static void EmitSimdPartition(StringBuilder sb, string typeName, SpecialType specialType, int elemSize) + { + // SIMD partition is a future optimization on top of the scalar 3-way partition. + // The scalar partition is always correct and sufficient for correctness. + // AVX-512F Compress (for 32/64-bit) and AVX-512 VBMI2 Compress (for 8/16-bit) + // can be used here for a vectorized partition step. + // TODO: Implement SIMD partition using Avx512F.Compress / Avx512Vbmi2.Compress + } + + internal static string EmitNetworkData() + { + var sb = new StringBuilder(); + // Build compact byte array of all sorting networks for sizes 2-64 + // Format: pairs are stored contiguously, offset table maps size -> start position + var allNetworkPairs = new List(); + var offsets = new int[BaseThreshold + 2]; // offsets[size] = start index for size, offsets[65] = end + + for (int size = 2; size <= BaseThreshold; size++) + { + offsets[size] = allNetworkPairs.Count; + var network = NetworkDatabase.GetNetwork(size); + if (network == null) + network = BatcherNetworkBuilder.Generate(size); + + for (int i = 0; i < network.Length; i++) + { + allNetworkPairs.Add((byte)network[i]); + } + } + offsets[BaseThreshold + 1] = allNetworkPairs.Count; + + // Emit the offset lookup method + sb.AppendLine(" [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]"); + sb.AppendLine(" private static int HybridGetNetworkOffset(int size)"); + sb.AppendLine(" {"); + sb.AppendLine(" System.ReadOnlySpan offsets = HybridNetworkOffsets;"); + sb.AppendLine(" return offsets[size];"); + sb.AppendLine(" }"); + sb.AppendLine(); + + // Emit the offset table + sb.Append(" private static System.ReadOnlySpan HybridNetworkOffsets => new int[] { "); + for (int i = 0; i < offsets.Length; i++) + { + if (i > 0) sb.Append(", "); + sb.Append(offsets[i]); + } + sb.AppendLine(" };"); + sb.AppendLine(); + + // Emit the network data as a byte array + sb.Append(" private static System.ReadOnlySpan HybridNetworkData => new byte[] { "); + for (int i = 0; i < allNetworkPairs.Count; i++) + { + if (i > 0) sb.Append(", "); + sb.Append(allNetworkPairs[i]); + } + sb.AppendLine(" };"); + sb.AppendLine(); + return sb.ToString(); + } + + /// + /// Returns whether SIMD partition can be emitted for this type. + /// + internal static bool CanEmitSimdPartition(SpecialType specialType) + { + return specialType switch + { + SpecialType.System_Byte => true, + SpecialType.System_SByte => true, + SpecialType.System_Int16 => true, + SpecialType.System_UInt16 => true, + SpecialType.System_Char => true, + SpecialType.System_Int32 => true, + SpecialType.System_UInt32 => true, + SpecialType.System_Single => true, + SpecialType.System_Int64 => true, + SpecialType.System_UInt64 => true, + SpecialType.System_Double => true, + _ => false, + }; + } + + private static int GetElementSize(SpecialType specialType) + { + return specialType switch + { + SpecialType.System_Byte => 1, + SpecialType.System_SByte => 1, + SpecialType.System_Int16 => 2, + SpecialType.System_UInt16 => 2, + SpecialType.System_Char => 2, + SpecialType.System_Int32 => 4, + SpecialType.System_UInt32 => 4, + SpecialType.System_Single => 4, + SpecialType.System_Int64 => 8, + SpecialType.System_UInt64 => 8, + SpecialType.System_Double => 8, + _ => 0, + }; + } + + /// + /// Returns a "greater than" comparison expression template with $a and $b placeholders. + /// + private static string GetGreaterThan(string typeName, bool isString) + { + if (isString) return "string.CompareOrdinal($a, $b) > 0"; + return "$a > $b"; + } + + /// + /// Returns a "less than" comparison expression template with $a and $b placeholders. + /// + private static string GetLessThan(string typeName, bool isString) + { + if (isString) return "string.CompareOrdinal($a, $b) < 0"; + return "$a < $b"; + } + + /// + /// Returns a "greater than" comparison for named ref variables. + /// + private static string GetGreaterThanForRef(string typeName, bool isString, string a, string b) + { + if (isString) return $"string.CompareOrdinal({a}, {b}) > 0"; + return $"{a} > {b}"; + } + } +} diff --git a/SortingNetworks.Generators/SortingNetworkGenerator.cs b/SortingNetworks.Generators/SortingNetworkGenerator.cs index 1c80c9c..38fcaa0 100644 --- a/SortingNetworks.Generators/SortingNetworkGenerator.cs +++ b/SortingNetworks.Generators/SortingNetworkGenerator.cs @@ -66,6 +66,17 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Collect(); context.RegisterSourceOutput(classDeclarations, static (spc, infos) => Execute(spc, infos!)); + + // Find all class declarations with [HybridSortingNetwork] attributes + var hybridDeclarations = context.SyntaxProvider + .ForAttributeWithMetadataName( + "SortingNetworks.HybridSortingNetworkAttribute", + predicate: static (node, _) => node is ClassDeclarationSyntax, + transform: static (ctx, _) => GetHybridGenerationInfo(ctx)) + .Where(static info => info != null) + .Collect(); + + context.RegisterSourceOutput(hybridDeclarations, static (spc, infos) => ExecuteHybrid(spc, infos!)); } private static GenerationInfo? GetGenerationInfo(GeneratorAttributeSyntaxContext context) @@ -966,6 +977,86 @@ private static void EmitComparerOverloads(StringBuilder sb, string typeName, Lis sb.AppendLine(); } + private static HybridGenerationInfo? GetHybridGenerationInfo(GeneratorAttributeSyntaxContext context) + { + var classSymbol = (INamedTypeSymbol)context.TargetSymbol; + + var types = new List(); + foreach (var attr in context.Attributes) + { + if (attr.ConstructorArguments.Length < 1) + continue; + + var typeArg = attr.ConstructorArguments[0]; + if (typeArg.Value is INamedTypeSymbol typeSymbol) + { + var typeName = GetKeywordName(typeSymbol.SpecialType) + ?? typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + if (!SupportedSpecialTypes.Contains(typeSymbol.SpecialType)) + continue; + + types.Add(new HybridTypeRequest(typeName, typeSymbol.SpecialType)); + } + } + + if (types.Count == 0) + return null; + + var namespaceName = classSymbol.ContainingNamespace.IsGlobalNamespace + ? null + : classSymbol.ContainingNamespace.ToDisplayString(); + + return new HybridGenerationInfo(classSymbol.Name, namespaceName, types.ToArray()); + } + + private static void ExecuteHybrid(SourceProductionContext context, ImmutableArray infos) + { + foreach (var info in infos) + { + if (info == null) continue; + var source = GenerateHybridSource(info); + if (source != null) + { + context.AddSource($"{info.ClassName}.Hybrid.g.cs", SourceText.From(source, Encoding.UTF8)); + } + } + } + + private static string? GenerateHybridSource(HybridGenerationInfo info) + { + var sb = new StringBuilder(); + sb.AppendLine("// "); + sb.AppendLine("#nullable enable"); + sb.AppendLine(); + + if (info.Namespace != null) + { + sb.AppendLine($"namespace {info.Namespace}"); + sb.AppendLine("{"); + } + + sb.AppendLine($" partial class {info.ClassName}"); + sb.AppendLine(" {"); + + foreach (var typeRequest in info.Types) + { + sb.Append(HybridSortEmitter.Emit(typeRequest.TypeName, typeRequest.SpecialType)); + } + + // Emit shared network data once (type-independent) + sb.Append(HybridSortEmitter.EmitNetworkData()); + + sb.AppendLine(" }"); + + if (info.Namespace != null) + { + sb.AppendLine("}"); + } + + return sb.ToString(); + } + private sealed class GenerationInfo { public string ClassName { get; } @@ -984,6 +1075,32 @@ public GenerationInfo(string className, string? ns, NetworkRequest[] requests, H } } + private sealed class HybridGenerationInfo + { + public string ClassName { get; } + public string? Namespace { get; } + public HybridTypeRequest[] Types { get; } + + public HybridGenerationInfo(string className, string? ns, HybridTypeRequest[] types) + { + ClassName = className; + Namespace = ns; + Types = types; + } + } + + private sealed class HybridTypeRequest + { + public string TypeName { get; } + public SpecialType SpecialType { get; } + + public HybridTypeRequest(string typeName, SpecialType specialType) + { + TypeName = typeName; + SpecialType = specialType; + } + } + /// /// Returns the fixed-size delegate types for nint/nuint SIMD dispatch. /// For nint: 32-bit=int, 64-bit=long. For nuint: 32-bit=uint, 64-bit=ulong. diff --git a/SortingNetworks.Tests/GeneratorTests.cs b/SortingNetworks.Tests/GeneratorTests.cs index 8fd4e8c..524ee0d 100644 --- a/SortingNetworks.Tests/GeneratorTests.cs +++ b/SortingNetworks.Tests/GeneratorTests.cs @@ -873,4 +873,40 @@ public partial class MySorter {{ }} Assert.Contains($"private static void Sort16(ref {type64} first)", generatedSource); Assert.Contains($"private static void Sort16(ref {type32} first)", generatedSource); } + + [Theory] + [InlineData("int")] + [InlineData("byte")] + [InlineData("short")] + [InlineData("long")] + [InlineData("float")] + [InlineData("double")] + public void HybridSort_GeneratesCode(string typeName) + { + var source = $@" +using SortingNetworks; + +[HybridSortingNetwork(typeof({typeName}))] +public partial class MySorter {{ }} +"; + var compilation = SourceGeneratorDriver.CreateCompilation(source); + var (result, updatedCompilation) = SourceGeneratorDriver.RunGeneratorWithCompilation(compilation); + + var errors = result.Diagnostics.Where(d => d.Severity == DiagnosticSeverity.Error).ToArray(); + Assert.Empty(errors); + + var compilationErrors = SourceGeneratorDriver.GetErrors(updatedCompilation); + Assert.Empty(compilationErrors); + + var generatedSource = result.GeneratedTrees + .Select(t => t.GetText().ToString()) + .FirstOrDefault(s => s.Contains("HybridQuickSort")); + Assert.NotNull(generatedSource); + Assert.Contains("HybridSortSmall", generatedSource); + Assert.Contains("HybridPartition3Way", generatedSource); + Assert.Contains("HybridMedianOfThree", generatedSource); + Assert.Contains("PartialSort", generatedSource); + Assert.Contains("NthElement", generatedSource); + Assert.Contains("HybridNetworkData", generatedSource); + } } diff --git a/SortingNetworks.Tests/HybridSortTests.cs b/SortingNetworks.Tests/HybridSortTests.cs new file mode 100644 index 0000000..87e1fac --- /dev/null +++ b/SortingNetworks.Tests/HybridSortTests.cs @@ -0,0 +1,260 @@ +using SortingNetworks; + +namespace SortingNetworks.Tests; + +public class HybridSortTests +{ + // --- Sort: various sizes --- + + private static void VerifyHybridSort(int size) + { + for (int seed = 0; seed < 50; seed++) + { + var rng = new Random(seed); + var input = Enumerable.Range(0, size).Select(_ => rng.Next(-10000, 10000)).ToArray(); + var expected = (int[])input.Clone(); + Array.Sort(expected); + + var actual = (int[])input.Clone(); + HybridSorter.Sort(actual.AsSpan()); + + Assert.Equal(expected, actual); + } + } + + [Fact] public void Sort_2Elements() => VerifyHybridSort(2); + [Fact] public void Sort_3Elements() => VerifyHybridSort(3); + [Fact] public void Sort_10Elements() => VerifyHybridSort(10); + [Fact] public void Sort_27Elements() => VerifyHybridSort(27); + [Fact] public void Sort_64Elements() => VerifyHybridSort(64); + [Fact] public void Sort_100Elements() => VerifyHybridSort(100); + [Fact] public void Sort_256Elements() => VerifyHybridSort(256); + [Fact] public void Sort_1000Elements() => VerifyHybridSort(1000); + [Fact] public void Sort_10000Elements() => VerifyHybridSort(10000); + + // --- Sort: edge cases --- + + [Fact] + public void Sort_Empty() + { + var span = Span.Empty; + HybridSorter.Sort(span); // Should not throw + } + + [Fact] + public void Sort_SingleElement() + { + var data = new int[] { 42 }; + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(42, data[0]); + } + + [Fact] + public void Sort_AlreadySorted() + { + var data = Enumerable.Range(0, 200).ToArray(); + var expected = (int[])data.Clone(); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_ReverseSorted() + { + var data = Enumerable.Range(0, 200).Reverse().ToArray(); + var expected = Enumerable.Range(0, 200).ToArray(); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_AllDuplicates() + { + var data = Enumerable.Repeat(7, 200).ToArray(); + var expected = (int[])data.Clone(); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_DuplicateHeavy() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 500).Select(_ => rng.Next(0, 5)).ToArray(); + var expected = (int[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_NullArray_Throws() + { + Assert.Throws(() => HybridSorter.Sort((int[])null!)); + } + + // --- Sort: multiple types --- + + [Fact] + public void Sort_Byte_200Elements() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 200).Select(_ => (byte)rng.Next(0, 256)).ToArray(); + var expected = (byte[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_Short_200Elements() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 200).Select(_ => (short)rng.Next(-10000, 10000)).ToArray(); + var expected = (short[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_Long_200Elements() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 200).Select(_ => (long)rng.Next(-10000, 10000)).ToArray(); + var expected = (long[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_Float_200Elements() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 200).Select(_ => (float)(rng.NextDouble() * 2000 - 1000)).ToArray(); + var expected = (float[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_Double_200Elements() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 200).Select(_ => rng.NextDouble() * 2000 - 1000).ToArray(); + var expected = (double[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + // --- Stress test --- + + [Fact] + public void Sort_Stress_Int_Various_Sizes() + { + foreach (int size in new[] { 65, 100, 128, 200, 500, 1000 }) + { + for (int seed = 0; seed < 20; seed++) + { + var rng = new Random(seed + size); + var data = Enumerable.Range(0, size).Select(_ => rng.Next(-10000, 10000)).ToArray(); + var expected = (int[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + } + } + + // --- PartialSort tests --- + + [Fact] + public void PartialSort_Top10() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 1000).Select(_ => rng.Next(-10000, 10000)).ToArray(); + var expected = (int[])data.Clone(); + Array.Sort(expected); + + HybridSorter.PartialSort(data.AsSpan(), 10); + + // First 10 elements should match the 10 smallest sorted elements + Assert.Equal(expected.Take(10), data.Take(10)); + } + + [Fact] + public void PartialSort_K_Equals_Length() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 100).Select(_ => rng.Next(-1000, 1000)).ToArray(); + var expected = (int[])data.Clone(); + Array.Sort(expected); + + HybridSorter.PartialSort(data.AsSpan(), data.Length); + Assert.Equal(expected, data); + } + + [Fact] + public void PartialSort_K_Zero() + { + var data = new int[] { 3, 1, 2 }; + HybridSorter.PartialSort(data.AsSpan(), 0); // Should not throw + } + + [Fact] + public void PartialSort_InvalidK_Throws() + { + var data = new int[] { 3, 1, 2 }; + Assert.Throws(() => HybridSorter.PartialSort(data.AsSpan(), -1)); + Assert.Throws(() => HybridSorter.PartialSort(data.AsSpan(), 4)); + } + + // --- NthElement tests --- + + [Fact] + public void NthElement_Median() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 101).Select(_ => rng.Next(-10000, 10000)).ToArray(); + var sorted = (int[])data.Clone(); + Array.Sort(sorted); + + HybridSorter.NthElement(data.AsSpan(), 50); + Assert.Equal(sorted[50], data[50]); + } + + [Fact] + public void NthElement_First() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 100).Select(_ => rng.Next(-10000, 10000)).ToArray(); + var sorted = (int[])data.Clone(); + Array.Sort(sorted); + + HybridSorter.NthElement(data.AsSpan(), 0); + Assert.Equal(sorted[0], data[0]); + } + + [Fact] + public void NthElement_Last() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 100).Select(_ => rng.Next(-10000, 10000)).ToArray(); + var sorted = (int[])data.Clone(); + Array.Sort(sorted); + + HybridSorter.NthElement(data.AsSpan(), 99); + Assert.Equal(sorted[99], data[99]); + } + + [Fact] + public void NthElement_InvalidN_Throws() + { + var data = new int[] { 3, 1, 2 }; + Assert.Throws(() => HybridSorter.NthElement(data.AsSpan(), -1)); + Assert.Throws(() => HybridSorter.NthElement(data.AsSpan(), 3)); + } +} diff --git a/SortingNetworks.Tests/HybridSorter.cs b/SortingNetworks.Tests/HybridSorter.cs new file mode 100644 index 0000000..3120e5d --- /dev/null +++ b/SortingNetworks.Tests/HybridSorter.cs @@ -0,0 +1,11 @@ +using SortingNetworks; + +namespace SortingNetworks.Tests; + +[HybridSortingNetwork(typeof(int))] +[HybridSortingNetwork(typeof(byte))] +[HybridSortingNetwork(typeof(short))] +[HybridSortingNetwork(typeof(long))] +[HybridSortingNetwork(typeof(float))] +[HybridSortingNetwork(typeof(double))] +partial class HybridSorter { } diff --git a/SortingNetworks/HybridSortingNetworkAttribute.cs b/SortingNetworks/HybridSortingNetworkAttribute.cs new file mode 100644 index 0000000..f4ea891 --- /dev/null +++ b/SortingNetworks/HybridSortingNetworkAttribute.cs @@ -0,0 +1,28 @@ +using System; + +namespace SortingNetworks +{ + /// + /// Marks a partial class for hybrid sorting network code generation. + /// The source generator will emit Sort, PartialSort, + /// and NthElement methods that use AVX-512 SIMD partitioning + /// for large arrays and optimal sorting networks for small sub-arrays. + /// + [AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)] + public sealed class HybridSortingNetworkAttribute : Attribute + { + /// + /// The element type to sort (e.g., typeof(int)). + /// + public Type ElementType { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The element type to sort. + public HybridSortingNetworkAttribute(Type elementType) + { + ElementType = elementType; + } + } +} From a0f04f2933e612a81b711c18423917e9118bd71c Mon Sep 17 00:00:00 2001 From: Jonathan Peppers Date: Sat, 2 May 2026 21:29:00 -0500 Subject: [PATCH 2/4] Add hybrid sort benchmarks for int and long Add HybridIntSortingBenchmarks (sizes 32-10000, all InputKinds) and HybridLongSortingBenchmarks (sizes 32-10000, Random) to measure hybrid quicksort with sorting-network base cases against Array.Sort and Span.Sort baselines. - Add HybridSorters class with [HybridSortingNetwork] for int and long - Update CI benchmark summary to describe hybrid benchmarks correctly Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/ci.yml | 6 +- .../GeneratedSorters.cs | 4 ++ .../HybridIntSortingBenchmarks.cs | 64 +++++++++++++++++++ .../HybridLongSortingBenchmarks.cs | 57 +++++++++++++++++ 4 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 SortingNetworks.Benchmarks/HybridIntSortingBenchmarks.cs create mode 100644 SortingNetworks.Benchmarks/HybridLongSortingBenchmarks.cs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b67b3aa..ddcb27d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,7 +76,11 @@ jobs: name=${name%SortingBenchmarks} echo "## $name" >> benchmark-summary.md echo "" >> benchmark-summary.md - echo "> **ArraySort** = BCL baseline \`Array.Sort\`, **GeneratedSort** = this library" >> benchmark-summary.md + if [[ "$name" == Hybrid* ]]; then + echo "> **ArraySort** = BCL baseline \`Array.Sort\`, **HybridSort** = hybrid quicksort with sorting-network base cases" >> benchmark-summary.md + else + echo "> **ArraySort** = BCL baseline \`Array.Sort\`, **GeneratedSort** = this library" >> benchmark-summary.md + fi echo "" >> benchmark-summary.md cat "$f" >> benchmark-summary.md echo "" >> benchmark-summary.md diff --git a/SortingNetworks.Benchmarks/GeneratedSorters.cs b/SortingNetworks.Benchmarks/GeneratedSorters.cs index d8fb340..87277a2 100644 --- a/SortingNetworks.Benchmarks/GeneratedSorters.cs +++ b/SortingNetworks.Benchmarks/GeneratedSorters.cs @@ -173,3 +173,7 @@ namespace SortingNetworks.Benchmarks; [SortingNetwork(31, typeof(Record))] [SortingNetwork(32, typeof(Record))] partial class GeneratedSorters { } + +[HybridSortingNetwork(typeof(int))] +[HybridSortingNetwork(typeof(long))] +partial class HybridSorters { } diff --git a/SortingNetworks.Benchmarks/HybridIntSortingBenchmarks.cs b/SortingNetworks.Benchmarks/HybridIntSortingBenchmarks.cs new file mode 100644 index 0000000..edd3c12 --- /dev/null +++ b/SortingNetworks.Benchmarks/HybridIntSortingBenchmarks.cs @@ -0,0 +1,64 @@ +using BenchmarkDotNet.Attributes; + +namespace SortingNetworks.Benchmarks; + +[MemoryDiagnoser] +[SimpleJob(warmupCount: 5, iterationCount: 15)] +public class HybridIntSortingBenchmarks +{ + private const int OpsPerInvoke = 1000; + + [Params(32, 64, 100, 1000, 10000)] + public int Length { get; set; } + + [ParamsAllValues] + public InputKind Kind { get; set; } + + private int[] _source = null!; + private int[][] _batch = null!; + + [GlobalSetup] + public void Setup() + { + var rng = new Random(42); + _source = Kind switch + { + InputKind.Random => Enumerable.Range(0, Length).Select(_ => rng.Next(-1000, 1000)).ToArray(), + InputKind.Sorted => Enumerable.Range(0, Length).ToArray(), + InputKind.Reversed => Enumerable.Range(0, Length).Reverse().ToArray(), + InputKind.Duplicates => Enumerable.Range(0, Length).Select(_ => rng.Next(0, 3)).ToArray(), + _ => throw new ArgumentOutOfRangeException() + }; + _batch = new int[OpsPerInvoke][]; + for (int i = 0; i < OpsPerInvoke; i++) + _batch[i] = new int[Length]; + } + + [IterationSetup] + public void IterationSetup() + { + for (int i = 0; i < OpsPerInvoke; i++) + Array.Copy(_source, _batch[i], Length); + } + + [Benchmark(Baseline = true, OperationsPerInvoke = OpsPerInvoke)] + public void ArraySort() + { + for (int i = 0; i < OpsPerInvoke; i++) + Array.Sort(_batch[i]); + } + + [Benchmark(OperationsPerInvoke = OpsPerInvoke)] + public void SpanSort() + { + for (int i = 0; i < OpsPerInvoke; i++) + _batch[i].AsSpan().Sort(); + } + + [Benchmark(OperationsPerInvoke = OpsPerInvoke)] + public void HybridSort() + { + for (int i = 0; i < OpsPerInvoke; i++) + HybridSorters.Sort(_batch[i].AsSpan()); + } +} diff --git a/SortingNetworks.Benchmarks/HybridLongSortingBenchmarks.cs b/SortingNetworks.Benchmarks/HybridLongSortingBenchmarks.cs new file mode 100644 index 0000000..96c6360 --- /dev/null +++ b/SortingNetworks.Benchmarks/HybridLongSortingBenchmarks.cs @@ -0,0 +1,57 @@ +using BenchmarkDotNet.Attributes; + +namespace SortingNetworks.Benchmarks; + +[MemoryDiagnoser] +[SimpleJob(warmupCount: 5, iterationCount: 15)] +public class HybridLongSortingBenchmarks +{ + private const int OpsPerInvoke = 1000; + + [Params(32, 64, 100, 1000, 10000)] + public int Length { get; set; } + + [Params(InputKind.Random)] + public InputKind Kind { get; set; } + + private long[] _source = null!; + private long[][] _batch = null!; + + [GlobalSetup] + public void Setup() + { + var rng = new Random(42); + _source = Enumerable.Range(0, Length).Select(_ => (long)rng.Next(-1000, 1000)).ToArray(); + _batch = new long[OpsPerInvoke][]; + for (int i = 0; i < OpsPerInvoke; i++) + _batch[i] = new long[Length]; + } + + [IterationSetup] + public void IterationSetup() + { + for (int i = 0; i < OpsPerInvoke; i++) + Array.Copy(_source, _batch[i], Length); + } + + [Benchmark(Baseline = true, OperationsPerInvoke = OpsPerInvoke)] + public void ArraySort() + { + for (int i = 0; i < OpsPerInvoke; i++) + Array.Sort(_batch[i]); + } + + [Benchmark(OperationsPerInvoke = OpsPerInvoke)] + public void SpanSort() + { + for (int i = 0; i < OpsPerInvoke; i++) + _batch[i].AsSpan().Sort(); + } + + [Benchmark(OperationsPerInvoke = OpsPerInvoke)] + public void HybridSort() + { + for (int i = 0; i < OpsPerInvoke; i++) + HybridSorters.Sort(_batch[i].AsSpan()); + } +} From ecbc0f7eb68cef3f8216e68715f024659cbcc550 Mon Sep 17 00:00:00 2001 From: Jonathan Peppers Date: Sat, 2 May 2026 21:43:50 -0500 Subject: [PATCH 3/4] Address code review feedback and merge main - Fix PartialSort(span, 1) bug: k<=1 early return changed to k<=0 so k=1 correctly places the minimum element at index 0 - Fix network data allocation: use static readonly arrays instead of ReadOnlySpan property accessors that allocated per call - Add introsort depth limit (2*log2(n)) with fallback to span.Sort() to guarantee O(n log n) worst case - Deduplicate type requests in hybrid generator (HashSet) - Use fully qualified hint name (namespace.class) to avoid collisions - Handle nested classes in hint name generation - Fix XML doc: remove AVX-512 SIMD claim (not yet implemented) - Add HybridSortingNetworkAttribute to PublicAPI.Unshipped.txt - Add PartialSort k=1 test - Add tests for sbyte, ushort, uint, ulong, char types - Add HybridSorter coverage for all supported primitive types - Merge main branch Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../HybridSortEmitter.cs | 46 ++++++++----- .../SortingNetworkGenerator.cs | 25 ++++++- SortingNetworks.Tests/HybridSortTests.cs | 67 +++++++++++++++++++ SortingNetworks.Tests/HybridSorter.cs | 5 ++ .../HybridSortingNetworkAttribute.cs | 2 +- SortingNetworks/PublicAPI.Unshipped.txt | 3 + 6 files changed, 126 insertions(+), 22 deletions(-) diff --git a/SortingNetworks.Generators/HybridSortEmitter.cs b/SortingNetworks.Generators/HybridSortEmitter.cs index 8e269a6..87d0ec8 100644 --- a/SortingNetworks.Generators/HybridSortEmitter.cs +++ b/SortingNetworks.Generators/HybridSortEmitter.cs @@ -6,7 +6,7 @@ namespace SortingNetworks.Generators { /// - /// Emits hybrid sorting code that combines AVX-512 SIMD partitioning + /// Emits hybrid sorting code that combines quicksort partitioning /// for large arrays with sorting-network base cases for small sub-arrays. /// Generates Sort, PartialSort, and NthElement methods. /// @@ -61,7 +61,7 @@ internal static string Emit(string typeName, SpecialType specialType) private static void EmitSortMethods(StringBuilder sb, string typeName, bool isString) { - sb.AppendLine($" /// Sorts a span of {typeName} using a hybrid SIMD quicksort with sorting network base case."); + sb.AppendLine($" /// Sorts a span of {typeName} using a hybrid quicksort with sorting network base case."); sb.AppendLine($" public static void Sort(System.Span<{typeName}> span)"); sb.AppendLine(" {"); sb.AppendLine(" if (span.Length <= 1) return;"); @@ -70,10 +70,11 @@ private static void EmitSortMethods(StringBuilder sb, string typeName, bool isSt sb.AppendLine($" HybridSortSmall(ref System.Runtime.InteropServices.MemoryMarshal.GetReference(span), span.Length);"); sb.AppendLine(" return;"); sb.AppendLine(" }"); - sb.AppendLine(" HybridQuickSort(span);"); + sb.AppendLine(" int depthLimit = 2 * System.Numerics.BitOperations.Log2((uint)span.Length);"); + sb.AppendLine(" HybridQuickSort(span, depthLimit);"); sb.AppendLine(" }"); sb.AppendLine(); - sb.AppendLine($" /// Sorts an array of {typeName} using a hybrid SIMD quicksort with sorting network base case."); + sb.AppendLine($" /// Sorts an array of {typeName} using a hybrid quicksort with sorting network base case."); sb.AppendLine($" public static void Sort({typeName}[] array)"); sb.AppendLine(" {"); sb.AppendLine(" System.ArgumentNullException.ThrowIfNull(array);"); @@ -88,13 +89,16 @@ private static void EmitPartialSortMethods(StringBuilder sb, string typeName) sb.AppendLine($" public static void PartialSort(System.Span<{typeName}> span, int k)"); sb.AppendLine(" {"); sb.AppendLine(" if (k < 0 || k > span.Length) throw new System.ArgumentOutOfRangeException(nameof(k));"); - sb.AppendLine(" if (k <= 1 || span.Length <= 1) return;"); + sb.AppendLine(" if (k <= 0 || span.Length <= 1) return;"); sb.AppendLine(" HybridQuickSelect(span, k);"); sb.AppendLine($" var left = span.Slice(0, k);"); sb.AppendLine($" if (left.Length <= {BaseThreshold})"); sb.AppendLine($" HybridSortSmall(ref System.Runtime.InteropServices.MemoryMarshal.GetReference(left), left.Length);"); sb.AppendLine(" else"); - sb.AppendLine(" HybridQuickSort(left);"); + sb.AppendLine(" {"); + sb.AppendLine(" int depthLimit = 2 * System.Numerics.BitOperations.Log2((uint)left.Length);"); + sb.AppendLine(" HybridQuickSort(left, depthLimit);"); + sb.AppendLine(" }"); sb.AppendLine(" }"); sb.AppendLine(); sb.AppendLine($" /// Partially sorts an array so that the first elements are the smallest in sorted order."); @@ -129,10 +133,17 @@ private static void EmitHybridQuickSort(StringBuilder sb, string typeName, bool { string gt = GetGreaterThan(typeName, isString); - sb.AppendLine($" private static void HybridQuickSort(System.Span<{typeName}> span)"); + sb.AppendLine($" private static void HybridQuickSort(System.Span<{typeName}> span, int depthLimit)"); sb.AppendLine(" {"); sb.AppendLine($" while (span.Length > {BaseThreshold})"); sb.AppendLine(" {"); + sb.AppendLine(" if (depthLimit == 0)"); + sb.AppendLine(" {"); + sb.AppendLine(" System.MemoryExtensions.Sort(span);"); + sb.AppendLine(" return;"); + sb.AppendLine(" }"); + sb.AppendLine(" depthLimit--;"); + sb.AppendLine(); sb.AppendLine($" ref {typeName} first = ref System.Runtime.InteropServices.MemoryMarshal.GetReference(span);"); sb.AppendLine($" {typeName} pivot = HybridMedianOfThree("); sb.AppendLine(" first,"); @@ -151,7 +162,7 @@ private static void EmitHybridQuickSort(StringBuilder sb, string typeName, bool sb.AppendLine($" if (left.Length <= {BaseThreshold})"); sb.AppendLine($" HybridSortSmall(ref System.Runtime.InteropServices.MemoryMarshal.GetReference(left), left.Length);"); sb.AppendLine(" else"); - sb.AppendLine(" HybridQuickSort(left);"); + sb.AppendLine(" HybridQuickSort(left, depthLimit);"); sb.AppendLine(" }"); sb.AppendLine(" span = right;"); sb.AppendLine(" }"); @@ -162,7 +173,7 @@ private static void EmitHybridQuickSort(StringBuilder sb, string typeName, bool sb.AppendLine($" if (right.Length <= {BaseThreshold})"); sb.AppendLine($" HybridSortSmall(ref System.Runtime.InteropServices.MemoryMarshal.GetReference(right), right.Length);"); sb.AppendLine(" else"); - sb.AppendLine(" HybridQuickSort(right);"); + sb.AppendLine(" HybridQuickSort(right, depthLimit);"); sb.AppendLine(" }"); sb.AppendLine(" span = left;"); sb.AppendLine(" }"); @@ -208,7 +219,7 @@ private static void EmitSortSmall(StringBuilder sb, string typeName, bool isStri sb.AppendLine(" {"); sb.AppendLine(" int offset = HybridGetNetworkOffset(length);"); sb.AppendLine(" int pairCount = HybridGetNetworkOffset(length + 1) - offset;"); - sb.AppendLine(" System.ReadOnlySpan data = HybridNetworkData;"); + sb.AppendLine(" byte[] data = HybridNetworkData;"); sb.AppendLine(" for (int i = 0; i < pairCount; i += 2)"); sb.AppendLine(" {"); sb.AppendLine($" ref {typeName} a = ref System.Runtime.CompilerServices.Unsafe.Add(ref first, data[offset + i]);"); @@ -315,29 +326,28 @@ internal static string EmitNetworkData() sb.AppendLine(" [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]"); sb.AppendLine(" private static int HybridGetNetworkOffset(int size)"); sb.AppendLine(" {"); - sb.AppendLine(" System.ReadOnlySpan offsets = HybridNetworkOffsets;"); - sb.AppendLine(" return offsets[size];"); + sb.AppendLine(" return HybridNetworkOffsets[size];"); sb.AppendLine(" }"); sb.AppendLine(); - // Emit the offset table - sb.Append(" private static System.ReadOnlySpan HybridNetworkOffsets => new int[] { "); + // Emit the offset table as a static readonly array (avoids per-access allocation) + sb.Append(" private static readonly int[] HybridNetworkOffsets = ["); for (int i = 0; i < offsets.Length; i++) { if (i > 0) sb.Append(", "); sb.Append(offsets[i]); } - sb.AppendLine(" };"); + sb.AppendLine("];"); sb.AppendLine(); - // Emit the network data as a byte array - sb.Append(" private static System.ReadOnlySpan HybridNetworkData => new byte[] { "); + // Emit the network data as a static readonly byte array (avoids per-access allocation) + sb.Append(" private static readonly byte[] HybridNetworkData = ["); for (int i = 0; i < allNetworkPairs.Count; i++) { if (i > 0) sb.Append(", "); sb.Append(allNetworkPairs[i]); } - sb.AppendLine(" };"); + sb.AppendLine("];"); sb.AppendLine(); return sb.ToString(); } diff --git a/SortingNetworks.Generators/SortingNetworkGenerator.cs b/SortingNetworks.Generators/SortingNetworkGenerator.cs index 38fcaa0..fef0770 100644 --- a/SortingNetworks.Generators/SortingNetworkGenerator.cs +++ b/SortingNetworks.Generators/SortingNetworkGenerator.cs @@ -981,6 +981,7 @@ private static void EmitComparerOverloads(StringBuilder sb, string typeName, Lis { var classSymbol = (INamedTypeSymbol)context.TargetSymbol; + var seen = new HashSet(); var types = new List(); foreach (var attr in context.Attributes) { @@ -996,6 +997,9 @@ private static void EmitComparerOverloads(StringBuilder sb, string typeName, Lis if (!SupportedSpecialTypes.Contains(typeSymbol.SpecialType)) continue; + if (!seen.Add(typeSymbol.SpecialType)) + continue; + types.Add(new HybridTypeRequest(typeName, typeSymbol.SpecialType)); } } @@ -1007,7 +1011,20 @@ private static void EmitComparerOverloads(StringBuilder sb, string typeName, Lis ? null : classSymbol.ContainingNamespace.ToDisplayString(); - return new HybridGenerationInfo(classSymbol.Name, namespaceName, types.ToArray()); + // Build fully-qualified class name for hint (handles nested classes) + var classChain = new List(); + var current = classSymbol; + while (current != null) + { + classChain.Insert(0, current.Name); + current = current.ContainingType; + } + var qualifiedClassName = string.Join(".", classChain); + var hintName = namespaceName != null + ? $"{namespaceName}.{qualifiedClassName}" + : qualifiedClassName; + + return new HybridGenerationInfo(classSymbol.Name, namespaceName, types.ToArray(), hintName); } private static void ExecuteHybrid(SourceProductionContext context, ImmutableArray infos) @@ -1018,7 +1035,7 @@ private static void ExecuteHybrid(SourceProductionContext context, ImmutableArra var source = GenerateHybridSource(info); if (source != null) { - context.AddSource($"{info.ClassName}.Hybrid.g.cs", SourceText.From(source, Encoding.UTF8)); + context.AddSource($"{info.HintName}.Hybrid.g.cs", SourceText.From(source, Encoding.UTF8)); } } } @@ -1080,12 +1097,14 @@ private sealed class HybridGenerationInfo public string ClassName { get; } public string? Namespace { get; } public HybridTypeRequest[] Types { get; } + public string HintName { get; } - public HybridGenerationInfo(string className, string? ns, HybridTypeRequest[] types) + public HybridGenerationInfo(string className, string? ns, HybridTypeRequest[] types, string hintName) { ClassName = className; Namespace = ns; Types = types; + HintName = hintName; } } diff --git a/SortingNetworks.Tests/HybridSortTests.cs b/SortingNetworks.Tests/HybridSortTests.cs index 87e1fac..62e9ffe 100644 --- a/SortingNetworks.Tests/HybridSortTests.cs +++ b/SortingNetworks.Tests/HybridSortTests.cs @@ -150,6 +150,61 @@ public void Sort_Double_200Elements() Assert.Equal(expected, data); } + [Fact] + public void Sort_SByte_200Elements() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 200).Select(_ => (sbyte)rng.Next(-128, 128)).ToArray(); + var expected = (sbyte[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_UShort_200Elements() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 200).Select(_ => (ushort)rng.Next(0, 65536)).ToArray(); + var expected = (ushort[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_UInt_200Elements() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 200).Select(_ => (uint)rng.Next(0, int.MaxValue)).ToArray(); + var expected = (uint[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_ULong_200Elements() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 200).Select(_ => (ulong)rng.Next(0, int.MaxValue)).ToArray(); + var expected = (ulong[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + + [Fact] + public void Sort_Char_200Elements() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 200).Select(_ => (char)rng.Next('A', 'z')).ToArray(); + var expected = (char[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + // --- Stress test --- [Fact] @@ -204,6 +259,18 @@ public void PartialSort_K_Zero() HybridSorter.PartialSort(data.AsSpan(), 0); // Should not throw } + [Fact] + public void PartialSort_K_One() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 100).Select(_ => rng.Next(-1000, 1000)).ToArray(); + var expected = (int[])data.Clone(); + Array.Sort(expected); + + HybridSorter.PartialSort(data.AsSpan(), 1); + Assert.Equal(expected[0], data[0]); + } + [Fact] public void PartialSort_InvalidK_Throws() { diff --git a/SortingNetworks.Tests/HybridSorter.cs b/SortingNetworks.Tests/HybridSorter.cs index 3120e5d..a2dcf48 100644 --- a/SortingNetworks.Tests/HybridSorter.cs +++ b/SortingNetworks.Tests/HybridSorter.cs @@ -4,8 +4,13 @@ namespace SortingNetworks.Tests; [HybridSortingNetwork(typeof(int))] [HybridSortingNetwork(typeof(byte))] +[HybridSortingNetwork(typeof(sbyte))] [HybridSortingNetwork(typeof(short))] +[HybridSortingNetwork(typeof(ushort))] [HybridSortingNetwork(typeof(long))] +[HybridSortingNetwork(typeof(ulong))] +[HybridSortingNetwork(typeof(uint))] [HybridSortingNetwork(typeof(float))] [HybridSortingNetwork(typeof(double))] +[HybridSortingNetwork(typeof(char))] partial class HybridSorter { } diff --git a/SortingNetworks/HybridSortingNetworkAttribute.cs b/SortingNetworks/HybridSortingNetworkAttribute.cs index f4ea891..d083d43 100644 --- a/SortingNetworks/HybridSortingNetworkAttribute.cs +++ b/SortingNetworks/HybridSortingNetworkAttribute.cs @@ -5,7 +5,7 @@ namespace SortingNetworks /// /// Marks a partial class for hybrid sorting network code generation. /// The source generator will emit Sort, PartialSort, - /// and NthElement methods that use AVX-512 SIMD partitioning + /// and NthElement methods that use quicksort partitioning /// for large arrays and optimal sorting networks for small sub-arrays. /// [AttributeUsage(AttributeTargets.Class, AllowMultiple = true, Inherited = false)] diff --git a/SortingNetworks/PublicAPI.Unshipped.txt b/SortingNetworks/PublicAPI.Unshipped.txt index f44ed24..a23877e 100644 --- a/SortingNetworks/PublicAPI.Unshipped.txt +++ b/SortingNetworks/PublicAPI.Unshipped.txt @@ -3,3 +3,6 @@ SortingNetworks.SortingNetworkAttribute SortingNetworks.SortingNetworkAttribute.ElementType.get -> System.Type! SortingNetworks.SortingNetworkAttribute.Size.get -> int SortingNetworks.SortingNetworkAttribute.SortingNetworkAttribute(int size, System.Type! elementType) -> void +SortingNetworks.HybridSortingNetworkAttribute +SortingNetworks.HybridSortingNetworkAttribute.ElementType.get -> System.Type! +SortingNetworks.HybridSortingNetworkAttribute.HybridSortingNetworkAttribute(System.Type! elementType) -> void From 209fca97662e1c3b91c2aa2a264b104849b7f598 Mon Sep 17 00:00:00 2001 From: Jonathan Peppers Date: Sun, 3 May 2026 14:47:27 -0500 Subject: [PATCH 4/4] Add decimal support to hybrid sorting path - Add decimal element size (16 bytes) in HybridSortEmitter.GetElementSize - Add [HybridSortingNetwork(typeof(decimal))] to test fixture - Add Sort_Decimal_200Elements functional test - Add decimal to HybridSort_GeneratesCode generator test Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- SortingNetworks.Generators/HybridSortEmitter.cs | 1 + SortingNetworks.Tests/GeneratorTests.cs | 1 + SortingNetworks.Tests/HybridSortTests.cs | 11 +++++++++++ SortingNetworks.Tests/HybridSorter.cs | 1 + 4 files changed, 14 insertions(+) diff --git a/SortingNetworks.Generators/HybridSortEmitter.cs b/SortingNetworks.Generators/HybridSortEmitter.cs index 87d0ec8..d047530 100644 --- a/SortingNetworks.Generators/HybridSortEmitter.cs +++ b/SortingNetworks.Generators/HybridSortEmitter.cs @@ -389,6 +389,7 @@ private static int GetElementSize(SpecialType specialType) SpecialType.System_Int64 => 8, SpecialType.System_UInt64 => 8, SpecialType.System_Double => 8, + SpecialType.System_Decimal => 16, _ => 0, }; } diff --git a/SortingNetworks.Tests/GeneratorTests.cs b/SortingNetworks.Tests/GeneratorTests.cs index 6d40772..a6e0094 100644 --- a/SortingNetworks.Tests/GeneratorTests.cs +++ b/SortingNetworks.Tests/GeneratorTests.cs @@ -919,6 +919,7 @@ public partial class MySorter {{ }} [InlineData("long")] [InlineData("float")] [InlineData("double")] + [InlineData("decimal")] public void HybridSort_GeneratesCode(string typeName) { var source = $@" diff --git a/SortingNetworks.Tests/HybridSortTests.cs b/SortingNetworks.Tests/HybridSortTests.cs index 62e9ffe..d7eb508 100644 --- a/SortingNetworks.Tests/HybridSortTests.cs +++ b/SortingNetworks.Tests/HybridSortTests.cs @@ -205,6 +205,17 @@ public void Sort_Char_200Elements() Assert.Equal(expected, data); } + [Fact] + public void Sort_Decimal_200Elements() + { + var rng = new Random(42); + var data = Enumerable.Range(0, 200).Select(_ => (decimal)(rng.NextDouble() * 2000 - 1000)).ToArray(); + var expected = (decimal[])data.Clone(); + Array.Sort(expected); + HybridSorter.Sort(data.AsSpan()); + Assert.Equal(expected, data); + } + // --- Stress test --- [Fact] diff --git a/SortingNetworks.Tests/HybridSorter.cs b/SortingNetworks.Tests/HybridSorter.cs index a2dcf48..3db7c19 100644 --- a/SortingNetworks.Tests/HybridSorter.cs +++ b/SortingNetworks.Tests/HybridSorter.cs @@ -13,4 +13,5 @@ namespace SortingNetworks.Tests; [HybridSortingNetwork(typeof(float))] [HybridSortingNetwork(typeof(double))] [HybridSortingNetwork(typeof(char))] +[HybridSortingNetwork(typeof(decimal))] partial class HybridSorter { }