diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 942a2a1695..a78931c8e0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -142,6 +142,72 @@ jobs: shell: bash run: python ./ci/run_ci.py java --version windows_java21 + csharp: + name: C# CI + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Set up .NET 8 + uses: actions/setup-dotnet@v4 + with: + dotnet-version: "8.0.x" + cache: true + cache-dependency-path: | + csharp/**/*.csproj + csharp/Fory.sln + - name: Restore C# dependencies + run: | + cd csharp + dotnet restore Fory.sln + - name: Build C# solution + run: | + cd csharp + dotnet build Fory.sln -c Release --no-restore + - name: Run C# tests + run: | + cd csharp + dotnet test tests/Fory.Tests/Fory.Tests.csproj -c Release --no-build + + csharp_xlang: + name: C# Xlang Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Set up .NET 8 + uses: actions/setup-dotnet@v4 + with: + dotnet-version: "8.0.x" + cache: true + cache-dependency-path: | + csharp/**/*.csproj + csharp/Fory.sln + - name: Set up JDK 21 + uses: actions/setup-java@v4 + with: + java-version: 21 + distribution: "temurin" + - name: Cache Maven local repository + uses: actions/cache@v4 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven- + - name: Restore and build C# xlang peer + run: | + cd csharp + dotnet restore tests/Fory.XlangPeer/Fory.XlangPeer.csproj + dotnet build tests/Fory.XlangPeer/Fory.XlangPeer.csproj -c Debug --no-restore + - name: Run C# Xlang Test + env: + FORY_CSHARP_JAVA_CI: "1" + ENABLE_FORY_DEBUG_OUTPUT: "1" + run: | + cd java + mvn -T16 --no-transfer-progress clean install -DskipTests + cd fory-core + mvn -T16 --no-transfer-progress test -Dtest=org.apache.fory.xlang.CSharpXlangTest + swift: name: Swift CI runs-on: macos-latest diff --git a/.gitignore b/.gitignore index db9be7edff..2c0f134e6b 100644 --- a/.gitignore +++ b/.gitignore @@ -106,4 +106,13 @@ benchmarks/**/report/ ignored/** ci-logs/** **/*.log -swift/.build \ No newline at end of file +swift/.build + +csharp/src/Fory.Generator/bin/ +csharp/src/Fory.Generator/obj/ +csharp/src/Fory/bin/ +csharp/src/Fory/obj/ +csharp/tests/Fory.Tests/bin/ +csharp/tests/Fory.Tests/obj/ +csharp/tests/Fory.XlangPeer/bin/ +csharp/tests/Fory.XlangPeer/obj/ \ No newline at end of file diff --git a/csharp/Fory.sln b/csharp/Fory.sln new file mode 100644 index 0000000000..b2f9edf7b1 --- /dev/null +++ b/csharp/Fory.sln @@ -0,0 +1,50 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.0.31903.59 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{C3FE93C6-2294-475C-88DD-BE1FFFADC4D6}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Fory", "src\Fory\Fory.csproj", "{309709EA-773D-449D-AE12-D1C2B9518793}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Fory.Generator", "src\Fory.Generator\Fory.Generator.csproj", "{4DBAD732-2820-4B15-B655-19384F100997}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tests", "tests", "{02DC4D41-0522-42EA-B643-471F4BB0747E}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Fory.Tests", "tests\Fory.Tests\Fory.Tests.csproj", "{65701195-E254-4D93-9CD0-F587FDFCE769}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Fory.XlangPeer", "tests\Fory.XlangPeer\Fory.XlangPeer.csproj", "{8E1D5E47-AF72-46FF-B60F-B1C6210654A4}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {309709EA-773D-449D-AE12-D1C2B9518793}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {309709EA-773D-449D-AE12-D1C2B9518793}.Debug|Any CPU.Build.0 = Debug|Any CPU + {309709EA-773D-449D-AE12-D1C2B9518793}.Release|Any CPU.ActiveCfg = Release|Any CPU + {309709EA-773D-449D-AE12-D1C2B9518793}.Release|Any CPU.Build.0 = Release|Any CPU + {4DBAD732-2820-4B15-B655-19384F100997}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4DBAD732-2820-4B15-B655-19384F100997}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4DBAD732-2820-4B15-B655-19384F100997}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4DBAD732-2820-4B15-B655-19384F100997}.Release|Any CPU.Build.0 = Release|Any CPU + {65701195-E254-4D93-9CD0-F587FDFCE769}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {65701195-E254-4D93-9CD0-F587FDFCE769}.Debug|Any CPU.Build.0 = Debug|Any CPU + {65701195-E254-4D93-9CD0-F587FDFCE769}.Release|Any CPU.ActiveCfg = Release|Any CPU + {65701195-E254-4D93-9CD0-F587FDFCE769}.Release|Any CPU.Build.0 = Release|Any CPU + {8E1D5E47-AF72-46FF-B60F-B1C6210654A4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8E1D5E47-AF72-46FF-B60F-B1C6210654A4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8E1D5E47-AF72-46FF-B60F-B1C6210654A4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8E1D5E47-AF72-46FF-B60F-B1C6210654A4}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {309709EA-773D-449D-AE12-D1C2B9518793} = {C3FE93C6-2294-475C-88DD-BE1FFFADC4D6} + {4DBAD732-2820-4B15-B655-19384F100997} = {C3FE93C6-2294-475C-88DD-BE1FFFADC4D6} + {65701195-E254-4D93-9CD0-F587FDFCE769} = {02DC4D41-0522-42EA-B643-471F4BB0747E} + {8E1D5E47-AF72-46FF-B60F-B1C6210654A4} = {02DC4D41-0522-42EA-B643-471F4BB0747E} + EndGlobalSection +EndGlobal diff --git a/csharp/README.md b/csharp/README.md new file mode 100644 index 0000000000..a3300bee6c --- /dev/null +++ b/csharp/README.md @@ -0,0 +1 @@ +# Apache Fory™ C\# diff --git a/csharp/src/Fory.Generator/Fory.Generator.csproj b/csharp/src/Fory.Generator/Fory.Generator.csproj new file mode 100644 index 0000000000..48864e20cf --- /dev/null +++ b/csharp/src/Fory.Generator/Fory.Generator.csproj @@ -0,0 +1,16 @@ + + + netstandard2.0 + 12.0 + enable + enable + true + true + false + + + + + + + diff --git a/csharp/src/Fory.Generator/ForyObjectGenerator.cs b/csharp/src/Fory.Generator/ForyObjectGenerator.cs new file mode 100644 index 0000000000..6cb40c931c --- /dev/null +++ b/csharp/src/Fory.Generator/ForyObjectGenerator.cs @@ -0,0 +1,1546 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Collections.Immutable; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; + +namespace Apache.Fory.Generator; + +[Generator(LanguageNames.CSharp)] +public sealed class ForyObjectGenerator : IIncrementalGenerator +{ + private static readonly SymbolDisplayFormat FullNameFormat = + SymbolDisplayFormat.FullyQualifiedFormat.WithMiscellaneousOptions( + SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier); + + private static readonly DiagnosticDescriptor GenericTypeNotSupported = new( + id: "FORY001", + title: "Generic types are not supported by ForyObject generator", + messageFormat: "Type '{0}' is generic and is not supported by [ForyObject].", + category: "Fory", + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true); + + private static readonly DiagnosticDescriptor MissingCtor = new( + id: "FORY002", + title: "Missing parameterless constructor", + messageFormat: "Class '{0}' must declare an accessible parameterless constructor for [ForyObject].", + category: "Fory", + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true); + + private static readonly DiagnosticDescriptor UnsupportedEncoding = new( + id: "FORY003", + title: "Unsupported Field encoding", + messageFormat: "Member '{0}' uses unsupported [Field] encoding for type '{1}'.", + category: "Fory", + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true); + + public void Initialize(IncrementalGeneratorInitializationContext context) + { + IncrementalValuesProvider typeModels = context.SyntaxProvider + .ForAttributeWithMetadataName( + "Apache.Fory.ForyObjectAttribute", + static (node, _) => node is TypeDeclarationSyntax || node is EnumDeclarationSyntax, + static (syntaxContext, ct) => BuildTypeModel(syntaxContext, ct)) + .Where(static m => m is not null); + + context.RegisterSourceOutput( + typeModels.Collect(), + static (spc, models) => Emit(spc, models)); + } + + private static void Emit(SourceProductionContext context, ImmutableArray maybeModels) + { + if (maybeModels.IsDefaultOrEmpty) + { + return; + } + + Dictionary models = new(StringComparer.Ordinal); + foreach (TypeModel? maybeModel in maybeModels) + { + if (maybeModel is null) + { + continue; + } + + models[maybeModel.TypeName] = maybeModel; + } + + if (models.Count == 0) + { + return; + } + + StringBuilder sb = new(); + sb.AppendLine("// "); + sb.AppendLine("#nullable enable"); + sb.AppendLine("namespace Apache.Fory.Generated;"); + sb.AppendLine(); + + foreach (KeyValuePair entry in models.OrderBy(kv => kv.Key, StringComparer.Ordinal)) + { + TypeModel model = entry.Value; + if (model.Kind == DeclKind.Struct || model.Kind == DeclKind.Class) + { + EmitObjectSerializer(sb, model); + sb.AppendLine(); + } + } + + sb.AppendLine("internal static class __ForyGeneratedModuleInitializer"); + sb.AppendLine("{"); + sb.AppendLine(" [global::System.Runtime.CompilerServices.ModuleInitializer]"); + sb.AppendLine(" internal static void Register()"); + sb.AppendLine(" {"); + foreach (KeyValuePair entry in models.OrderBy(kv => kv.Key, StringComparer.Ordinal)) + { + TypeModel model = entry.Value; + if (model.Kind == DeclKind.Enum) + { + sb.AppendLine( + $" global::Apache.Fory.TypeResolver.RegisterGenerated<{model.TypeName}, global::Apache.Fory.EnumSerializer<{model.TypeName}>>();"); + } + else + { + sb.AppendLine( + $" global::Apache.Fory.TypeResolver.RegisterGenerated<{model.TypeName}, {model.SerializerName}>();"); + } + } + + sb.AppendLine(" }"); + sb.AppendLine("}"); + + context.AddSource("Fory.GeneratedSerializers.g.cs", SourceText.From(sb.ToString(), Encoding.UTF8)); + } + + private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) + { + sb.AppendLine($"file sealed class {model.SerializerName} : global::Apache.Fory.Serializer<{model.TypeName}>"); + sb.AppendLine("{"); + sb.AppendLine(" private static global::Apache.Fory.RefMode __ForyRefMode(bool nullable, bool trackRef)"); + sb.AppendLine(" {"); + sb.AppendLine(" if (trackRef)"); + sb.AppendLine(" {"); + sb.AppendLine(" return global::Apache.Fory.RefMode.Tracking;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" return nullable ? global::Apache.Fory.RefMode.NullOnly : global::Apache.Fory.RefMode.None;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" private static bool __ForyNeedsTypeInfoForField(global::Apache.Fory.TypeId typeId)"); + sb.AppendLine(" {"); + sb.AppendLine(" return typeId switch"); + sb.AppendLine(" {"); + sb.AppendLine(" global::Apache.Fory.TypeId.Struct or"); + sb.AppendLine(" global::Apache.Fory.TypeId.CompatibleStruct or"); + sb.AppendLine(" global::Apache.Fory.TypeId.NamedStruct or"); + sb.AppendLine(" global::Apache.Fory.TypeId.NamedCompatibleStruct or"); + sb.AppendLine(" global::Apache.Fory.TypeId.Ext or"); + sb.AppendLine(" global::Apache.Fory.TypeId.NamedExt or"); + sb.AppendLine(" global::Apache.Fory.TypeId.Unknown => true,"); + sb.AppendLine(" _ => false,"); + sb.AppendLine(" };"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" private static uint __ForySchemaHash(bool trackRef, global::Apache.Fory.TypeResolver typeResolver)"); + sb.AppendLine(" {"); + sb.Append(" return global::Apache.Fory.SchemaHash.StructHash32("); + sb.Append(BuildSchemaFingerprintExpression(model.Members)); + sb.AppendLine(");"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" public override global::Apache.Fory.TypeId StaticTypeId => global::Apache.Fory.TypeId.Struct;"); + if (model.Kind == DeclKind.Class) + { + sb.AppendLine(" public override bool IsNullableType => true;"); + sb.AppendLine(" public override bool IsReferenceTrackableType => true;"); + sb.AppendLine($" public override {model.TypeName} DefaultValue => null!;"); + sb.AppendLine($" public override bool IsNone(in {model.TypeName} value) => value is null;"); + } + else + { + sb.AppendLine($" public override {model.TypeName} DefaultValue => new {model.TypeName}();"); + } + + sb.AppendLine(); + sb.AppendLine(" public override global::System.Collections.Generic.IReadOnlyList CompatibleTypeMetaFields(bool trackRef)"); + sb.AppendLine(" {"); + if (model.SortedMembers.Length == 0) + { + sb.AppendLine(" return global::System.Array.Empty();"); + } + else + { + sb.AppendLine(" return new global::Apache.Fory.TypeMetaFieldInfo[]"); + sb.AppendLine(" {"); + foreach (MemberModel member in model.SortedMembers) + { + sb.AppendLine( + $" new global::Apache.Fory.TypeMetaFieldInfo(null, \"{EscapeString(member.Name)}\", {BuildCompatibleTypeMetaExpression(member.TypeMeta, "trackRef")}),"); + } + + sb.AppendLine(" };"); + } + + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine( + $" public override void WriteData(ref global::Apache.Fory.WriteContext context, in {model.TypeName} value, bool hasGenerics)"); + sb.AppendLine(" {"); + sb.AppendLine(" _ = hasGenerics;"); + sb.AppendLine(" if (context.Compatible)"); + sb.AppendLine(" {"); + if (model.SortedMembers.Length == 0) + { + sb.AppendLine(" return;"); + } + else + { + foreach (MemberModel member in model.SortedMembers) + { + EmitWriteMember(sb, member, true); + } + + sb.AppendLine(" return;"); + } + + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" context.Writer.WriteInt32(unchecked((int)__ForySchemaHash(context.TrackRef, context.TypeResolver)));"); + foreach (MemberModel member in model.SortedMembers) + { + EmitWriteMember(sb, member, false); + } + + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine($" public override {model.TypeName} ReadData(ref global::Apache.Fory.ReadContext context)"); + sb.AppendLine(" {"); + sb.AppendLine(" if (context.Compatible)"); + sb.AppendLine(" {"); + sb.AppendLine($" global::Apache.Fory.TypeMeta typeMeta = context.ConsumeCompatibleTypeMeta(typeof({model.TypeName}));"); + sb.AppendLine($" {model.TypeName} value = new {model.TypeName}();"); + if (model.Kind == DeclKind.Class) + { + sb.AppendLine(" context.BindPendingReference(value);"); + } + + sb.AppendLine(" foreach (global::Apache.Fory.TypeMetaFieldInfo remoteField in typeMeta.Fields)"); + sb.AppendLine(" {"); + sb.AppendLine(" global::Apache.Fory.RefMode remoteRefMode = __ForyRefMode(remoteField.FieldType.Nullable, remoteField.FieldType.TrackRef);"); + sb.AppendLine(" bool remoteReadTypeInfo = __ForyNeedsTypeInfoForField((global::Apache.Fory.TypeId)remoteField.FieldType.TypeId);"); + sb.AppendLine(" switch (remoteField.FieldName)"); + sb.AppendLine(" {"); + foreach (MemberModel member in model.SortedMembers) + { + sb.AppendLine($" case \"{EscapeString(member.FieldIdentifier)}\":"); + sb.AppendLine(" {"); + EmitReadMemberAssignment(sb, member, "remoteRefMode", "remoteReadTypeInfo", "value", "Compat", 7, false); + sb.AppendLine(" break;"); + sb.AppendLine(" }"); + } + + sb.AppendLine(" default:"); + sb.AppendLine(" global::Apache.Fory.FieldSkipper.SkipFieldValue(ref context, remoteField.FieldType);"); + sb.AppendLine(" break;"); + sb.AppendLine(" }"); + sb.AppendLine(" }"); + sb.AppendLine(" return value;"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine(" uint schemaHash = unchecked((uint)context.Reader.ReadInt32());"); + sb.AppendLine(" uint expectedHash = __ForySchemaHash(context.TrackRef, context.TypeResolver);"); + sb.AppendLine(" if (schemaHash != expectedHash)"); + sb.AppendLine(" {"); + sb.AppendLine(" throw new global::Apache.Fory.InvalidDataException($\"class version hash mismatch: expected {expectedHash}, got {schemaHash}\");"); + sb.AppendLine(" }"); + sb.AppendLine(); + sb.AppendLine($" {model.TypeName} valueSchema = new {model.TypeName}();"); + if (model.Kind == DeclKind.Class) + { + sb.AppendLine(" context.BindPendingReference(valueSchema);"); + } + + foreach (MemberModel member in model.SortedMembers) + { + EmitReadMemberAssignment(sb, member, BuildWriteRefModeExpression(member), "false", "valueSchema", "Schema", 2, true); + } + + sb.AppendLine(" return valueSchema;"); + sb.AppendLine(" }"); + sb.AppendLine("}"); + } + + private static void EmitWriteMember(StringBuilder sb, MemberModel member, bool compatibleMode) + { + string refModeExpr = BuildWriteRefModeExpression(member); + string memberAccess = $"value.{member.Name}"; + string hasGenerics = member.IsCollection ? "true" : "false"; + string writeTypeInfo = compatibleMode + ? $"__ForyNeedsTypeInfoForField(context.TypeResolver.GetTypeInfo<{member.TypeName}>().StaticTypeId)" + : "false"; + + switch (member.DynamicAnyKind) + { + case DynamicAnyKind.AnyValue: + sb.AppendLine( + $" global::Apache.Fory.DynamicAnyCodec.WriteAny(ref context, {memberAccess}, {refModeExpr}, true, false);"); + return; + case DynamicAnyKind.None: + break; + default: + throw new InvalidOperationException($"unsupported dynamic any kind {member.DynamicAnyKind}"); + } + + if (!member.IsNullable && TryBuildDirectFieldWrite(member, memberAccess, out string? directWriteCode)) + { + sb.AppendLine($" {directWriteCode}"); + return; + } + + if (TryBuildNullableFixedTaggedFieldWrite(member, memberAccess, out string? nullableWriteCode)) + { + sb.AppendLine($" {nullableWriteCode}"); + return; + } + + sb.AppendLine( + $" context.TypeResolver.GetSerializer<{member.TypeName}>().Write(ref context, {memberAccess}, {refModeExpr}, {writeTypeInfo}, {hasGenerics});"); + } + + private static void EmitReadMemberAssignment( + StringBuilder sb, + MemberModel member, + string refModeExpr, + string readTypeInfoExpr, + string valueVar, + string variableSuffix, + int indentLevel, + bool allowDirectRead) + { + string indent = new(' ', indentLevel * 2); + string assignmentTarget = $"{valueVar}.{member.Name}"; + string typeOfTypeName = StripNullableForTypeOf(member.TypeName); + switch (member.DynamicAnyKind) + { + case DynamicAnyKind.AnyValue: + sb.AppendLine( + $"{indent}{assignmentTarget} = ({member.TypeName})global::Apache.Fory.DynamicAnyCodec.CastAnyDynamicValue(global::Apache.Fory.DynamicAnyCodec.ReadAny(ref context, {refModeExpr}, true), typeof({typeOfTypeName}))!;"); + return; + case DynamicAnyKind.None: + break; + default: + throw new InvalidOperationException($"unsupported dynamic any kind {member.DynamicAnyKind}"); + } + + if (allowDirectRead && !member.IsNullable && TryBuildDirectFieldRead(member, out string? directReadExpr)) + { + sb.AppendLine($"{indent}{assignmentTarget} = {directReadExpr};"); + return; + } + + if (allowDirectRead && TryBuildNullableFixedTaggedFieldRead(member, assignmentTarget, variableSuffix, indent, out string? nullableReadCode)) + { + sb.AppendLine(nullableReadCode); + return; + } + + sb.AppendLine( + $"{indent}{assignmentTarget} = context.TypeResolver.GetSerializer<{member.TypeName}>().Read(ref context, {refModeExpr}, {readTypeInfoExpr});"); + } + + private static string StripNullableForTypeOf(string typeName) + { + return typeName.Replace("?", string.Empty); + } + + private static bool TryBuildDirectFieldWrite(MemberModel member, string memberAccess, out string? writeCode) + { + writeCode = null; + if (!CanUseDirectBuiltInFieldAccess(member)) + { + return false; + } + + return TryBuildDirectPayloadWrite(member.Classification.TypeId, memberAccess, out writeCode); + } + + private static bool TryBuildDirectFieldRead(MemberModel member, out string? readExpr) + { + readExpr = null; + if (!CanUseDirectBuiltInFieldAccess(member)) + { + return false; + } + + return TryBuildDirectPayloadRead(member.Classification.TypeId, out readExpr); + } + + private static bool TryBuildNullableFixedTaggedFieldWrite(MemberModel member, string memberAccess, out string? writeCode) + { + writeCode = null; + if (!member.IsNullableValueType || !IsFixedTaggedTypeId(member.Classification.TypeId)) + { + return false; + } + + if (!TryBuildDirectPayloadWrite(member.Classification.TypeId, $"{memberAccess}.Value", out string? payloadWriteCode)) + { + return false; + } + + writeCode = $"if (!{memberAccess}.HasValue) {{ context.Writer.WriteInt8((sbyte)global::Apache.Fory.RefFlag.Null); }} else {{ context.Writer.WriteInt8((sbyte)global::Apache.Fory.RefFlag.NotNullValue); {payloadWriteCode} }}"; + return true; + } + + private static bool TryBuildNullableFixedTaggedFieldRead( + MemberModel member, + string assignmentTarget, + string variableSuffix, + string indent, + out string code) + { + code = string.Empty; + if (!member.IsNullableValueType || !IsFixedTaggedTypeId(member.Classification.TypeId)) + { + return false; + } + + if (!TryBuildDirectPayloadRead(member.Classification.TypeId, out string? payloadReadExpr)) + { + return false; + } + + string refFlagVar = $"__{member.Name}RefFlag{variableSuffix}"; + string nestedIndent = indent + " "; + StringBuilder sb = new(); + sb.AppendLine($"{indent}sbyte {refFlagVar} = context.Reader.ReadInt8();"); + sb.AppendLine($"{indent}if ({refFlagVar} == (sbyte)global::Apache.Fory.RefFlag.Null)"); + sb.AppendLine($"{indent}{{"); + sb.AppendLine($"{nestedIndent}{assignmentTarget} = ({member.TypeName})null!;"); + sb.AppendLine($"{indent}}}"); + sb.AppendLine($"{indent}else"); + sb.AppendLine($"{indent}{{"); + sb.AppendLine($"{nestedIndent}{assignmentTarget} = {payloadReadExpr};"); + sb.Append($"{indent}}}"); + code = sb.ToString(); + return true; + } + + private static bool IsFixedTaggedTypeId(uint typeId) + { + return typeId is 4 or 6 or 8 or 11 or 13 or 15; + } + + private static bool TryBuildDirectPayloadWrite(uint typeId, string valueExpr, out string? writeCode) + { + writeCode = null; + switch (typeId) + { + case 1: + writeCode = $"context.Writer.WriteUInt8({valueExpr} ? (byte)1 : (byte)0);"; + return true; + case 2: + writeCode = $"context.Writer.WriteInt8({valueExpr});"; + return true; + case 3: + writeCode = $"context.Writer.WriteInt16({valueExpr});"; + return true; + case 4: + writeCode = $"context.Writer.WriteInt32({valueExpr});"; + return true; + case 5: + writeCode = $"context.Writer.WriteVarInt32({valueExpr});"; + return true; + case 6: + writeCode = $"context.Writer.WriteInt64({valueExpr});"; + return true; + case 7: + writeCode = $"context.Writer.WriteVarInt64({valueExpr});"; + return true; + case 8: + writeCode = $"context.Writer.WriteTaggedInt64({valueExpr});"; + return true; + case 9: + writeCode = $"context.Writer.WriteUInt8({valueExpr});"; + return true; + case 10: + writeCode = $"context.Writer.WriteUInt16({valueExpr});"; + return true; + case 11: + writeCode = $"context.Writer.WriteUInt32({valueExpr});"; + return true; + case 12: + writeCode = $"context.Writer.WriteVarUInt32({valueExpr});"; + return true; + case 13: + writeCode = $"context.Writer.WriteUInt64({valueExpr});"; + return true; + case 14: + writeCode = $"context.Writer.WriteVarUInt64({valueExpr});"; + return true; + case 15: + writeCode = $"context.Writer.WriteTaggedUInt64({valueExpr});"; + return true; + case 19: + writeCode = $"context.Writer.WriteFloat32({valueExpr});"; + return true; + case 20: + writeCode = $"context.Writer.WriteFloat64({valueExpr});"; + return true; + case 21: + writeCode = $"global::Apache.Fory.StringSerializer.WriteString(ref context, {valueExpr});"; + return true; + default: + return false; + } + } + + private static bool TryBuildDirectPayloadRead(uint typeId, out string? readExpr) + { + readExpr = null; + switch (typeId) + { + case 1: + readExpr = "context.Reader.ReadUInt8() != 0"; + return true; + case 2: + readExpr = "context.Reader.ReadInt8()"; + return true; + case 3: + readExpr = "context.Reader.ReadInt16()"; + return true; + case 4: + readExpr = "context.Reader.ReadInt32()"; + return true; + case 5: + readExpr = "context.Reader.ReadVarInt32()"; + return true; + case 6: + readExpr = "context.Reader.ReadInt64()"; + return true; + case 7: + readExpr = "context.Reader.ReadVarInt64()"; + return true; + case 8: + readExpr = "context.Reader.ReadTaggedInt64()"; + return true; + case 9: + readExpr = "context.Reader.ReadUInt8()"; + return true; + case 10: + readExpr = "context.Reader.ReadUInt16()"; + return true; + case 11: + readExpr = "context.Reader.ReadUInt32()"; + return true; + case 12: + readExpr = "context.Reader.ReadVarUInt32()"; + return true; + case 13: + readExpr = "context.Reader.ReadUInt64()"; + return true; + case 14: + readExpr = "context.Reader.ReadVarUInt64()"; + return true; + case 15: + readExpr = "context.Reader.ReadTaggedUInt64()"; + return true; + case 19: + readExpr = "context.Reader.ReadFloat32()"; + return true; + case 20: + readExpr = "context.Reader.ReadFloat64()"; + return true; + case 21: + readExpr = "global::Apache.Fory.StringSerializer.ReadString(ref context)"; + return true; + default: + return false; + } + } + + private static bool CanUseDirectBuiltInFieldAccess(MemberModel member) + { + if (member.IsNullable || + member.DynamicAnyKind != DynamicAnyKind.None || + member.IsCollection || + member.Classification.IsMap) + { + return false; + } + + return member.Classification.IsPrimitive || member.Classification.TypeId == 21; + } + + private static string BuildSchemaFingerprintExpression(ImmutableArray members) + { + if (members.IsDefaultOrEmpty) + { + return "\"\""; + } + + IEnumerable ordered = members + .OrderBy(m => m.FieldIdentifier, StringComparer.Ordinal) + .ThenBy(m => m.OriginalIndex); + + StringBuilder sb = new(); + bool first = true; + foreach (MemberModel member in ordered) + { + uint fingerprintTypeId = (member.Classification.IsPrimitive || member.Classification.IsBuiltIn) + ? member.Classification.TypeId + : 0; + string trackRefExpr = member.DynamicAnyKind switch + { + DynamicAnyKind.AnyValue => "(trackRef ? 1 : 0)", + _ => member.Classification.IsBuiltIn + ? "0" + : $"((trackRef && typeResolver.GetTypeInfo<{member.TypeName}>().IsReferenceTrackableType) ? 1 : 0)", + }; + string nullable = member.IsNullable ? "1" : "0"; + string piece = $"\"{EscapeString(member.FieldIdentifier)},{fingerprintTypeId},\" + {trackRefExpr} + \",{nullable};\""; + if (!first) + { + sb.Append(" + "); + } + + first = false; + sb.Append(piece); + } + + return sb.ToString(); + } + + private static string BuildCompatibleTypeMetaExpression(TypeMetaFieldTypeModel model, string trackRefExpr) + { + string localTrackRefExpr = model.TrackRefByContext ? trackRefExpr : "false"; + if (model.Generics.Length > 0) + { + string generics = string.Join( + ", ", + model.Generics.Select(g => BuildCompatibleTypeMetaExpression(g, "false"))); + return + $"new global::Apache.Fory.TypeMetaFieldType({model.TypeIdExpr}, {BoolLiteral(model.Nullable)}, {localTrackRefExpr}, new global::Apache.Fory.TypeMetaFieldType[] {{ {generics} }})"; + } + + return $"new global::Apache.Fory.TypeMetaFieldType({model.TypeIdExpr}, {BoolLiteral(model.Nullable)}, {localTrackRefExpr})"; + } + + private static string BuildWriteRefModeExpression(MemberModel member) + { + return member.DynamicAnyKind switch + { + DynamicAnyKind.AnyValue => $"__ForyRefMode({BoolLiteral(member.IsNullable)}, context.TrackRef)", + _ => member.Classification.IsBuiltIn + ? $"__ForyRefMode({BoolLiteral(member.IsNullable)}, false)" + : $"__ForyRefMode({BoolLiteral(member.IsNullable)}, context.TrackRef && context.TypeResolver.GetTypeInfo<{member.TypeName}>().IsReferenceTrackableType)", + }; + } + + private static TypeModel? BuildTypeModel(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) + { + _ = cancellationToken; + if (context.TargetSymbol is not INamedTypeSymbol typeSymbol) + { + return null; + } + + if (typeSymbol.TypeParameters.Length > 0) + { + return null; + } + + string typeName = typeSymbol.ToDisplayString(FullNameFormat); + string serializerName = "__ForySerializer_" + Sanitize(typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + if (typeSymbol.TypeKind == TypeKind.Enum) + { + return new TypeModel( + typeName, + serializerName, + DeclKind.Enum, + ImmutableArray.Empty, + ImmutableArray.Empty); + } + + DeclKind kind = typeSymbol.TypeKind switch + { + TypeKind.Struct => DeclKind.Struct, + TypeKind.Class => DeclKind.Class, + _ => DeclKind.Unknown, + }; + + if (kind == DeclKind.Unknown) + { + return null; + } + + if (kind == DeclKind.Class && !HasAccessibleParameterlessCtor(typeSymbol)) + { + return null; + } + + List members = []; + foreach (ISymbol member in typeSymbol.GetMembers()) + { + if (member.IsStatic) + { + continue; + } + + if (member is IFieldSymbol field) + { + if (field.IsConst || field.IsReadOnly || !IsReadableWritableAccessibility(field.DeclaredAccessibility)) + { + continue; + } + + MemberModel? parsedField = BuildMemberModel(field.Name, field.Type, field, MemberDeclKind.Field); + if (parsedField is not null) + { + members.Add(parsedField); + } + + continue; + } + + if (member is IPropertySymbol property) + { + if (property.IsIndexer || property.GetMethod is null || property.SetMethod is null) + { + continue; + } + + if (property.SetMethod.IsInitOnly) + { + continue; + } + + if (!IsReadableWritableAccessibility(property.GetMethod.DeclaredAccessibility) || + !IsReadableWritableAccessibility(property.SetMethod.DeclaredAccessibility)) + { + continue; + } + + MemberModel? parsedProperty = BuildMemberModel( + property.Name, + property.Type, + property, + MemberDeclKind.Property); + if (parsedProperty is not null) + { + members.Add(parsedProperty); + } + } + } + + ImmutableArray ordered = members + .OrderBy(m => m.OriginalIndex) + .ToImmutableArray(); + ImmutableArray sorted = SortMembers(ordered); + + return new TypeModel(typeName, serializerName, kind, ordered, sorted); + } + + private static MemberModel? BuildMemberModel( + string name, + ITypeSymbol memberType, + ISymbol memberSymbol, + MemberDeclKind memberDeclKind) + { + (bool isOptional, ITypeSymbol unwrappedType) = UnwrapNullable(memberType); + FieldEncoding fieldEncoding = FieldEncoding.None; + foreach (AttributeData attribute in memberSymbol.GetAttributes()) + { + string? attrName = attribute.AttributeClass?.ToDisplayString(); + if (!string.Equals(attrName, "Apache.Fory.FieldAttribute", StringComparison.Ordinal)) + { + continue; + } + + foreach (KeyValuePair namedArg in attribute.NamedArguments) + { + if (!string.Equals(namedArg.Key, "Encoding", StringComparison.Ordinal)) + { + continue; + } + + if (namedArg.Value.Value is int encoding) + { + fieldEncoding = (FieldEncoding)encoding; + } + } + } + + DynamicAnyKind dynamicAnyKind = ResolveDynamicAnyKind(unwrappedType); + TypeResolution resolution = ResolveTypeResolution(unwrappedType, fieldEncoding); + if (!resolution.Supported) + { + return null; + } + + TypeClassification classification = resolution.Classification; + int group = classification.IsPrimitive + ? (isOptional ? 2 : 1) + : classification.IsMap + ? 5 + : classification.IsCollection + ? 4 + : classification.IsBuiltIn + ? 3 + : 6; + + int index = int.MaxValue; + Location? sourceLocation = memberSymbol.Locations.FirstOrDefault(loc => loc.IsInSource); + if (sourceLocation is not null) + { + index = sourceLocation.SourceSpan.Start; + } + + string typeName = memberType.ToDisplayString(FullNameFormat); + TypeMetaFieldTypeModel typeMeta = BuildTypeMetaFieldTypeModel( + memberType, + isOptional, + dynamicAnyKind, + resolution.Classification.TypeId); + + return new MemberModel( + name, + ToSnakeCase(name), + index, + memberDeclKind, + typeName, + isOptional, + memberType is INamedTypeSymbol nts && + nts.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T, + classification, + group, + classification.IsCollection || classification.IsMap, + dynamicAnyKind == DynamicAnyKind.None ? DynamicAnyKind.None : dynamicAnyKind, + typeMeta); + } + + private static TypeMetaFieldTypeModel BuildTypeMetaFieldTypeModel( + ITypeSymbol memberType, + bool nullable, + DynamicAnyKind dynamicAnyKind, + uint explicitTypeId) + { + (bool _, ITypeSymbol unwrapped) = UnwrapNullable(memberType); + + if (TryGetListElementType(unwrapped, out ITypeSymbol? listElementType)) + { + bool elementNullable = GenericNullable(listElementType!); + TypeMetaFieldTypeModel element = BuildTypeMetaFieldTypeModel( + listElementType!, + elementNullable, + ResolveDynamicAnyKind(UnwrapNullable(listElementType!).Item2), + 0); + return new TypeMetaFieldTypeModel( + "(uint)global::Apache.Fory.TypeId.List", + nullable, + false, + ImmutableArray.Create(element)); + } + + if (TryGetSetElementType(unwrapped, out ITypeSymbol? setElementType)) + { + bool elementNullable = GenericNullable(setElementType!); + TypeMetaFieldTypeModel element = BuildTypeMetaFieldTypeModel( + setElementType!, + elementNullable, + ResolveDynamicAnyKind(UnwrapNullable(setElementType!).Item2), + 0); + return new TypeMetaFieldTypeModel( + "(uint)global::Apache.Fory.TypeId.Set", + nullable, + false, + ImmutableArray.Create(element)); + } + + if (TryGetMapTypeArguments(unwrapped, out ITypeSymbol? keyType, out ITypeSymbol? valueType)) + { + bool keyNullable = GenericNullable(keyType!); + bool valueNullable = GenericNullable(valueType!); + TypeMetaFieldTypeModel key = BuildTypeMetaFieldTypeModel( + keyType!, + keyNullable, + ResolveDynamicAnyKind(UnwrapNullable(keyType!).Item2), + 0); + TypeMetaFieldTypeModel value = BuildTypeMetaFieldTypeModel( + valueType!, + valueNullable, + ResolveDynamicAnyKind(UnwrapNullable(valueType!).Item2), + 0); + return new TypeMetaFieldTypeModel( + "(uint)global::Apache.Fory.TypeId.Map", + nullable, + false, + ImmutableArray.Create(key, value)); + } + + TypeClassification classification = ClassifyType(unwrapped); + if (explicitTypeId != 0 && classification.IsPrimitive && classification.TypeId != explicitTypeId) + { + return new TypeMetaFieldTypeModel( + explicitTypeId.ToString(), + nullable, + false, + ImmutableArray.Empty); + } + + if (IsUnionType(unwrapped)) + { + return new TypeMetaFieldTypeModel( + "(uint)global::Apache.Fory.TypeId.Union", + nullable, + true, + ImmutableArray.Empty); + } + + if (dynamicAnyKind == DynamicAnyKind.AnyValue) + { + return new TypeMetaFieldTypeModel( + "(uint)global::Apache.Fory.TypeId.Unknown", + nullable, + true, + ImmutableArray.Empty); + } + + if (unwrapped.TypeKind == TypeKind.Enum) + { + return new TypeMetaFieldTypeModel( + "(uint)global::Apache.Fory.TypeId.Enum", + nullable, + false, + ImmutableArray.Empty); + } + + return new TypeMetaFieldTypeModel( + $"(uint){classification.TypeId}", + nullable, + !classification.IsBuiltIn && unwrapped.TypeKind != TypeKind.Enum, + ImmutableArray.Empty); + } + + private static ImmutableArray SortMembers(ImmutableArray members) + { + return members + .OrderBy(m => m.Group) + .ThenBy(m => + { + if (m.Group is 1 or 2) + { + return m.Classification.IsCompressedNumeric ? 1 : 0; + } + + return 0; + }) + .ThenByDescending(m => m.Group is 1 or 2 ? m.Classification.PrimitiveSize : 0) + .ThenBy(m => + { + if (m.Group is 1 or 2) + { + return (int)(uint.MaxValue - m.Classification.TypeId); + } + + if (m.Group is 3 or 4 or 5) + { + return (int)m.Classification.TypeId; + } + + return 0; + }) + .ThenBy(m => m.FieldIdentifier, StringComparer.Ordinal) + .ThenBy(m => m.Name, StringComparer.Ordinal) + .ThenBy(m => m.OriginalIndex) + .ToImmutableArray(); + } + + private static bool GenericNullable(ITypeSymbol type) + { + (bool optional, ITypeSymbol unwrapped) = UnwrapNullable(type); + if (optional) + { + return true; + } + + TypeClassification c = ClassifyType(unwrapped); + return !c.IsPrimitive; + } + + private static bool IsReadableWritableAccessibility(Accessibility accessibility) + { + return accessibility is Accessibility.Public or Accessibility.Internal or Accessibility.ProtectedOrInternal; + } + + private static bool HasAccessibleParameterlessCtor(INamedTypeSymbol type) + { + foreach (IMethodSymbol ctor in type.InstanceConstructors) + { + if (ctor.Parameters.Length != 0) + { + continue; + } + + if (ctor.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal or Accessibility.ProtectedOrInternal) + { + return true; + } + } + + return false; + } + + private static TypeResolution ResolveTypeResolution(ITypeSymbol type, FieldEncoding encoding) + { + TypeClassification baseType = ClassifyType(type); + if (encoding == FieldEncoding.None) + { + return new TypeResolution(true, baseType); + } + + bool isInt32 = type.SpecialType == SpecialType.System_Int32; + bool isUInt32 = type.SpecialType == SpecialType.System_UInt32; + bool isInt64 = type.SpecialType == SpecialType.System_Int64; + bool isUInt64 = type.SpecialType == SpecialType.System_UInt64; + + if (isInt32) + { + return encoding switch + { + FieldEncoding.Varint => new TypeResolution(true, baseType), + FieldEncoding.Fixed => new TypeResolution( + true, + new TypeClassification(4, true, true, false, false, false, 4)), + _ => new TypeResolution(false, baseType), + }; + } + + if (isUInt32) + { + return encoding switch + { + FieldEncoding.Varint => new TypeResolution(true, baseType), + FieldEncoding.Fixed => new TypeResolution( + true, + new TypeClassification(11, true, true, false, false, false, 4)), + _ => new TypeResolution(false, baseType), + }; + } + + if (isInt64) + { + return encoding switch + { + FieldEncoding.Varint => new TypeResolution(true, baseType), + FieldEncoding.Fixed => new TypeResolution( + true, + new TypeClassification(6, true, true, false, false, false, 8)), + FieldEncoding.Tagged => new TypeResolution( + true, + new TypeClassification(8, true, true, false, false, true, 8)), + _ => new TypeResolution(false, baseType), + }; + } + + if (isUInt64) + { + return encoding switch + { + FieldEncoding.Varint => new TypeResolution(true, baseType), + FieldEncoding.Fixed => new TypeResolution( + true, + new TypeClassification(13, true, true, false, false, false, 8)), + FieldEncoding.Tagged => new TypeResolution( + true, + new TypeClassification(15, true, true, false, false, true, 8)), + _ => new TypeResolution(false, baseType), + }; + } + + return new TypeResolution(false, baseType); + } + + private static TypeClassification ClassifyType(ITypeSymbol type) + { + if (ResolveDynamicAnyKind(type) == DynamicAnyKind.AnyValue) + { + return new TypeClassification(0, false, true, false, false, false, 0); + } + + if (type.SpecialType == SpecialType.System_Boolean) + { + return new TypeClassification(1, true, true, false, false, false, 1); + } + + if (type.SpecialType == SpecialType.System_SByte) + { + return new TypeClassification(2, true, true, false, false, false, 1); + } + + if (type.SpecialType == SpecialType.System_Int16) + { + return new TypeClassification(3, true, true, false, false, false, 2); + } + + if (type.SpecialType == SpecialType.System_Int32) + { + return new TypeClassification(5, true, true, false, false, true, 4); + } + + if (type.SpecialType == SpecialType.System_Int64) + { + return new TypeClassification(7, true, true, false, false, true, 8); + } + + if (type.SpecialType == SpecialType.System_Byte) + { + return new TypeClassification(9, true, true, false, false, false, 1); + } + + if (type.SpecialType == SpecialType.System_UInt16) + { + return new TypeClassification(10, true, true, false, false, false, 2); + } + + if (type.SpecialType == SpecialType.System_UInt32) + { + return new TypeClassification(12, true, true, false, false, true, 4); + } + + if (type.SpecialType == SpecialType.System_UInt64) + { + return new TypeClassification(14, true, true, false, false, true, 8); + } + + if (type.SpecialType == SpecialType.System_Single) + { + return new TypeClassification(19, true, true, false, false, false, 4); + } + + if (type.SpecialType == SpecialType.System_Double) + { + return new TypeClassification(20, true, true, false, false, false, 8); + } + + if (type.SpecialType == SpecialType.System_String) + { + return new TypeClassification(21, false, true, false, false, false, 0); + } + + if (IsDateType(type)) + { + return new TypeClassification(39, false, true, false, false, false, 0); + } + + if (IsTimestampType(type)) + { + return new TypeClassification(38, false, true, false, false, false, 0); + } + + if (IsDurationType(type)) + { + return new TypeClassification(37, false, true, false, false, false, 0); + } + + if (type is IArrayTypeSymbol arrayType) + { + TypeClassification elem = ClassifyType(arrayType.ElementType); + uint typeId = elem.TypeId switch + { + 9 => 41, + 1 => 43, + 2 => 44, + 3 => 45, + 5 => 46, + 7 => 47, + 10 => 49, + 12 => 50, + 14 => 51, + 19 => 55, + 20 => 56, + _ => 22, + }; + return new TypeClassification(typeId, false, true, true, false, false, 0); + } + + if (TryGetListElementType(type, out _)) + { + return new TypeClassification(22, false, true, true, false, false, 0); + } + + if (TryGetSetElementType(type, out _)) + { + return new TypeClassification(23, false, true, true, false, false, 0); + } + + if (TryGetMapTypeArguments(type, out _, out _)) + { + return new TypeClassification(24, false, true, false, true, false, 0); + } + + if (IsUnionType(type)) + { + return new TypeClassification(33, false, false, false, false, false, 0); + } + + return new TypeClassification(27, false, false, false, false, false, 0); + } + + private static DynamicAnyKind ResolveDynamicAnyKind(ITypeSymbol type) + { + if (type.SpecialType == SpecialType.System_Object) + { + return DynamicAnyKind.AnyValue; + } + + return DynamicAnyKind.None; + } + + private static bool IsDateType(ITypeSymbol symbol) + { + return string.Equals(symbol.ToDisplayString(), "System.DateOnly", StringComparison.Ordinal); + } + + private static bool IsTimestampType(ITypeSymbol symbol) + { + string name = symbol.ToDisplayString(); + return string.Equals(name, "System.DateTime", StringComparison.Ordinal) || + string.Equals(name, "System.DateTimeOffset", StringComparison.Ordinal); + } + + private static bool IsDurationType(ITypeSymbol symbol) + { + return string.Equals(symbol.ToDisplayString(), "System.TimeSpan", StringComparison.Ordinal); + } + + private static bool IsUnionType(ITypeSymbol symbol) + { + INamedTypeSymbol? current = symbol as INamedTypeSymbol; + while (current is not null) + { + if (string.Equals(current.ToDisplayString(), "Apache.Fory.Union", StringComparison.Ordinal)) + { + return true; + } + + current = current.BaseType; + } + + return false; + } + + private static bool TryGetListElementType(ITypeSymbol type, out ITypeSymbol? elementType) + { + elementType = null; + if (type is IArrayTypeSymbol arrayType) + { + elementType = arrayType.ElementType; + return true; + } + + if (type is not INamedTypeSymbol named) + { + return false; + } + + string genericName = named.ConstructedFrom.ToDisplayString(); + if (genericName is + "System.Collections.Generic.List" or + "System.Collections.Generic.IList" or + "System.Collections.Generic.IReadOnlyList") + { + elementType = named.TypeArguments[0]; + return true; + } + + return false; + } + + private static bool TryGetSetElementType(ITypeSymbol type, out ITypeSymbol? elementType) + { + elementType = null; + if (type is not INamedTypeSymbol named) + { + return false; + } + + string genericName = named.ConstructedFrom.ToDisplayString(); + if (genericName is + "System.Collections.Generic.HashSet" or + "System.Collections.Generic.ISet" or + "System.Collections.Generic.IReadOnlySet") + { + elementType = named.TypeArguments[0]; + return true; + } + + return false; + } + + private static bool TryGetMapTypeArguments(ITypeSymbol type, out ITypeSymbol? keyType, out ITypeSymbol? valueType) + { + keyType = null; + valueType = null; + if (type is not INamedTypeSymbol named) + { + return false; + } + + string genericName = named.ConstructedFrom.ToDisplayString(); + if (genericName is + "System.Collections.Generic.Dictionary" or + "System.Collections.Generic.IDictionary" or + "System.Collections.Generic.IReadOnlyDictionary" or + "Apache.Fory.NullableKeyDictionary") + { + keyType = named.TypeArguments[0]; + valueType = named.TypeArguments[1]; + return true; + } + + return false; + } + + private static (bool, ITypeSymbol) UnwrapNullable(ITypeSymbol type) + { + if (type is INamedTypeSymbol named && + named.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) + { + return (true, named.TypeArguments[0]); + } + + if (type.IsReferenceType && type.NullableAnnotation == NullableAnnotation.Annotated) + { + return (true, type.WithNullableAnnotation(NullableAnnotation.NotAnnotated)); + } + + return (false, type); + } + + private static string BoolLiteral(bool value) => value ? "true" : "false"; + + private static string EscapeString(string value) => value.Replace("\\", "\\\\").Replace("\"", "\\\""); + + private static string ToSnakeCase(string name) + { + if (string.IsNullOrEmpty(name)) + { + return name; + } + + StringBuilder sb = new(name.Length + 4); + for (int i = 0; i < name.Length; i++) + { + char c = name[i]; + if (char.IsUpper(c)) + { + if (i > 0) + { + bool prevUpper = char.IsUpper(name[i - 1]); + bool nextUpperOrEnd = i + 1 >= name.Length || char.IsUpper(name[i + 1]); + if (!prevUpper || !nextUpperOrEnd) + { + sb.Append('_'); + } + } + + sb.Append(char.ToLowerInvariant(c)); + } + else + { + sb.Append(c); + } + } + + return sb.ToString(); + } + + private static string Sanitize(string name) + { + StringBuilder sb = new(name.Length + 8); + foreach (char c in name) + { + sb.Append(char.IsLetterOrDigit(c) ? c : '_'); + } + + return sb.ToString(); + } + + private sealed class TypeResolution + { + public TypeResolution(bool supported, TypeClassification classification) + { + Supported = supported; + Classification = classification; + } + + public bool Supported { get; } + public TypeClassification Classification { get; } + } + + private sealed class TypeClassification + { + public TypeClassification( + uint typeId, + bool isPrimitive, + bool isBuiltIn, + bool isCollection, + bool isMap, + bool isCompressedNumeric, + int primitiveSize) + { + TypeId = typeId; + IsPrimitive = isPrimitive; + IsBuiltIn = isBuiltIn; + IsCollection = isCollection; + IsMap = isMap; + IsCompressedNumeric = isCompressedNumeric; + PrimitiveSize = primitiveSize; + } + + public uint TypeId { get; } + public bool IsPrimitive { get; } + public bool IsBuiltIn { get; } + public bool IsCollection { get; } + public bool IsMap { get; } + public bool IsCompressedNumeric { get; } + public int PrimitiveSize { get; } + } + + private sealed class TypeMetaFieldTypeModel + { + public TypeMetaFieldTypeModel( + string typeIdExpr, + bool nullable, + bool trackRefByContext, + ImmutableArray generics) + { + TypeIdExpr = typeIdExpr; + Nullable = nullable; + TrackRefByContext = trackRefByContext; + Generics = generics; + } + + public string TypeIdExpr { get; } + public bool Nullable { get; } + public bool TrackRefByContext { get; } + public ImmutableArray Generics { get; } + } + + private sealed class TypeModel + { + public TypeModel( + string typeName, + string serializerName, + DeclKind kind, + ImmutableArray members, + ImmutableArray sortedMembers) + { + TypeName = typeName; + SerializerName = serializerName; + Kind = kind; + Members = members; + SortedMembers = sortedMembers; + } + + public string TypeName { get; } + public string SerializerName { get; } + public DeclKind Kind { get; } + public ImmutableArray Members { get; } + public ImmutableArray SortedMembers { get; } + } + + private sealed class MemberModel + { + public MemberModel( + string name, + string fieldIdentifier, + int originalIndex, + MemberDeclKind declKind, + string typeName, + bool isNullable, + bool isNullableValueType, + TypeClassification classification, + int group, + bool isCollection, + DynamicAnyKind dynamicAnyKind, + TypeMetaFieldTypeModel typeMeta) + { + Name = name; + FieldIdentifier = fieldIdentifier; + OriginalIndex = originalIndex; + DeclKind = declKind; + TypeName = typeName; + IsNullable = isNullable; + IsNullableValueType = isNullableValueType; + Classification = classification; + Group = group; + IsCollection = isCollection; + DynamicAnyKind = dynamicAnyKind; + TypeMeta = typeMeta; + } + + public string Name { get; } + public string FieldIdentifier { get; } + public int OriginalIndex { get; } + public MemberDeclKind DeclKind { get; } + public string TypeName { get; } + public bool IsNullable { get; } + public bool IsNullableValueType { get; } + public TypeClassification Classification { get; } + public int Group { get; } + public bool IsCollection { get; } + public DynamicAnyKind DynamicAnyKind { get; } + public TypeMetaFieldTypeModel TypeMeta { get; } + } + + private enum MemberDeclKind + { + Field, + Property, + } + + private enum DeclKind + { + Unknown, + Class, + Struct, + Enum, + } + + private enum DynamicAnyKind + { + None, + AnyValue, + } + + private enum FieldEncoding + { + None = -1, + Varint = 0, + Fixed = 1, + Tagged = 2, + } +} diff --git a/csharp/src/Fory/AnySerializer.cs b/csharp/src/Fory/AnySerializer.cs new file mode 100644 index 0000000000..9ca4008b09 --- /dev/null +++ b/csharp/src/Fory/AnySerializer.cs @@ -0,0 +1,445 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public sealed class DynamicAnyObjectSerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Unknown; + public override bool IsNullableType => true; + public override bool IsReferenceTrackableType => true; + public override object? DefaultValue => null; + public override bool IsNone(in object? value) => value is null; + + public override void WriteData(ref WriteContext context, in object? value, bool hasGenerics) + { + if (IsNone(value)) + { + return; + } + + DynamicAnyCodec.WriteAnyPayload(value!, ref context, hasGenerics); + } + + public override object? ReadData(ref ReadContext context) + { + DynamicTypeInfo? dynamicTypeInfo = context.DynamicTypeInfo(typeof(object)); + if (dynamicTypeInfo is null) + { + throw new InvalidDataException("dynamic Any value requires type info"); + } + + return context.TypeResolver.ReadDynamicValue(dynamicTypeInfo, ref context); + } + + public override void WriteTypeInfo(ref WriteContext context) + { + throw new InvalidDataException("dynamic Any value type info is runtime-only"); + } + + public override void ReadTypeInfo(ref ReadContext context) + { + DynamicTypeInfo typeInfo = context.TypeResolver.ReadDynamicTypeInfo(ref context); + context.SetDynamicTypeInfo(typeof(object), typeInfo); + } + + public override void Write(ref WriteContext context, in object? value, RefMode refMode, bool writeTypeInfo, bool hasGenerics) + { + if (refMode != RefMode.None) + { + if (IsNone(value)) + { + context.Writer.WriteInt8((sbyte)RefFlag.Null); + return; + } + + bool wroteTrackingRefFlag = false; + if (refMode == RefMode.Tracking && AnyValueIsReferenceTrackable(value!, context.TypeResolver)) + { + if (context.RefWriter.TryWriteReference(context.Writer, value!)) + { + return; + } + + wroteTrackingRefFlag = true; + } + + if (!wroteTrackingRefFlag) + { + context.Writer.WriteInt8((sbyte)RefFlag.NotNullValue); + } + } + + if (writeTypeInfo) + { + DynamicAnyCodec.WriteAnyTypeInfo(value!, ref context); + } + + WriteData(ref context, value, hasGenerics); + } + + public override object? Read(ref ReadContext context, RefMode refMode, bool readTypeInfo) + { + if (refMode != RefMode.None) + { + sbyte rawFlag = context.Reader.ReadInt8(); + RefFlag flag = (RefFlag)rawFlag; + switch (flag) + { + case RefFlag.Null: + return null; + case RefFlag.Ref: + { + uint refId = context.Reader.ReadVarUInt32(); + return context.RefReader.ReadRefValue(refId); + } + case RefFlag.RefValue: + { + uint reservedRefId = context.RefReader.ReserveRefId(); + context.PushPendingReference(reservedRefId); + if (readTypeInfo) + { + ReadTypeInfo(ref context); + } + + object? value = ReadData(ref context); + if (readTypeInfo) + { + context.ClearDynamicTypeInfo(typeof(object)); + } + + context.FinishPendingReferenceIfNeeded(value); + context.PopPendingReference(); + return value; + } + case RefFlag.NotNullValue: + break; + default: + throw new RefException($"invalid ref flag {rawFlag}"); + } + } + + if (readTypeInfo) + { + ReadTypeInfo(ref context); + } + + object? result = ReadData(ref context); + if (readTypeInfo) + { + context.ClearDynamicTypeInfo(typeof(object)); + } + + return result; + } + + private static bool AnyValueIsReferenceTrackable(object value, TypeResolver typeResolver) + { + Serializer serializer = typeResolver.GetSerializer(value.GetType()); + return serializer.IsReferenceTrackableType; + } +} + +public static class DynamicAnyCodec +{ + internal static void WriteAnyTypeInfo(object value, ref WriteContext context) + { + if (DynamicContainerCodec.TryGetTypeId(value, out TypeId containerTypeId)) + { + context.Writer.WriteUInt8((byte)containerTypeId); + return; + } + + if (TryWriteKnownTypeInfo(value, ref context)) + { + return; + } + + Serializer serializer = context.TypeResolver.GetSerializer(value.GetType()); + serializer.WriteTypeInfo(ref context); + } + + public static object? CastAnyDynamicValue(object? value, Type targetType) + { + if (value is null) + { + if (targetType == typeof(object)) + { + return null; + } + + if (!targetType.IsValueType || Nullable.GetUnderlyingType(targetType) is not null) + { + return null; + } + + throw new InvalidDataException($"cannot cast null dynamic Any value to non-nullable {targetType}"); + } + + if (targetType.IsInstanceOfType(value)) + { + return value; + } + + throw new InvalidDataException($"cannot cast dynamic Any value to {targetType}"); + } + + public static void WriteAny(ref WriteContext context, object? value, RefMode refMode, bool writeTypeInfo = true, bool hasGenerics = false) + { + context.TypeResolver.GetSerializer().Write(ref context, value, refMode, writeTypeInfo, hasGenerics); + } + + public static object? ReadAny(ref ReadContext context, RefMode refMode, bool readTypeInfo = true) + { + return context.TypeResolver.GetSerializer().Read(ref context, refMode, readTypeInfo); + } + + public static void WriteAnyPayload(object value, ref WriteContext context, bool hasGenerics) + { + if (DynamicContainerCodec.TryWritePayload(value, ref context, hasGenerics)) + { + return; + } + + if (TryWriteKnownPayload(value, ref context)) + { + return; + } + + Serializer serializer = context.TypeResolver.GetSerializer(value.GetType()); + serializer.WriteDataObject(ref context, value, hasGenerics); + } + + private static bool TryWriteKnownTypeInfo(object value, ref WriteContext context) + { + switch (value) + { + case bool: + context.Writer.WriteUInt8((byte)TypeId.Bool); + return true; + case sbyte: + context.Writer.WriteUInt8((byte)TypeId.Int8); + return true; + case short: + context.Writer.WriteUInt8((byte)TypeId.Int16); + return true; + case int: + context.Writer.WriteUInt8((byte)TypeId.VarInt32); + return true; + case long: + context.Writer.WriteUInt8((byte)TypeId.VarInt64); + return true; + case byte: + context.Writer.WriteUInt8((byte)TypeId.UInt8); + return true; + case ushort: + context.Writer.WriteUInt8((byte)TypeId.UInt16); + return true; + case uint: + context.Writer.WriteUInt8((byte)TypeId.VarUInt32); + return true; + case ulong: + context.Writer.WriteUInt8((byte)TypeId.VarUInt64); + return true; + case float: + context.Writer.WriteUInt8((byte)TypeId.Float32); + return true; + case double: + context.Writer.WriteUInt8((byte)TypeId.Float64); + return true; + case string: + context.Writer.WriteUInt8((byte)TypeId.String); + return true; + case byte[]: + context.Writer.WriteUInt8((byte)TypeId.Binary); + return true; + case bool[]: + context.Writer.WriteUInt8((byte)TypeId.BoolArray); + return true; + case sbyte[]: + context.Writer.WriteUInt8((byte)TypeId.Int8Array); + return true; + case short[]: + context.Writer.WriteUInt8((byte)TypeId.Int16Array); + return true; + case int[]: + context.Writer.WriteUInt8((byte)TypeId.Int32Array); + return true; + case long[]: + context.Writer.WriteUInt8((byte)TypeId.Int64Array); + return true; + case ushort[]: + context.Writer.WriteUInt8((byte)TypeId.UInt16Array); + return true; + case uint[]: + context.Writer.WriteUInt8((byte)TypeId.UInt32Array); + return true; + case ulong[]: + context.Writer.WriteUInt8((byte)TypeId.UInt64Array); + return true; + case float[]: + context.Writer.WriteUInt8((byte)TypeId.Float32Array); + return true; + case double[]: + context.Writer.WriteUInt8((byte)TypeId.Float64Array); + return true; + case DateOnly: + context.Writer.WriteUInt8((byte)TypeId.Date); + return true; + case DateTimeOffset: + case DateTime: + context.Writer.WriteUInt8((byte)TypeId.Timestamp); + return true; + case TimeSpan: + context.Writer.WriteUInt8((byte)TypeId.Duration); + return true; + default: + return false; + } + } + + private static bool TryWriteKnownPayload(object value, ref WriteContext context) + { + switch (value) + { + case bool v: + context.Writer.WriteUInt8(v ? (byte)1 : (byte)0); + return true; + case sbyte v: + context.Writer.WriteInt8(v); + return true; + case short v: + context.Writer.WriteInt16(v); + return true; + case int v: + context.Writer.WriteVarInt32(v); + return true; + case long v: + context.Writer.WriteVarInt64(v); + return true; + case byte v: + context.Writer.WriteUInt8(v); + return true; + case ushort v: + context.Writer.WriteUInt16(v); + return true; + case uint v: + context.Writer.WriteVarUInt32(v); + return true; + case ulong v: + context.Writer.WriteVarUInt64(v); + return true; + case float v: + context.Writer.WriteFloat32(v); + return true; + case double v: + context.Writer.WriteFloat64(v); + return true; + case string v: + StringSerializer.WriteString(ref context, v); + return true; + case DateOnly v: + TimeCodec.WriteDate(ref context, v); + return true; + case DateTimeOffset v: + TimeCodec.WriteTimestamp(ref context, v); + return true; + case DateTime v: + TimeCodec.WriteTimestamp(ref context, TimeCodec.ToDateTimeOffset(v)); + return true; + case TimeSpan v: + TimeCodec.WriteDuration(ref context, v); + return true; + case byte[] v: + context.Writer.WriteVarUInt32((uint)v.Length); + context.Writer.WriteBytes(v); + return true; + case bool[] v: + context.Writer.WriteVarUInt32((uint)v.Length); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteUInt8(v[i] ? (byte)1 : (byte)0); + } + return true; + case sbyte[] v: + context.Writer.WriteVarUInt32((uint)v.Length); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteInt8(v[i]); + } + return true; + case short[] v: + context.Writer.WriteVarUInt32((uint)(v.Length * 2)); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteInt16(v[i]); + } + return true; + case int[] v: + context.Writer.WriteVarUInt32((uint)(v.Length * 4)); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteInt32(v[i]); + } + return true; + case long[] v: + context.Writer.WriteVarUInt32((uint)(v.Length * 8)); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteInt64(v[i]); + } + return true; + case ushort[] v: + context.Writer.WriteVarUInt32((uint)(v.Length * 2)); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteUInt16(v[i]); + } + return true; + case uint[] v: + context.Writer.WriteVarUInt32((uint)(v.Length * 4)); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteUInt32(v[i]); + } + return true; + case ulong[] v: + context.Writer.WriteVarUInt32((uint)(v.Length * 8)); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteUInt64(v[i]); + } + return true; + case float[] v: + context.Writer.WriteVarUInt32((uint)(v.Length * 4)); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteFloat32(v[i]); + } + return true; + case double[] v: + context.Writer.WriteVarUInt32((uint)(v.Length * 8)); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteFloat64(v[i]); + } + return true; + default: + return false; + } + } +} diff --git a/csharp/src/Fory/Attributes.cs b/csharp/src/Fory/Attributes.cs new file mode 100644 index 0000000000..84fe5dc380 --- /dev/null +++ b/csharp/src/Fory/Attributes.cs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Enum)] +public sealed class ForyObjectAttribute : Attribute +{ +} + +public enum FieldEncoding +{ + Varint, + Fixed, + Tagged, +} + +[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] +public sealed class FieldAttribute : Attribute +{ + public FieldEncoding Encoding { get; set; } = FieldEncoding.Varint; +} diff --git a/csharp/src/Fory/ByteBuffer.cs b/csharp/src/Fory/ByteBuffer.cs new file mode 100644 index 0000000000..4d7e83e14a --- /dev/null +++ b/csharp/src/Fory/ByteBuffer.cs @@ -0,0 +1,418 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Buffers.Binary; + +namespace Apache.Fory; + +public sealed class ByteWriter +{ + private readonly List _storage; + + public ByteWriter(int capacity = 256) + { + _storage = new List(capacity); + } + + public int Count => _storage.Count; + + public IReadOnlyList Storage => _storage; + + public void Reserve(int additional) + { + _storage.Capacity = Math.Max(_storage.Capacity, _storage.Count + additional); + } + + public void WriteUInt8(byte value) + { + _storage.Add(value); + } + + public void WriteInt8(sbyte value) + { + _storage.Add(unchecked((byte)value)); + } + + public void WriteUInt16(ushort value) + { + Span tmp = stackalloc byte[2]; + BinaryPrimitives.WriteUInt16LittleEndian(tmp, value); + WriteBytes(tmp); + } + + public void WriteInt16(short value) + { + WriteUInt16(unchecked((ushort)value)); + } + + public void WriteUInt32(uint value) + { + Span tmp = stackalloc byte[4]; + BinaryPrimitives.WriteUInt32LittleEndian(tmp, value); + WriteBytes(tmp); + } + + public void WriteInt32(int value) + { + WriteUInt32(unchecked((uint)value)); + } + + public void WriteUInt64(ulong value) + { + Span tmp = stackalloc byte[8]; + BinaryPrimitives.WriteUInt64LittleEndian(tmp, value); + WriteBytes(tmp); + } + + public void WriteInt64(long value) + { + WriteUInt64(unchecked((ulong)value)); + } + + public void WriteVarUInt32(uint value) + { + uint remaining = value; + while (remaining >= 0x80) + { + WriteUInt8((byte)((remaining & 0x7F) | 0x80)); + remaining >>= 7; + } + + WriteUInt8((byte)remaining); + } + + public void WriteVarUInt64(ulong value) + { + ulong remaining = value; + for (var i = 0; i < 8; i++) + { + if (remaining < 0x80) + { + WriteUInt8((byte)remaining); + return; + } + + WriteUInt8((byte)((remaining & 0x7F) | 0x80)); + remaining >>= 7; + } + + WriteUInt8((byte)(remaining & 0xFF)); + } + + public void WriteVarUInt36Small(ulong value) + { + if (value >= (1UL << 36)) + { + throw new EncodingException("varuint36small overflow"); + } + + WriteVarUInt64(value); + } + + public void WriteVarInt32(int value) + { + uint zigzag = unchecked((uint)((value << 1) ^ (value >> 31))); + WriteVarUInt32(zigzag); + } + + public void WriteVarInt64(long value) + { + ulong zigzag = unchecked((ulong)((value << 1) ^ (value >> 63))); + WriteVarUInt64(zigzag); + } + + public void WriteTaggedInt64(long value) + { + if (value >= -1_073_741_824L && value <= 1_073_741_823L) + { + WriteInt32(unchecked((int)value << 1)); + return; + } + + WriteUInt8(0x01); + WriteInt64(value); + } + + public void WriteTaggedUInt64(ulong value) + { + if (value <= int.MaxValue) + { + WriteUInt32(unchecked((uint)value << 1)); + return; + } + + WriteUInt8(0x01); + WriteUInt64(value); + } + + public void WriteFloat32(float value) + { + WriteUInt32(unchecked((uint)BitConverter.SingleToInt32Bits(value))); + } + + public void WriteFloat64(double value) + { + WriteUInt64(unchecked((ulong)BitConverter.DoubleToInt64Bits(value))); + } + + public void WriteBytes(ReadOnlySpan bytes) + { + for (int i = 0; i < bytes.Length; i++) + { + _storage.Add(bytes[i]); + } + } + + public void SetByte(int index, byte value) + { + _storage[index] = value; + } + + public void SetBytes(int index, ReadOnlySpan bytes) + { + for (var i = 0; i < bytes.Length; i++) + { + _storage[index + i] = bytes[i]; + } + } + + public byte[] ToArray() + { + return _storage.ToArray(); + } + + public void Reset() + { + _storage.Clear(); + } +} + +public sealed class ByteReader +{ + private readonly byte[] _storage; + private int _cursor; + + public ByteReader(ReadOnlySpan data) + { + _storage = data.ToArray(); + _cursor = 0; + } + + public ByteReader(byte[] bytes) + { + _storage = bytes; + _cursor = 0; + } + + public byte[] Storage => _storage; + + public int Cursor => _cursor; + + public int Remaining => _storage.Length - _cursor; + + public void SetCursor(int value) + { + _cursor = value; + } + + public void MoveBack(int amount) + { + _cursor -= amount; + } + + public void CheckBound(int need) + { + if (_cursor + need > _storage.Length) + { + throw new OutOfBoundsException(_cursor, need, _storage.Length); + } + } + + public byte ReadUInt8() + { + CheckBound(1); + byte value = _storage[_cursor]; + _cursor += 1; + return value; + } + + public sbyte ReadInt8() + { + return unchecked((sbyte)ReadUInt8()); + } + + public ushort ReadUInt16() + { + CheckBound(2); + ushort value = BinaryPrimitives.ReadUInt16LittleEndian(_storage.AsSpan(_cursor, 2)); + _cursor += 2; + return value; + } + + public short ReadInt16() + { + return unchecked((short)ReadUInt16()); + } + + public uint ReadUInt32() + { + CheckBound(4); + uint value = BinaryPrimitives.ReadUInt32LittleEndian(_storage.AsSpan(_cursor, 4)); + _cursor += 4; + return value; + } + + public int ReadInt32() + { + return unchecked((int)ReadUInt32()); + } + + public ulong ReadUInt64() + { + CheckBound(8); + ulong value = BinaryPrimitives.ReadUInt64LittleEndian(_storage.AsSpan(_cursor, 8)); + _cursor += 8; + return value; + } + + public long ReadInt64() + { + return unchecked((long)ReadUInt64()); + } + + public uint ReadVarUInt32() + { + uint result = 0; + var shift = 0; + while (true) + { + byte b = ReadUInt8(); + result |= (uint)(b & 0x7F) << shift; + if ((b & 0x80) == 0) + { + return result; + } + + shift += 7; + if (shift > 28) + { + throw new EncodingException("varuint32 overflow"); + } + } + } + + public ulong ReadVarUInt64() + { + ulong result = 0; + var shift = 0; + for (var i = 0; i < 8; i++) + { + byte b = ReadUInt8(); + result |= (ulong)(b & 0x7F) << shift; + if ((b & 0x80) == 0) + { + return result; + } + + shift += 7; + } + + byte last = ReadUInt8(); + result |= (ulong)last << 56; + return result; + } + + public ulong ReadVarUInt36Small() + { + ulong value = ReadVarUInt64(); + if (value >= (1UL << 36)) + { + throw new EncodingException("varuint36small overflow"); + } + + return value; + } + + public int ReadVarInt32() + { + uint encoded = ReadVarUInt32(); + return unchecked((int)((encoded >> 1) ^ (~(encoded & 1) + 1))); + } + + public long ReadVarInt64() + { + ulong encoded = ReadVarUInt64(); + return unchecked((long)((encoded >> 1) ^ (~(encoded & 1UL) + 1UL))); + } + + public long ReadTaggedInt64() + { + int first = ReadInt32(); + if ((first & 1) == 0) + { + return first >> 1; + } + + MoveBack(3); + return ReadInt64(); + } + + public ulong ReadTaggedUInt64() + { + uint first = ReadUInt32(); + if ((first & 1) == 0) + { + return first >> 1; + } + + MoveBack(3); + return ReadUInt64(); + } + + public float ReadFloat32() + { + return BitConverter.Int32BitsToSingle(unchecked((int)ReadUInt32())); + } + + public double ReadFloat64() + { + return BitConverter.Int64BitsToDouble(unchecked((long)ReadUInt64())); + } + + public byte[] ReadBytes(int count) + { + CheckBound(count); + byte[] result = new byte[count]; + Array.Copy(_storage, _cursor, result, 0, count); + _cursor += count; + return result; + } + + public ReadOnlySpan ReadSpan(int count) + { + CheckBound(count); + ReadOnlySpan span = _storage.AsSpan(_cursor, count); + _cursor += count; + return span; + } + + public void Skip(int count) + { + CheckBound(count); + _cursor += count; + } +} diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs new file mode 100644 index 0000000000..c1b026e33a --- /dev/null +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -0,0 +1,457 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Collections; + +namespace Apache.Fory; + +internal static class CollectionBits +{ + public const byte TrackingRef = 0b0000_0001; + public const byte HasNull = 0b0000_0010; + public const byte DeclaredElementType = 0b0000_0100; + public const byte SameType = 0b0000_1000; +} + + +internal static class CollectionCodec +{ + public static void WriteCollectionData( + IEnumerable values, + Serializer elementSerializer, + ref WriteContext context, + bool hasGenerics) + { + List list = values as List ?? [.. values]; + context.Writer.WriteVarUInt32((uint)list.Count); + if (list.Count == 0) + { + return; + } + + bool hasNull = false; + if (elementSerializer.IsNullableType) + { + for (int i = 0; i < list.Count; i++) + { + if (!elementSerializer.IsNoneObject(list[i])) + { + continue; + } + + hasNull = true; + break; + } + } + + bool trackRef = context.TrackRef && elementSerializer.IsReferenceTrackableType; + bool declaredElementType = hasGenerics && !elementSerializer.StaticTypeId.NeedsTypeInfoForField(); + bool dynamicElementType = elementSerializer.StaticTypeId == TypeId.Unknown; + + byte header = dynamicElementType ? (byte)0 : CollectionBits.SameType; + if (trackRef) + { + header |= CollectionBits.TrackingRef; + } + + if (hasNull) + { + header |= CollectionBits.HasNull; + } + + if (declaredElementType) + { + header |= CollectionBits.DeclaredElementType; + } + + context.Writer.WriteUInt8(header); + if (!dynamicElementType && !declaredElementType) + { + elementSerializer.WriteTypeInfo(ref context); + } + + if (dynamicElementType) + { + RefMode refMode = trackRef ? RefMode.Tracking : hasNull ? RefMode.NullOnly : RefMode.None; + foreach (T element in list) + { + elementSerializer.Write(ref context, element, refMode, true, hasGenerics); + } + + return; + } + + if (trackRef) + { + foreach (T element in list) + { + elementSerializer.Write(ref context, element, RefMode.Tracking, false, hasGenerics); + } + + return; + } + + if (hasNull) + { + foreach (T element in list) + { + if (elementSerializer.IsNoneObject(element)) + { + context.Writer.WriteInt8((sbyte)RefFlag.Null); + } + else + { + context.Writer.WriteInt8((sbyte)RefFlag.NotNullValue); + elementSerializer.WriteData(ref context, element, hasGenerics); + } + } + + return; + } + + foreach (T element in list) + { + elementSerializer.WriteData(ref context, element, hasGenerics); + } + } + + public static List ReadCollectionData(Serializer elementSerializer, ref ReadContext context) + { + int length = checked((int)context.Reader.ReadVarUInt32()); + if (length == 0) + { + return []; + } + + byte header = context.Reader.ReadUInt8(); + bool trackRef = (header & CollectionBits.TrackingRef) != 0; + bool hasNull = (header & CollectionBits.HasNull) != 0; + bool declared = (header & CollectionBits.DeclaredElementType) != 0; + bool sameType = (header & CollectionBits.SameType) != 0; + bool canonicalizeElements = context.TrackRef && !trackRef && elementSerializer.IsReferenceTrackableType; + + List values = new(length); + if (!sameType) + { + if (trackRef) + { + for (int i = 0; i < length; i++) + { + values.Add(elementSerializer.Read(ref context, RefMode.Tracking, true)); + } + + return values; + } + + if (hasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.Add((T)elementSerializer.DefaultObject!); + } + else if (refFlag == (sbyte)RefFlag.NotNullValue) + { + values.Add(ReadCollectionElementWithCanonicalization(elementSerializer, ref context, true, canonicalizeElements)); + } + else + { + throw new RefException($"invalid nullability flag {refFlag}"); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.Add(ReadCollectionElementWithCanonicalization(elementSerializer, ref context, true, canonicalizeElements)); + } + } + + return values; + } + + if (!declared) + { + elementSerializer.ReadTypeInfo(ref context); + } + + if (trackRef) + { + for (int i = 0; i < length; i++) + { + values.Add(elementSerializer.Read(ref context, RefMode.Tracking, false)); + } + + if (!declared) + { + context.ClearDynamicTypeInfo(typeof(T)); + } + + return values; + } + + if (hasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.Add((T)elementSerializer.DefaultObject!); + } + else + { + values.Add(ReadCollectionElementDataWithCanonicalization(elementSerializer, ref context, canonicalizeElements)); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.Add(ReadCollectionElementDataWithCanonicalization(elementSerializer, ref context, canonicalizeElements)); + } + } + + if (!declared) + { + context.ClearDynamicTypeInfo(typeof(T)); + } + + return values; + } + + private static T ReadCollectionElementWithCanonicalization( + Serializer elementSerializer, + ref ReadContext context, + bool readTypeInfo, + bool canonicalize) + { + if (!canonicalize) + { + return elementSerializer.Read(ref context, RefMode.None, readTypeInfo); + } + + int start = context.Reader.Cursor; + T value = elementSerializer.Read(ref context, RefMode.None, readTypeInfo); + int end = context.Reader.Cursor; + return context.CanonicalizeNonTrackingReference(value, start, end); + } + + private static T ReadCollectionElementDataWithCanonicalization( + Serializer elementSerializer, + ref ReadContext context, + bool canonicalize) + { + if (!canonicalize) + { + return elementSerializer.ReadData(ref context); + } + + int start = context.Reader.Cursor; + T value = elementSerializer.ReadData(ref context); + int end = context.Reader.Cursor; + return context.CanonicalizeNonTrackingReference(value, start, end); + } +} + +internal static class DynamicContainerCodec +{ + public static bool TryGetTypeId(object value, out TypeId typeId) + { + if (value is IDictionary) + { + typeId = TypeId.Map; + return true; + } + + Type valueType = value.GetType(); + if (value is IList && !valueType.IsArray) + { + typeId = TypeId.List; + return true; + } + + if (IsSet(valueType)) + { + typeId = TypeId.Set; + return true; + } + + typeId = default; + return false; + } + + public static bool TryWritePayload(object value, ref WriteContext context, bool hasGenerics) + { + if (value is IDictionary dictionary) + { + NullableKeyDictionary map = new(); + foreach (DictionaryEntry entry in dictionary) + { + map.Add(entry.Key, entry.Value); + } + + context.TypeResolver.GetSerializer>().WriteData(ref context, map, false); + return true; + } + + Type valueType = value.GetType(); + if (value is IList list && !valueType.IsArray) + { + List values = new(list.Count); + for (int i = 0; i < list.Count; i++) + { + values.Add(list[i]); + } + + context.TypeResolver.GetSerializer>().WriteData(ref context, values, hasGenerics); + return true; + } + + if (!IsSet(valueType)) + { + return false; + } + + HashSet set = []; + foreach (object? item in (IEnumerable)value) + { + set.Add(item); + } + + context.TypeResolver.GetSerializer>().WriteData(ref context, set, hasGenerics); + return true; + } + + public static List ReadListPayload(ref ReadContext context) + { + return context.TypeResolver.GetSerializer>().ReadData(ref context); + } + + public static HashSet ReadSetPayload(ref ReadContext context) + { + return context.TypeResolver.GetSerializer>().ReadData(ref context); + } + + public static object ReadMapPayload(ref ReadContext context) + { + NullableKeyDictionary map = context.TypeResolver.GetSerializer>().ReadData(ref context); + if (map.HasNullKey) + { + return map; + } + + return new Dictionary(map.NonNullEntries); + } + + private static bool IsSet(Type valueType) + { + if (!valueType.IsGenericType) + { + return false; + } + + if (valueType.GetGenericTypeDefinition() == typeof(ISet<>)) + { + return true; + } + + foreach (Type iface in valueType.GetInterfaces()) + { + if (!iface.IsGenericType) + { + continue; + } + + if (iface.GetGenericTypeDefinition() == typeof(ISet<>)) + { + return true; + } + } + + return false; + } +} + +public sealed class ArraySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.List; + public override bool IsNullableType => true; + public override bool IsReferenceTrackableType => true; + public override T[] DefaultValue => null!; + public override bool IsNone(in T[] value) => value is null; + + public override void WriteData(ref WriteContext context, in T[] value, bool hasGenerics) + { + T[] safe = value ?? []; + CollectionCodec.WriteCollectionData( + safe, + context.TypeResolver.GetSerializer(), + ref context, + hasGenerics); + } + + public override T[] ReadData(ref ReadContext context) + { + List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), ref context); + return values.ToArray(); + } +} + +public class ListSerializer : Serializer> +{ + public override TypeId StaticTypeId => TypeId.List; + public override bool IsNullableType => true; + public override bool IsReferenceTrackableType => true; + public override List DefaultValue => null!; + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List safe = value ?? []; + CollectionCodec.WriteCollectionData(safe, context.TypeResolver.GetSerializer(), ref context, hasGenerics); + } + + public override List ReadData(ref ReadContext context) + { + return CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), ref context); + } +} + +public sealed class SetSerializer : Serializer> where T : notnull +{ + public override TypeId StaticTypeId => TypeId.Set; + public override bool IsNullableType => true; + public override bool IsReferenceTrackableType => true; + public override HashSet DefaultValue => null!; + public override bool IsNone(in HashSet value) => value is null; + + public override void WriteData(ref WriteContext context, in HashSet value, bool hasGenerics) + { + List list = value is null ? [] : [.. value]; + context.TypeResolver.GetSerializer>().WriteData(ref context, list, hasGenerics); + } + + public override HashSet ReadData(ref ReadContext context) + { + return [.. context.TypeResolver.GetSerializer>().ReadData(ref context)]; + } +} diff --git a/csharp/src/Fory/Config.cs b/csharp/src/Fory/Config.cs new file mode 100644 index 0000000000..5fd8f2609c --- /dev/null +++ b/csharp/src/Fory/Config.cs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public sealed record Config( + bool Xlang = true, + bool TrackRef = false, + bool Compatible = false, + bool CheckStructVersion = false, + bool EnableReflectionFallback = false, + int MaxDepth = 512); + +public sealed class ForyBuilder +{ + private bool _xlang = true; + private bool _trackRef; + private bool _compatible; + private bool _checkStructVersion; + private bool _enableReflectionFallback; + private int _maxDepth = 512; + + public ForyBuilder Xlang(bool enabled = true) + { + _xlang = enabled; + return this; + } + + public ForyBuilder TrackRef(bool enabled = false) + { + _trackRef = enabled; + return this; + } + + public ForyBuilder Compatible(bool enabled = false) + { + _compatible = enabled; + return this; + } + + public ForyBuilder CheckStructVersion(bool enabled = false) + { + _checkStructVersion = enabled; + return this; + } + + public ForyBuilder EnableReflectionFallback(bool enabled = false) + { + _enableReflectionFallback = enabled; + return this; + } + + public ForyBuilder MaxDepth(int value) + { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "MaxDepth must be greater than 0."); + } + + _maxDepth = value; + return this; + } + + public Fory Build() + { + return new Fory( + new Config( + Xlang: _xlang, + TrackRef: _trackRef, + Compatible: _compatible, + CheckStructVersion: _checkStructVersion, + EnableReflectionFallback: _enableReflectionFallback, + MaxDepth: _maxDepth)); + } +} + diff --git a/csharp/src/Fory/Context.cs b/csharp/src/Fory/Context.cs new file mode 100644 index 0000000000..9dafa1c1d2 --- /dev/null +++ b/csharp/src/Fory/Context.cs @@ -0,0 +1,397 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public sealed class CompatibleTypeDefWriteState +{ + private readonly Dictionary _typeIndexByType = []; + private uint _nextIndex; + + public uint? LookupIndex(Type type) + { + return _typeIndexByType.TryGetValue(type, out uint idx) ? idx : null; + } + + public (uint Index, bool IsNew) AssignIndexIfAbsent(Type type) + { + if (_typeIndexByType.TryGetValue(type, out uint existing)) + { + return (existing, false); + } + + uint index = _nextIndex; + _nextIndex += 1; + _typeIndexByType[type] = index; + return (index, true); + } + + public void Reset() + { + _typeIndexByType.Clear(); + _nextIndex = 0; + } +} + +public sealed class CompatibleTypeDefReadState +{ + private readonly List _typeMetas = []; + + public TypeMeta? TypeMetaAt(int index) + { + return index >= 0 && index < _typeMetas.Count ? _typeMetas[index] : null; + } + + public void StoreTypeMeta(TypeMeta typeMeta, int index) + { + if (index < 0) + { + throw new InvalidDataException("negative compatible type definition index"); + } + + if (index == _typeMetas.Count) + { + _typeMetas.Add(typeMeta); + return; + } + + if (index < _typeMetas.Count) + { + _typeMetas[index] = typeMeta; + return; + } + + throw new InvalidDataException( + $"compatible type definition index gap: index={index}, count={_typeMetas.Count}"); + } + + public void Reset() + { + _typeMetas.Clear(); + } +} + +public sealed class MetaStringWriteState +{ + private readonly Dictionary _stringIndexByKey = []; + private uint _nextIndex; + + public uint? Index(MetaString value) + { + return _stringIndexByKey.TryGetValue(value, out uint index) ? index : null; + } + + public (uint Index, bool IsNew) AssignIndexIfAbsent(MetaString value) + { + if (_stringIndexByKey.TryGetValue(value, out uint existing)) + { + return (existing, false); + } + + uint index = _nextIndex; + _nextIndex += 1; + _stringIndexByKey[value] = index; + return (index, true); + } + + public void Reset() + { + _stringIndexByKey.Clear(); + _nextIndex = 0; + } +} + +public sealed class MetaStringReadState +{ + private readonly List _values = []; + + public MetaString? ValueAt(int index) + { + return index >= 0 && index < _values.Count ? _values[index] : null; + } + + public void Append(MetaString value) + { + _values.Add(value); + } + + public void Reset() + { + _values.Clear(); + } +} + +public sealed record DynamicTypeInfo( + TypeId WireTypeId, + uint? UserTypeId, + MetaString? NamespaceName, + MetaString? TypeName, + TypeMeta? CompatibleTypeMeta); + +public readonly struct WriteContext +{ + public WriteContext( + ByteWriter writer, + TypeResolver typeResolver, + bool trackRef, + bool compatible = false, + CompatibleTypeDefWriteState? compatibleTypeDefState = null, + MetaStringWriteState? metaStringWriteState = null) + { + Writer = writer; + TypeResolver = typeResolver; + TrackRef = trackRef; + Compatible = compatible; + RefWriter = new RefWriter(); + CompatibleTypeDefState = compatibleTypeDefState ?? new CompatibleTypeDefWriteState(); + MetaStringWriteState = metaStringWriteState ?? new MetaStringWriteState(); + } + + public ByteWriter Writer { get; } + + public TypeResolver TypeResolver { get; } + + public bool TrackRef { get; } + + public bool Compatible { get; } + + public RefWriter RefWriter { get; } + + public CompatibleTypeDefWriteState CompatibleTypeDefState { get; } + + public MetaStringWriteState MetaStringWriteState { get; } + + public void WriteCompatibleTypeMeta(Type type, TypeMeta typeMeta) + { + (uint index, bool isNew) = CompatibleTypeDefState.AssignIndexIfAbsent(type); + if (isNew) + { + Writer.WriteVarUInt32(index << 1); + Writer.WriteBytes(typeMeta.Encode()); + } + else + { + Writer.WriteVarUInt32((index << 1) | 1); + } + } + + public void ResetObjectState() + { + RefWriter.Reset(); + } + + public void Reset() + { + ResetObjectState(); + CompatibleTypeDefState.Reset(); + MetaStringWriteState.Reset(); + } +} + +internal readonly record struct PendingRefSlot(uint RefId, bool Bound); + +internal readonly record struct CanonicalReferenceSignature( + Type Type, + ulong HashLo, + ulong HashHi, + int Length); + +internal sealed class CanonicalReferenceEntry +{ + public required byte[] Bytes { get; init; } + public required object Object { get; init; } +} + +public sealed class ReadContext +{ + private readonly List _pendingRefStack = []; + private readonly Dictionary> _pendingCompatibleTypeMeta = []; + private readonly Dictionary _pendingDynamicTypeInfo = []; + private readonly Dictionary> _canonicalReferenceCache = []; + + public ReadContext( + ByteReader reader, + TypeResolver typeResolver, + bool trackRef, + bool compatible = false, + CompatibleTypeDefReadState? compatibleTypeDefState = null, + MetaStringReadState? metaStringReadState = null) + { + Reader = reader; + TypeResolver = typeResolver; + TrackRef = trackRef; + Compatible = compatible; + RefReader = new RefReader(); + CompatibleTypeDefState = compatibleTypeDefState ?? new CompatibleTypeDefReadState(); + MetaStringReadState = metaStringReadState ?? new MetaStringReadState(); + } + + public ByteReader Reader { get; } + + public TypeResolver TypeResolver { get; } + + public bool TrackRef { get; } + + public bool Compatible { get; } + + public RefReader RefReader { get; } + + public CompatibleTypeDefReadState CompatibleTypeDefState { get; } + + public MetaStringReadState MetaStringReadState { get; } + + public void PushPendingReference(uint refId) + { + _pendingRefStack.Add(new PendingRefSlot(refId, false)); + } + + public void BindPendingReference(object? value) + { + if (_pendingRefStack.Count == 0) + { + return; + } + + PendingRefSlot last = _pendingRefStack[^1]; + _pendingRefStack.RemoveAt(_pendingRefStack.Count - 1); + _pendingRefStack.Add(last with { Bound = true }); + RefReader.StoreRef(value, last.RefId); + } + + public void FinishPendingReferenceIfNeeded(object? value) + { + if (_pendingRefStack.Count == 0) + { + return; + } + + PendingRefSlot last = _pendingRefStack[^1]; + if (!last.Bound) + { + RefReader.StoreRef(value, last.RefId); + } + } + + public void PopPendingReference() + { + if (_pendingRefStack.Count > 0) + { + _pendingRefStack.RemoveAt(_pendingRefStack.Count - 1); + } + } + + public TypeMeta ReadCompatibleTypeMeta() + { + uint indexMarker = Reader.ReadVarUInt32(); + bool isRef = (indexMarker & 1) == 1; + int index = checked((int)(indexMarker >> 1)); + if (isRef) + { + TypeMeta? cached = CompatibleTypeDefState.TypeMetaAt(index); + if (cached is null) + { + throw new InvalidDataException($"unknown compatible type definition ref index {index}"); + } + + return cached; + } + + TypeMeta typeMeta = TypeMeta.Decode(Reader); + CompatibleTypeDefState.StoreTypeMeta(typeMeta, index); + return typeMeta; + } + + public void PushCompatibleTypeMeta(Type type, TypeMeta typeMeta) + { + _pendingCompatibleTypeMeta[type] = [typeMeta]; + } + + public TypeMeta ConsumeCompatibleTypeMeta(Type type) + { + if (!_pendingCompatibleTypeMeta.TryGetValue(type, out List? stack) || stack.Count == 0) + { + throw new InvalidDataException($"missing compatible type metadata for {type}"); + } + + return stack[^1]; + } + + public void SetDynamicTypeInfo(Type type, DynamicTypeInfo typeInfo) + { + _pendingDynamicTypeInfo[type] = typeInfo; + } + + public DynamicTypeInfo? DynamicTypeInfo(Type type) + { + return _pendingDynamicTypeInfo.TryGetValue(type, out DynamicTypeInfo? typeInfo) ? typeInfo : null; + } + + public void ClearDynamicTypeInfo(Type type) + { + _pendingDynamicTypeInfo.Remove(type); + } + + public T CanonicalizeNonTrackingReference(T value, int start, int end) + { + if (!TrackRef || end <= start || value is null || value is not object obj) + { + return value; + } + + byte[] bytes = new byte[end - start]; + Array.Copy(Reader.Storage, start, bytes, 0, bytes.Length); + (ulong hashLo, ulong hashHi) = MurmurHash3.X64_128(bytes, 47); + CanonicalReferenceSignature signature = new(obj.GetType(), hashLo, hashHi, bytes.Length); + + if (_canonicalReferenceCache.TryGetValue(signature, out List? bucket)) + { + foreach (CanonicalReferenceEntry entry in bucket) + { + if (entry.Bytes.AsSpan().SequenceEqual(bytes)) + { + return (T)entry.Object; + } + } + + bucket.Add(new CanonicalReferenceEntry { Bytes = bytes, Object = obj }); + return value; + } + + _canonicalReferenceCache[signature] = + [ + new CanonicalReferenceEntry { Bytes = bytes, Object = obj }, + ]; + return value; + } + + public void ResetObjectState() + { + RefReader.Reset(); + _pendingRefStack.Clear(); + _pendingCompatibleTypeMeta.Clear(); + _pendingDynamicTypeInfo.Clear(); + _canonicalReferenceCache.Clear(); + } + + public void Reset() + { + ResetObjectState(); + CompatibleTypeDefState.Reset(); + MetaStringReadState.Reset(); + } +} + diff --git a/csharp/src/Fory/DictionarySerializers.cs b/csharp/src/Fory/DictionarySerializers.cs new file mode 100644 index 0000000000..9b1e923a14 --- /dev/null +++ b/csharp/src/Fory/DictionarySerializers.cs @@ -0,0 +1,473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +internal static class DictionaryBits +{ + public const byte TrackingKeyRef = 0b0000_0001; + public const byte KeyNull = 0b0000_0010; + public const byte DeclaredKeyType = 0b0000_0100; + public const byte TrackingValueRef = 0b0000_1000; + public const byte ValueNull = 0b0001_0000; + public const byte DeclaredValueType = 0b0010_0000; +} + +public class DictionarySerializer : Serializer> + where TKey : notnull +{ + public override TypeId StaticTypeId => TypeId.Map; + public override bool IsNullableType => true; + public override bool IsReferenceTrackableType => true; + public override Dictionary DefaultValue => null!; + public override bool IsNone(in Dictionary value) => value is null; + + public override void WriteData(ref WriteContext context, in Dictionary value, bool hasGenerics) + { + Serializer keySerializer = context.TypeResolver.GetSerializer(); + Serializer valueSerializer = context.TypeResolver.GetSerializer(); + Dictionary map = value ?? []; + context.Writer.WriteVarUInt32((uint)map.Count); + if (map.Count == 0) + { + return; + } + + bool trackKeyRef = context.TrackRef && keySerializer.IsReferenceTrackableType; + bool trackValueRef = context.TrackRef && valueSerializer.IsReferenceTrackableType; + bool keyDeclared = hasGenerics && !keySerializer.StaticTypeId.NeedsTypeInfoForField(); + bool valueDeclared = hasGenerics && !valueSerializer.StaticTypeId.NeedsTypeInfoForField(); + bool keyDynamicType = keySerializer.StaticTypeId == TypeId.Unknown; + bool valueDynamicType = valueSerializer.StaticTypeId == TypeId.Unknown; + + KeyValuePair[] pairs = [.. map]; + if (keyDynamicType || valueDynamicType) + { + WriteDynamicMapPairs( + pairs, + ref context, + hasGenerics, + trackKeyRef, + trackValueRef, + keyDeclared, + valueDeclared, + keyDynamicType, + valueDynamicType, + keySerializer, + valueSerializer); + return; + } + + int index = 0; + while (index < pairs.Length) + { + KeyValuePair pair = pairs[index]; + bool keyIsNull = keySerializer.IsNoneObject(pair.Key); + bool valueIsNull = valueSerializer.IsNoneObject(pair.Value); + if (keyIsNull || valueIsNull) + { + byte header = 0; + if (trackKeyRef) + { + header |= DictionaryBits.TrackingKeyRef; + } + + if (trackValueRef) + { + header |= DictionaryBits.TrackingValueRef; + } + + if (keyIsNull) + { + header |= DictionaryBits.KeyNull; + } + + if (valueIsNull) + { + header |= DictionaryBits.ValueNull; + } + + if (!keyIsNull && keyDeclared) + { + header |= DictionaryBits.DeclaredKeyType; + } + + if (!valueIsNull && valueDeclared) + { + header |= DictionaryBits.DeclaredValueType; + } + + context.Writer.WriteUInt8(header); + if (!keyIsNull) + { + if (!keyDeclared) + { + keySerializer.WriteTypeInfo(ref context); + } + + keySerializer.Write(ref context, pair.Key, trackKeyRef ? RefMode.Tracking : RefMode.None, false, hasGenerics); + } + + if (!valueIsNull) + { + if (!valueDeclared) + { + valueSerializer.WriteTypeInfo(ref context); + } + + valueSerializer.Write(ref context, pair.Value, trackValueRef ? RefMode.Tracking : RefMode.None, false, hasGenerics); + } + + index += 1; + continue; + } + + byte blockHeader = 0; + if (trackKeyRef) + { + blockHeader |= DictionaryBits.TrackingKeyRef; + } + + if (trackValueRef) + { + blockHeader |= DictionaryBits.TrackingValueRef; + } + + if (keyDeclared) + { + blockHeader |= DictionaryBits.DeclaredKeyType; + } + + if (valueDeclared) + { + blockHeader |= DictionaryBits.DeclaredValueType; + } + + context.Writer.WriteUInt8(blockHeader); + int chunkSizeOffset = context.Writer.Count; + context.Writer.WriteUInt8(0); + if (!keyDeclared) + { + keySerializer.WriteTypeInfo(ref context); + } + + if (!valueDeclared) + { + valueSerializer.WriteTypeInfo(ref context); + } + + byte chunkSize = 0; + while (index < pairs.Length && chunkSize < byte.MaxValue) + { + KeyValuePair current = pairs[index]; + if (keySerializer.IsNoneObject(current.Key) || valueSerializer.IsNoneObject(current.Value)) + { + break; + } + + keySerializer.Write(ref context, current.Key, trackKeyRef ? RefMode.Tracking : RefMode.None, false, hasGenerics); + valueSerializer.Write(ref context, current.Value, trackValueRef ? RefMode.Tracking : RefMode.None, false, hasGenerics); + chunkSize += 1; + index += 1; + } + + context.Writer.SetByte(chunkSizeOffset, chunkSize); + } + } + + public override Dictionary ReadData(ref ReadContext context) + { + Serializer keySerializer = context.TypeResolver.GetSerializer(); + Serializer valueSerializer = context.TypeResolver.GetSerializer(); + int totalLength = checked((int)context.Reader.ReadVarUInt32()); + if (totalLength == 0) + { + return []; + } + + Dictionary map = new(totalLength); + bool keyDynamicType = keySerializer.StaticTypeId == TypeId.Unknown; + bool valueDynamicType = valueSerializer.StaticTypeId == TypeId.Unknown; + bool canonicalizeValues = context.TrackRef && valueSerializer.IsReferenceTrackableType; + + int readCount = 0; + while (readCount < totalLength) + { + byte header = context.Reader.ReadUInt8(); + bool trackKeyRef = (header & DictionaryBits.TrackingKeyRef) != 0; + bool keyNull = (header & DictionaryBits.KeyNull) != 0; + bool keyDeclared = (header & DictionaryBits.DeclaredKeyType) != 0; + bool trackValueRef = (header & DictionaryBits.TrackingValueRef) != 0; + bool valueNull = (header & DictionaryBits.ValueNull) != 0; + bool valueDeclared = (header & DictionaryBits.DeclaredValueType) != 0; + + if (keyNull && valueNull) + { + // Dictionary cannot represent a null key. + // Drop this entry instead of mapping it to default(TKey), which would corrupt key semantics. + readCount += 1; + continue; + } + + if (keyNull) + { + TValue value = ReadValueElement( + ref context, + trackValueRef, + !valueDeclared, + canonicalizeValues, + valueSerializer); + + // Preserve stream/reference state by reading value payload, then skip null-key entry. + // This avoids injecting a fake default(TKey) key into Dictionary. + readCount += 1; + continue; + } + + if (valueNull) + { + TKey key = keySerializer.Read( + ref context, + trackKeyRef ? RefMode.Tracking : RefMode.None, + !keyDeclared); + + map[key] = (TValue)valueSerializer.DefaultObject!; + readCount += 1; + continue; + } + + int chunkSize = context.Reader.ReadUInt8(); + if (keyDynamicType || valueDynamicType) + { + for (int i = 0; i < chunkSize; i++) + { + DynamicTypeInfo? keyDynamicInfo = null; + DynamicTypeInfo? valueDynamicInfo = null; + + if (!keyDeclared) + { + if (keyDynamicType) + { + keyDynamicInfo = context.TypeResolver.ReadDynamicTypeInfo(ref context); + } + else + { + keySerializer.ReadTypeInfo(ref context); + } + } + + if (!valueDeclared) + { + if (valueDynamicType) + { + valueDynamicInfo = context.TypeResolver.ReadDynamicTypeInfo(ref context); + } + else + { + valueSerializer.ReadTypeInfo(ref context); + } + } + + if (keyDynamicInfo is not null) + { + context.SetDynamicTypeInfo(typeof(TKey), keyDynamicInfo); + } + + TKey key = keySerializer.Read(ref context, trackKeyRef ? RefMode.Tracking : RefMode.None, false); + if (keyDynamicInfo is not null) + { + context.ClearDynamicTypeInfo(typeof(TKey)); + } + + if (valueDynamicInfo is not null) + { + context.SetDynamicTypeInfo(typeof(TValue), valueDynamicInfo); + } + + TValue value = ReadValueElement( + ref context, + trackValueRef, + false, + canonicalizeValues, + valueSerializer); + if (valueDynamicInfo is not null) + { + context.ClearDynamicTypeInfo(typeof(TValue)); + } + + map[key] = value; + } + + readCount += chunkSize; + continue; + } + + if (!keyDeclared) + { + keySerializer.ReadTypeInfo(ref context); + } + + if (!valueDeclared) + { + valueSerializer.ReadTypeInfo(ref context); + } + + for (int i = 0; i < chunkSize; i++) + { + TKey key = keySerializer.Read(ref context, trackKeyRef ? RefMode.Tracking : RefMode.None, false); + TValue value = ReadValueElement(ref context, trackValueRef, false, canonicalizeValues, valueSerializer); + map[key] = value; + } + + if (!keyDeclared) + { + context.ClearDynamicTypeInfo(typeof(TKey)); + } + + if (!valueDeclared) + { + context.ClearDynamicTypeInfo(typeof(TValue)); + } + + readCount += chunkSize; + } + + return map; + } + + private static void WriteDynamicMapPairs( + KeyValuePair[] pairs, + ref WriteContext context, + bool hasGenerics, + bool trackKeyRef, + bool trackValueRef, + bool keyDeclared, + bool valueDeclared, + bool keyDynamicType, + bool valueDynamicType, + Serializer keySerializer, + Serializer valueSerializer) + { + foreach (KeyValuePair pair in pairs) + { + bool keyIsNull = keySerializer.IsNoneObject(pair.Key); + bool valueIsNull = valueSerializer.IsNoneObject(pair.Value); + byte header = 0; + if (trackKeyRef) + { + header |= DictionaryBits.TrackingKeyRef; + } + + if (trackValueRef) + { + header |= DictionaryBits.TrackingValueRef; + } + + if (keyIsNull) + { + header |= DictionaryBits.KeyNull; + } + else if (!keyDynamicType && keyDeclared) + { + header |= DictionaryBits.DeclaredKeyType; + } + + if (valueIsNull) + { + header |= DictionaryBits.ValueNull; + } + else if (!valueDynamicType && valueDeclared) + { + header |= DictionaryBits.DeclaredValueType; + } + + context.Writer.WriteUInt8(header); + if (keyIsNull && valueIsNull) + { + continue; + } + + if (keyIsNull) + { + valueSerializer.Write( + ref context, + pair.Value, + trackValueRef ? RefMode.Tracking : RefMode.None, + !valueDeclared, + hasGenerics); + continue; + } + + if (valueIsNull) + { + keySerializer.Write( + ref context, + pair.Key, + trackKeyRef ? RefMode.Tracking : RefMode.None, + !keyDeclared, + hasGenerics); + continue; + } + + context.Writer.WriteUInt8(1); + if (!keyDeclared) + { + if (keyDynamicType) + { + DynamicAnyCodec.WriteAnyTypeInfo(pair.Key!, ref context); + } + else + { + keySerializer.WriteTypeInfo(ref context); + } + } + + if (!valueDeclared) + { + if (valueDynamicType) + { + DynamicAnyCodec.WriteAnyTypeInfo(pair.Value!, ref context); + } + else + { + valueSerializer.WriteTypeInfo(ref context); + } + } + + keySerializer.Write(ref context, pair.Key, trackKeyRef ? RefMode.Tracking : RefMode.None, false, hasGenerics); + valueSerializer.Write(ref context, pair.Value, trackValueRef ? RefMode.Tracking : RefMode.None, false, hasGenerics); + } + } + + private static TValue ReadValueElement( + ref ReadContext context, + bool trackValueRef, + bool readTypeInfo, + bool canonicalizeValues, + Serializer valueSerializer) + { + if (trackValueRef || !canonicalizeValues) + { + return valueSerializer.Read(ref context, trackValueRef ? RefMode.Tracking : RefMode.None, readTypeInfo); + } + + int start = context.Reader.Cursor; + TValue value = valueSerializer.Read(ref context, RefMode.None, readTypeInfo); + int end = context.Reader.Cursor; + return context.CanonicalizeNonTrackingReference(value, start, end); + } +} diff --git a/csharp/src/Fory/EnumSerializer.cs b/csharp/src/Fory/EnumSerializer.cs new file mode 100644 index 0000000000..ba269899e2 --- /dev/null +++ b/csharp/src/Fory/EnumSerializer.cs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public sealed class EnumSerializer : Serializer where TEnum : struct, Enum +{ + public override TypeId StaticTypeId => TypeId.Enum; + public override TEnum DefaultValue => default; + + public override void WriteData(ref WriteContext context, in TEnum value, bool hasGenerics) + { + _ = hasGenerics; + uint ordinal = Convert.ToUInt32(value); + context.Writer.WriteVarUInt32(ordinal); + } + + public override TEnum ReadData(ref ReadContext context) + { + uint ordinal = context.Reader.ReadVarUInt32(); + TEnum value = (TEnum)Enum.ToObject(typeof(TEnum), ordinal); + if (!Enum.IsDefined(typeof(TEnum), value)) + { + throw new InvalidDataException($"unknown enum ordinal {ordinal}"); + } + + return value; + } +} diff --git a/csharp/src/Fory/FieldSkipper.cs b/csharp/src/Fory/FieldSkipper.cs new file mode 100644 index 0000000000..22121372cd --- /dev/null +++ b/csharp/src/Fory/FieldSkipper.cs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public static class FieldSkipper +{ + public static void SkipFieldValue(ref ReadContext context, TypeMetaFieldType fieldType) + { + _ = ReadFieldValue(ref context, fieldType); + } + + private static uint? ReadEnumOrdinal(ref ReadContext context, RefMode refMode) + { + return refMode switch + { + RefMode.None => context.Reader.ReadVarUInt32(), + RefMode.NullOnly => ReadNullableEnumOrdinal(ref context), + RefMode.Tracking => throw new InvalidDataException("enum tracking ref mode is not supported"), + _ => throw new InvalidDataException($"unsupported ref mode {refMode}"), + }; + } + + private static uint? ReadNullableEnumOrdinal(ref ReadContext context) + { + sbyte flag = context.Reader.ReadInt8(); + if (flag == (sbyte)RefFlag.Null) + { + return null; + } + + if (flag != (sbyte)RefFlag.NotNullValue) + { + throw new InvalidDataException($"unexpected enum nullOnly flag {flag}"); + } + + return context.Reader.ReadVarUInt32(); + } + + private static object? ReadFieldValue(ref ReadContext context, TypeMetaFieldType fieldType) + { + RefMode refMode = RefModeExtensions.From(fieldType.Nullable, fieldType.TrackRef); + switch (fieldType.TypeId) + { + case (uint)TypeId.Bool: + return context.TypeResolver.GetSerializer().Read(ref context, refMode, false); + case (uint)TypeId.Int8: + return context.TypeResolver.GetSerializer().Read(ref context, refMode, false); + case (uint)TypeId.Int16: + return context.TypeResolver.GetSerializer().Read(ref context, refMode, false); + case (uint)TypeId.VarInt32: + return context.TypeResolver.GetSerializer().Read(ref context, refMode, false); + case (uint)TypeId.VarInt64: + return context.TypeResolver.GetSerializer().Read(ref context, refMode, false); + case (uint)TypeId.Float32: + return context.TypeResolver.GetSerializer().Read(ref context, refMode, false); + case (uint)TypeId.Float64: + return context.TypeResolver.GetSerializer().Read(ref context, refMode, false); + case (uint)TypeId.String: + return context.TypeResolver.GetSerializer().Read(ref context, refMode, false); + case (uint)TypeId.List: + { + if (fieldType.Generics.Count != 1 || fieldType.Generics[0].TypeId != (uint)TypeId.String) + { + throw new InvalidDataException("unsupported compatible list element type"); + } + + return context.TypeResolver.GetSerializer>().Read(ref context, refMode, false); + } + case (uint)TypeId.Set: + { + if (fieldType.Generics.Count != 1 || fieldType.Generics[0].TypeId != (uint)TypeId.String) + { + throw new InvalidDataException("unsupported compatible set element type"); + } + + return context.TypeResolver.GetSerializer>().Read(ref context, refMode, false); + } + case (uint)TypeId.Map: + { + if (fieldType.Generics.Count != 2 || + fieldType.Generics[0].TypeId != (uint)TypeId.String || + fieldType.Generics[1].TypeId != (uint)TypeId.String) + { + throw new InvalidDataException("unsupported compatible map key/value type"); + } + + return context.TypeResolver.GetSerializer>().Read(ref context, refMode, false); + } + case (uint)TypeId.Enum: + return ReadEnumOrdinal(ref context, refMode); + case (uint)TypeId.Union: + case (uint)TypeId.TypedUnion: + case (uint)TypeId.NamedUnion: + return context.TypeResolver.GetSerializer().Read(ref context, refMode, false); + default: + throw new InvalidDataException($"unsupported compatible field type id {fieldType.TypeId}"); + } + } +} diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs new file mode 100644 index 0000000000..124fa0ceda --- /dev/null +++ b/csharp/src/Fory/Fory.cs @@ -0,0 +1,242 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Buffers; + +namespace Apache.Fory; + +public sealed class Fory +{ + private readonly TypeResolver _typeResolver; + + internal Fory(Config config) + { + Config = config; + _typeResolver = new TypeResolver(); + } + + public Config Config { get; } + + public static ForyBuilder Builder() + { + return new ForyBuilder(); + } + + public Fory Register(uint typeId) + { + _typeResolver.Register(typeof(T), typeId); + return this; + } + + public Fory Register(string typeName) + { + _typeResolver.Register(typeof(T), string.Empty, typeName); + return this; + } + + public Fory Register(string typeNamespace, string typeName) + { + _typeResolver.Register(typeof(T), typeNamespace, typeName); + return this; + } + + public Fory Register(uint typeId) + where TSerializer : Serializer, new() + { + Serializer serializerBinding = _typeResolver.RegisterSerializer(); + _typeResolver.Register(typeof(T), typeId, serializerBinding); + return this; + } + + public Fory Register(string typeNamespace, string typeName) + where TSerializer : Serializer, new() + { + Serializer serializerBinding = _typeResolver.RegisterSerializer(); + _typeResolver.Register(typeof(T), typeNamespace, typeName, serializerBinding); + return this; + } + + public byte[] Serialize(in T value) + { + ByteWriter writer = new(); + Serializer binding = _typeResolver.GetSerializer(); + bool isNone = binding.IsNone(value); + WriteHead(writer, isNone); + if (!isNone) + { + WriteContext context = new( + writer, + _typeResolver, + Config.TrackRef, + Config.Compatible, + new CompatibleTypeDefWriteState(), + new MetaStringWriteState()); + RefMode refMode = Config.TrackRef ? RefMode.Tracking : RefMode.NullOnly; + binding.Write(ref context, value, refMode, true, false); + context.ResetObjectState(); + } + + return writer.ToArray(); + } + + public void Serialize(IBufferWriter output, in T value) + { + byte[] payload = Serialize(value); + output.Write(payload); + } + + public T Deserialize(ReadOnlySpan payload) + { + ByteReader reader = new(payload); + T value = DeserializeFromReader(reader); + if (reader.Remaining != 0) + { + throw new InvalidDataException($"unexpected trailing bytes after deserializing {typeof(T)}"); + } + + return value; + } + + public T Deserialize(ref ReadOnlySequence payload) + { + byte[] bytes = payload.ToArray(); + ByteReader reader = new(bytes); + T value = DeserializeFromReader(reader); + payload = payload.Slice(reader.Cursor); + return value; + } + + public byte[] SerializeObject(object? value) + { + ByteWriter writer = new(); + bool isNone = value is null; + WriteHead(writer, isNone); + if (!isNone) + { + WriteContext context = new( + writer, + _typeResolver, + Config.TrackRef, + Config.Compatible, + new CompatibleTypeDefWriteState(), + new MetaStringWriteState()); + RefMode refMode = Config.TrackRef ? RefMode.Tracking : RefMode.NullOnly; + DynamicAnyCodec.WriteAny(ref context, value, refMode, true, false); + context.ResetObjectState(); + } + + return writer.ToArray(); + } + + public void SerializeObject(IBufferWriter output, object? value) + { + byte[] payload = SerializeObject(value); + output.Write(payload); + } + + public object? DeserializeObject(ReadOnlySpan payload) + { + ByteReader reader = new(payload); + object? value = DeserializeObjectFromReader(reader); + if (reader.Remaining != 0) + { + throw new InvalidDataException("unexpected trailing bytes after deserializing dynamic object"); + } + + return value; + } + + public object? DeserializeObject(ref ReadOnlySequence payload) + { + byte[] bytes = payload.ToArray(); + ByteReader reader = new(bytes); + object? value = DeserializeObjectFromReader(reader); + payload = payload.Slice(reader.Cursor); + return value; + } + + public void WriteHead(ByteWriter writer, bool isNone) + { + byte bitmap = 0; + if (Config.Xlang) + { + bitmap |= ForyHeaderFlag.IsXlang; + } + + if (isNone) + { + bitmap |= ForyHeaderFlag.IsNull; + } + + writer.WriteUInt8(bitmap); + } + + public bool ReadHead(ByteReader reader) + { + byte bitmap = reader.ReadUInt8(); + bool peerIsXlang = (bitmap & ForyHeaderFlag.IsXlang) != 0; + if (peerIsXlang != Config.Xlang) + { + throw new InvalidDataException("xlang bitmap mismatch"); + } + + return (bitmap & ForyHeaderFlag.IsNull) != 0; + } + + private T DeserializeFromReader(ByteReader reader) + { + bool isNone = ReadHead(reader); + Serializer binding = _typeResolver.GetSerializer(); + if (isNone) + { + return binding.DefaultValue; + } + + ReadContext context = new( + reader, + _typeResolver, + Config.TrackRef, + Config.Compatible, + new CompatibleTypeDefReadState(), + new MetaStringReadState()); + RefMode refMode = Config.TrackRef ? RefMode.Tracking : RefMode.NullOnly; + T value = binding.Read(ref context, refMode, true); + context.ResetObjectState(); + return value; + } + + private object? DeserializeObjectFromReader(ByteReader reader) + { + bool isNone = ReadHead(reader); + if (isNone) + { + return null; + } + + ReadContext context = new( + reader, + _typeResolver, + Config.TrackRef, + Config.Compatible, + new CompatibleTypeDefReadState(), + new MetaStringReadState()); + RefMode refMode = Config.TrackRef ? RefMode.Tracking : RefMode.NullOnly; + object? value = DynamicAnyCodec.ReadAny(ref context, refMode, true); + context.ResetObjectState(); + return value; + } +} diff --git a/csharp/src/Fory/Fory.csproj b/csharp/src/Fory/Fory.csproj new file mode 100644 index 0000000000..2082c7f5e4 --- /dev/null +++ b/csharp/src/Fory/Fory.csproj @@ -0,0 +1,8 @@ + + + net8.0 + 12.0 + enable + enable + + diff --git a/csharp/src/Fory/ForyException.cs b/csharp/src/Fory/ForyException.cs new file mode 100644 index 0000000000..d3c4689ff2 --- /dev/null +++ b/csharp/src/Fory/ForyException.cs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public class ForyException : Exception +{ + public ForyException(string message) : base(message) + { + } +} + +public sealed class InvalidDataException : ForyException +{ + public InvalidDataException(string message) : base($"Invalid data: {message}") + { + } +} + +public sealed class TypeMismatchException : ForyException +{ + public TypeMismatchException(uint expected, uint actual) + : base($"Type mismatch: expected {expected}, got {actual}") + { + } +} + +public sealed class TypeNotRegisteredException : ForyException +{ + public TypeNotRegisteredException(string message) : base($"Type not registered: {message}") + { + } +} + +public sealed class RefException : ForyException +{ + public RefException(string message) : base($"Reference error: {message}") + { + } +} + +public sealed class EncodingException : ForyException +{ + public EncodingException(string message) : base($"Encoding error: {message}") + { + } +} + +public sealed class OutOfBoundsException : ForyException +{ + public OutOfBoundsException(int cursor, int need, int length) + : base($"Buffer out of bounds: cursor={cursor}, need={need}, length={length}") + { + } +} + diff --git a/csharp/src/Fory/ForyFlags.cs b/csharp/src/Fory/ForyFlags.cs new file mode 100644 index 0000000000..d5b1819579 --- /dev/null +++ b/csharp/src/Fory/ForyFlags.cs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public enum RefFlag : sbyte +{ + Null = -3, + Ref = -2, + NotNullValue = -1, + RefValue = 0, +} + +public enum RefMode : byte +{ + None = 0, + NullOnly = 1, + Tracking = 2, +} + +internal static class RefModeExtensions +{ + public static RefMode From(bool nullable, bool trackRef) + { + if (trackRef) + { + return RefMode.Tracking; + } + + return nullable ? RefMode.NullOnly : RefMode.None; + } +} + +public static class ForyHeaderFlag +{ + public const byte IsNull = 0x01; + public const byte IsXlang = 0x02; + public const byte IsOutOfBand = 0x04; +} + diff --git a/csharp/src/Fory/MetaString.cs b/csharp/src/Fory/MetaString.cs new file mode 100644 index 0000000000..f67152e2d2 --- /dev/null +++ b/csharp/src/Fory/MetaString.cs @@ -0,0 +1,533 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Text; + +namespace Apache.Fory; + +public enum MetaStringEncoding : byte +{ + Utf8 = 0, + LowerSpecial = 1, + LowerUpperDigitSpecial = 2, + FirstToLowerSpecial = 3, + AllToLowerSpecial = 4, +} + +public readonly struct MetaString : IEquatable +{ + private const int MaxMetaStringLength = 32_767; + + public MetaString( + string value, + MetaStringEncoding encoding, + char specialChar1, + char specialChar2, + byte[] bytes) + { + if (value.Length >= MaxMetaStringLength) + { + throw new EncodingException("meta string too long"); + } + + if (encoding != MetaStringEncoding.Utf8 && bytes.Length == 0) + { + throw new EncodingException("encoded meta string cannot be empty"); + } + + Value = value; + Encoding = encoding; + SpecialChar1 = specialChar1; + SpecialChar2 = specialChar2; + Bytes = bytes; + StripLastChar = encoding != MetaStringEncoding.Utf8 && (bytes[0] & 0x80) != 0; + } + + public string Value { get; } + + public MetaStringEncoding Encoding { get; } + + public char SpecialChar1 { get; } + + public char SpecialChar2 { get; } + + public byte[] Bytes { get; } + + public bool StripLastChar { get; } + + public static MetaString Empty(char specialChar1, char specialChar2) + { + return new MetaString(string.Empty, MetaStringEncoding.Utf8, specialChar1, specialChar2, []); + } + + public bool Equals(MetaString other) + { + return Value == other.Value && + Encoding == other.Encoding && + SpecialChar1 == other.SpecialChar1 && + SpecialChar2 == other.SpecialChar2 && + Bytes.AsSpan().SequenceEqual(other.Bytes); + } + + public override bool Equals(object? obj) + { + return obj is MetaString other && Equals(other); + } + + public override int GetHashCode() + { + HashCode hc = new(); + hc.Add(Value); + hc.Add(Encoding); + hc.Add(SpecialChar1); + hc.Add(SpecialChar2); + foreach (byte b in Bytes) + { + hc.Add(b); + } + + return hc.ToHashCode(); + } +} + +public sealed class MetaStringEncoder +{ + private const int MaxMetaStringLength = 32_767; + + public MetaStringEncoder(char specialChar1, char specialChar2) + { + SpecialChar1 = specialChar1; + SpecialChar2 = specialChar2; + } + + public char SpecialChar1 { get; } + + public char SpecialChar2 { get; } + + public static MetaStringEncoder Namespace { get; } = new('.', '_'); + + public static MetaStringEncoder TypeName { get; } = new('$', '_'); + + public static MetaStringEncoder FieldName { get; } = new('$', '_'); + + public MetaString Encode(string input) + { + return EncodeAuto(input, null); + } + + public MetaString Encode(string input, IReadOnlyList allowedEncodings) + { + return EncodeAuto(input, allowedEncodings); + } + + public MetaString Encode(string input, MetaStringEncoding encoding) + { + if (input.Length >= MaxMetaStringLength) + { + throw new EncodingException("meta string too long"); + } + + if (input.Length == 0) + { + return MetaString.Empty(SpecialChar1, SpecialChar2); + } + + if (encoding != MetaStringEncoding.Utf8 && !IsLatin(input)) + { + throw new EncodingException("non-ASCII characters are not allowed for packed meta string"); + } + + return encoding switch + { + MetaStringEncoding.Utf8 => new MetaString( + input, + MetaStringEncoding.Utf8, + SpecialChar1, + SpecialChar2, + Encoding.UTF8.GetBytes(input)), + MetaStringEncoding.LowerSpecial => new MetaString( + input, + MetaStringEncoding.LowerSpecial, + SpecialChar1, + SpecialChar2, + EncodeGeneric(input, 5, MapLowerSpecial)), + MetaStringEncoding.LowerUpperDigitSpecial => new MetaString( + input, + MetaStringEncoding.LowerUpperDigitSpecial, + SpecialChar1, + SpecialChar2, + EncodeGeneric(input, 6, MapLowerUpperDigitSpecial)), + MetaStringEncoding.FirstToLowerSpecial => new MetaString( + input, + MetaStringEncoding.FirstToLowerSpecial, + SpecialChar1, + SpecialChar2, + EncodeGeneric(LowerFirstAscii(input), 5, MapLowerSpecial)), + MetaStringEncoding.AllToLowerSpecial => new MetaString( + input, + MetaStringEncoding.AllToLowerSpecial, + SpecialChar1, + SpecialChar2, + EncodeGeneric(EscapeAllUpper(input), 5, MapLowerSpecial)), + _ => throw new EncodingException($"unsupported meta string encoding: {encoding}"), + }; + } + + private MetaString EncodeAuto(string input, IReadOnlyList? allowedEncodings) + { + if (input.Length >= MaxMetaStringLength) + { + throw new EncodingException("meta string too long"); + } + + if (input.Length == 0) + { + return MetaString.Empty(SpecialChar1, SpecialChar2); + } + + if (!IsLatin(input)) + { + return new MetaString(input, MetaStringEncoding.Utf8, SpecialChar1, SpecialChar2, Encoding.UTF8.GetBytes(input)); + } + + MetaStringEncoding encoding = ChooseEncoding(input, allowedEncodings); + return Encode(input, encoding); + } + + private MetaStringEncoding ChooseEncoding(string input, IReadOnlyList? allowedEncodings) + { + bool Allow(MetaStringEncoding encoding) + { + return allowedEncodings is null || allowedEncodings.Contains(encoding); + } + + int digitCount = 0; + int upperCount = 0; + bool canLowerSpecial = true; + bool canLowerUpperDigitSpecial = true; + + foreach (char c in input) + { + if (canLowerSpecial) + { + bool isValid = c is >= 'a' and <= 'z' || c is '.' or '_' or '$' or '|'; + if (!isValid) + { + canLowerSpecial = false; + } + } + + if (canLowerUpperDigitSpecial) + { + bool isLower = c is >= 'a' and <= 'z'; + bool isUpper = c is >= 'A' and <= 'Z'; + bool isDigit = c is >= '0' and <= '9'; + bool isSpecial = c == SpecialChar1 || c == SpecialChar2; + if (!(isLower || isUpper || isDigit || isSpecial)) + { + canLowerUpperDigitSpecial = false; + } + } + + if (c is >= '0' and <= '9') + { + digitCount++; + } + + if (c is >= 'A' and <= 'Z') + { + upperCount++; + } + } + + if (canLowerSpecial && Allow(MetaStringEncoding.LowerSpecial)) + { + return MetaStringEncoding.LowerSpecial; + } + + if (canLowerUpperDigitSpecial) + { + if (digitCount != 0 && Allow(MetaStringEncoding.LowerUpperDigitSpecial)) + { + return MetaStringEncoding.LowerUpperDigitSpecial; + } + + if (upperCount == 1 && + char.IsUpper(input[0]) && + Allow(MetaStringEncoding.FirstToLowerSpecial)) + { + return MetaStringEncoding.FirstToLowerSpecial; + } + + if ((input.Length + upperCount) * 5 < input.Length * 6 && Allow(MetaStringEncoding.AllToLowerSpecial)) + { + return MetaStringEncoding.AllToLowerSpecial; + } + + if (Allow(MetaStringEncoding.LowerUpperDigitSpecial)) + { + return MetaStringEncoding.LowerUpperDigitSpecial; + } + } + + return MetaStringEncoding.Utf8; + } + + private byte[] EncodeGeneric(string input, int bitsPerChar, Func mapper) + { + int totalBits = input.Length * bitsPerChar + 1; + int byteLength = (totalBits + 7) / 8; + byte[] bytes = new byte[byteLength]; + int currentBit = 1; + + foreach (char c in input) + { + byte value = mapper(c); + for (int i = bitsPerChar - 1; i >= 0; i--) + { + if (((value >> i) & 0x01) != 0) + { + int bytePos = currentBit / 8; + int bitPos = currentBit % 8; + bytes[bytePos] |= (byte)(1 << (7 - bitPos)); + } + + currentBit++; + } + } + + if (byteLength * 8 >= totalBits + bitsPerChar) + { + bytes[0] |= 0x80; + } + + return bytes; + } + + private static byte MapLowerSpecial(char c) + { + if (c is >= 'a' and <= 'z') + { + return (byte)(c - 'a'); + } + + return c switch + { + '.' => 26, + '_' => 27, + '$' => 28, + '|' => 29, + _ => throw new EncodingException("unsupported character in LOWER_SPECIAL"), + }; + } + + private byte MapLowerUpperDigitSpecial(char c) + { + if (c is >= 'a' and <= 'z') + { + return (byte)(c - 'a'); + } + + if (c is >= 'A' and <= 'Z') + { + return (byte)(26 + c - 'A'); + } + + if (c is >= '0' and <= '9') + { + return (byte)(52 + c - '0'); + } + + if (c == SpecialChar1) + { + return 62; + } + + if (c == SpecialChar2) + { + return 63; + } + + throw new EncodingException("unsupported character in LOWER_UPPER_DIGIT_SPECIAL"); + } + + private static string LowerFirstAscii(string input) + { + if (input.Length == 0) + { + return input; + } + + return char.ToLowerInvariant(input[0]) + input[1..]; + } + + private static string EscapeAllUpper(string input) + { + StringBuilder sb = new(input.Length * 2); + foreach (char c in input) + { + if (char.IsUpper(c)) + { + sb.Append('|'); + sb.Append(char.ToLowerInvariant(c)); + } + else + { + sb.Append(c); + } + } + + return sb.ToString(); + } + + private static bool IsLatin(string input) + { + foreach (char c in input) + { + if (c > 255) + { + return false; + } + } + + return true; + } +} + +public sealed class MetaStringDecoder +{ + public MetaStringDecoder(char specialChar1, char specialChar2) + { + SpecialChar1 = specialChar1; + SpecialChar2 = specialChar2; + } + + public char SpecialChar1 { get; } + + public char SpecialChar2 { get; } + + public static MetaStringDecoder Namespace { get; } = new('.', '_'); + + public static MetaStringDecoder TypeName { get; } = new('$', '_'); + + public static MetaStringDecoder FieldName { get; } = new('$', '_'); + + public MetaString Decode(byte[] bytes, MetaStringEncoding encoding) + { + string value = encoding switch + { + MetaStringEncoding.Utf8 => Encoding.UTF8.GetString(bytes), + MetaStringEncoding.LowerSpecial => DecodeGeneric(bytes, 5, UnmapLowerSpecial), + MetaStringEncoding.LowerUpperDigitSpecial => DecodeGeneric(bytes, 6, UnmapLowerUpperDigitSpecial), + MetaStringEncoding.FirstToLowerSpecial => + DecodeFirstToLowerSpecial(bytes), + MetaStringEncoding.AllToLowerSpecial => + UnescapeAllUpper(DecodeGeneric(bytes, 5, UnmapLowerSpecial)), + _ => throw new EncodingException($"unsupported meta string encoding: {encoding}"), + }; + + return new MetaString(value, encoding, SpecialChar1, SpecialChar2, bytes); + } + + private string DecodeFirstToLowerSpecial(byte[] bytes) + { + string decoded = DecodeGeneric(bytes, 5, UnmapLowerSpecial); + if (decoded.Length == 0) + { + return decoded; + } + + return char.ToUpperInvariant(decoded[0]) + decoded[1..]; + } + + private string DecodeGeneric(byte[] bytes, int bitsPerChar, Func mapper) + { + if (bytes.Length == 0) + { + return string.Empty; + } + + bool stripLast = (bytes[0] & 0x80) != 0; + int totalBits = bytes.Length * 8; + int bitIndex = 1; + StringBuilder sb = new(bytes.Length); + while (bitIndex + bitsPerChar <= totalBits && + !(stripLast && (bitIndex + 2 * bitsPerChar > totalBits))) + { + byte value = 0; + for (var i = 0; i < bitsPerChar; i++) + { + int byteIndex = bitIndex / 8; + int intra = bitIndex % 8; + byte bit = (byte)((bytes[byteIndex] >> (7 - intra)) & 0x01); + value = (byte)((value << 1) | bit); + bitIndex++; + } + + sb.Append(mapper(value)); + } + + return sb.ToString(); + } + + private static char UnmapLowerSpecial(byte value) + { + return value switch + { + <= 25 => (char)('a' + value), + 26 => '.', + 27 => '_', + 28 => '$', + 29 => '|', + _ => throw new EncodingException("invalid LOWER_SPECIAL value"), + }; + } + + private char UnmapLowerUpperDigitSpecial(byte value) + { + return value switch + { + <= 25 => (char)('a' + value), + <= 51 => (char)('A' + value - 26), + <= 61 => (char)('0' + value - 52), + 62 => SpecialChar1, + 63 => SpecialChar2, + _ => throw new EncodingException("invalid LOWER_UPPER_DIGIT_SPECIAL value"), + }; + } + + private static string UnescapeAllUpper(string input) + { + StringBuilder sb = new(input.Length); + for (int i = 0; i < input.Length; i++) + { + char c = input[i]; + if (c == '|' && i + 1 < input.Length) + { + i++; + sb.Append(char.ToUpperInvariant(input[i])); + } + else + { + sb.Append(c); + } + } + + return sb.ToString(); + } +} + diff --git a/csharp/src/Fory/MurmurHash3.cs b/csharp/src/Fory/MurmurHash3.cs new file mode 100644 index 0000000000..d36c7f0a2d --- /dev/null +++ b/csharp/src/Fory/MurmurHash3.cs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Buffers.Binary; + +namespace Apache.Fory; + +public static class MurmurHash3 +{ + public static (ulong H1, ulong H2) X64_128(ReadOnlySpan bytes, ulong seed = 47) + { + const ulong c1 = 0x87c37b91114253d5; + const ulong c2 = 0x4cf5ad432745937f; + + ulong h1 = seed; + ulong h2 = seed; + + int length = bytes.Length; + int nblocks = length / 16; + for (int i = 0; i < nblocks; i++) + { + int offset = i * 16; + ulong k1 = BinaryPrimitives.ReadUInt64LittleEndian(bytes.Slice(offset, 8)); + ulong k2 = BinaryPrimitives.ReadUInt64LittleEndian(bytes.Slice(offset + 8, 8)); + + k1 *= c1; + k1 = RotateLeft(k1, 31); + k1 *= c2; + h1 ^= k1; + + h1 = RotateLeft(h1, 27); + h1 += h2; + h1 = h1 * 5 + 0x52dce729; + + k2 *= c2; + k2 = RotateLeft(k2, 33); + k2 *= c1; + h2 ^= k2; + + h2 = RotateLeft(h2, 31); + h2 += h1; + h2 = h2 * 5 + 0x38495ab5; + } + + ulong tk1 = 0; + ulong tk2 = 0; + int tailStart = nblocks * 16; + ReadOnlySpan tail = bytes.Slice(tailStart); + switch (length & 15) + { + case 15: + tk2 ^= (ulong)tail[14] << 48; + goto case 14; + case 14: + tk2 ^= (ulong)tail[13] << 40; + goto case 13; + case 13: + tk2 ^= (ulong)tail[12] << 32; + goto case 12; + case 12: + tk2 ^= (ulong)tail[11] << 24; + goto case 11; + case 11: + tk2 ^= (ulong)tail[10] << 16; + goto case 10; + case 10: + tk2 ^= (ulong)tail[9] << 8; + goto case 9; + case 9: + tk2 ^= tail[8]; + tk2 *= c2; + tk2 = RotateLeft(tk2, 33); + tk2 *= c1; + h2 ^= tk2; + goto case 8; + case 8: + tk1 ^= (ulong)tail[7] << 56; + goto case 7; + case 7: + tk1 ^= (ulong)tail[6] << 48; + goto case 6; + case 6: + tk1 ^= (ulong)tail[5] << 40; + goto case 5; + case 5: + tk1 ^= (ulong)tail[4] << 32; + goto case 4; + case 4: + tk1 ^= (ulong)tail[3] << 24; + goto case 3; + case 3: + tk1 ^= (ulong)tail[2] << 16; + goto case 2; + case 2: + tk1 ^= (ulong)tail[1] << 8; + goto case 1; + case 1: + tk1 ^= tail[0]; + tk1 *= c1; + tk1 = RotateLeft(tk1, 31); + tk1 *= c2; + h1 ^= tk1; + break; + } + + h1 ^= (ulong)length; + h2 ^= (ulong)length; + h1 += h2; + h2 += h1; + h1 = Fmix64(h1); + h2 = Fmix64(h2); + h1 += h2; + h2 += h1; + return (h1, h2); + } + + private static ulong RotateLeft(ulong x, int r) + { + return (x << r) | (x >> (64 - r)); + } + + private static ulong Fmix64(ulong x) + { + x ^= x >> 33; + x *= 0xff51afd7ed558ccd; + x ^= x >> 33; + x *= 0xc4ceb9fe1a85ec53; + x ^= x >> 33; + return x; + } +} diff --git a/csharp/src/Fory/NullableKeyDictionary.cs b/csharp/src/Fory/NullableKeyDictionary.cs new file mode 100644 index 0000000000..19d2462610 --- /dev/null +++ b/csharp/src/Fory/NullableKeyDictionary.cs @@ -0,0 +1,821 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Collections; + +namespace Apache.Fory; + +#pragma warning disable CS8714 +public sealed class NullableKeyDictionary : IDictionary, IReadOnlyDictionary +{ + private readonly Dictionary _nonNullEntries; + private bool _hasNullKey; + private TValue _nullValue = default!; + private KeyCollection? _keys; + private ValueCollection? _values; + + public NullableKeyDictionary() + : this((IEqualityComparer?)null) + { + } + + public NullableKeyDictionary(int capacity) + : this(capacity, null) + { + } + + public NullableKeyDictionary(IEqualityComparer? comparer) + : this(0, comparer) + { + } + + public NullableKeyDictionary(int capacity, IEqualityComparer? comparer) + { + _nonNullEntries = comparer is null + ? new Dictionary(capacity) + : new Dictionary(capacity, comparer); + } + + public NullableKeyDictionary(IDictionary dictionary) + : this(dictionary, null) + { + } + + public NullableKeyDictionary(IDictionary dictionary, IEqualityComparer? comparer) + : this(dictionary?.Count ?? 0, comparer) + { + ArgumentNullException.ThrowIfNull(dictionary); + foreach (KeyValuePair entry in dictionary) + { + this[entry.Key] = entry.Value; + } + } + + public int Count => _nonNullEntries.Count + (_hasNullKey ? 1 : 0); + + public bool HasNullKey => _hasNullKey; + + public TValue NullKeyValue => _nullValue; + + public IEqualityComparer Comparer => _nonNullEntries.Comparer; + + public IEnumerable> NonNullEntries => _nonNullEntries; + + public ICollection Keys => _keys ??= new KeyCollection(this); + + IEnumerable IReadOnlyDictionary.Keys => Keys; + + public ICollection Values => _values ??= new ValueCollection(this); + + IEnumerable IReadOnlyDictionary.Values => Values; + + public TValue this[TKey key] + { + get + { + if (TryGetValue(key, out TValue value)) + { + return value; + } + + throw new KeyNotFoundException(); + } + set => SetValue(key, value); + } + + public bool IsReadOnly => false; + + public void Add(TKey key, TValue value) + { + if (key is null) + { + if (_hasNullKey) + { + throw new ArgumentException("An item with the same key has already been added.", nameof(key)); + } + + SetNullKeyValue(value); + return; + } + + _nonNullEntries.Add(key, value); + } + + public bool ContainsKey(TKey key) + { + if (key is null) + { + return _hasNullKey; + } + + return _nonNullEntries.ContainsKey(key); + } + + public bool Remove(TKey key) + { + if (key is null) + { + if (!_hasNullKey) + { + return false; + } + + _hasNullKey = false; + _nullValue = default!; + return true; + } + + return _nonNullEntries.Remove(key); + } + + public bool TryGetValue(TKey key, out TValue value) + { + if (key is null) + { + if (_hasNullKey) + { + value = _nullValue; + return true; + } + + value = default!; + return false; + } + + return _nonNullEntries.TryGetValue(key, out value!); + } + + public void Add(KeyValuePair item) + { + Add(item.Key, item.Value); + } + + public bool Contains(KeyValuePair item) + { + return TryGetValue(item.Key, out TValue value) && + EqualityComparer.Default.Equals(value, item.Value); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + ArgumentNullException.ThrowIfNull(array); + if (arrayIndex < 0) + { + throw new ArgumentOutOfRangeException(nameof(arrayIndex)); + } + + if (array.Length - arrayIndex < Count) + { + throw new ArgumentException("The destination array is too small.", nameof(array)); + } + + if (_hasNullKey) + { + array[arrayIndex++] = new KeyValuePair(default!, _nullValue); + } + + foreach (KeyValuePair entry in _nonNullEntries) + { + array[arrayIndex++] = entry; + } + } + + public bool Remove(KeyValuePair item) + { + if (!Contains(item)) + { + return false; + } + + return Remove(item.Key); + } + + public void Clear() + { + _nonNullEntries.Clear(); + _hasNullKey = false; + _nullValue = default!; + } + + internal void SetNullKeyValue(TValue value) + { + _hasNullKey = true; + _nullValue = value; + } + + private void SetValue(TKey key, TValue value) + { + if (key is null) + { + SetNullKeyValue(value); + return; + } + + _nonNullEntries[key] = value; + } + + public IEnumerator> GetEnumerator() + { + if (_hasNullKey) + { + yield return new KeyValuePair(default!, _nullValue); + } + + foreach (KeyValuePair entry in _nonNullEntries) + { + yield return entry; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + private sealed class KeyCollection(NullableKeyDictionary map) : ICollection + { + private readonly NullableKeyDictionary _map = map; + + public int Count => _map.Count; + + public bool IsReadOnly => true; + + public void Add(TKey item) + { + throw new NotSupportedException("Collection is read-only."); + } + + public void Clear() + { + throw new NotSupportedException("Collection is read-only."); + } + + public bool Contains(TKey item) + { + return _map.ContainsKey(item); + } + + public void CopyTo(TKey[] array, int arrayIndex) + { + ArgumentNullException.ThrowIfNull(array); + if (arrayIndex < 0) + { + throw new ArgumentOutOfRangeException(nameof(arrayIndex)); + } + + if (array.Length - arrayIndex < Count) + { + throw new ArgumentException("The destination array is too small.", nameof(array)); + } + + if (_map._hasNullKey) + { + array[arrayIndex++] = default!; + } + + _map._nonNullEntries.Keys.CopyTo(array, arrayIndex); + } + + public bool Remove(TKey item) + { + throw new NotSupportedException("Collection is read-only."); + } + + public IEnumerator GetEnumerator() + { + if (_map._hasNullKey) + { + yield return default!; + } + + foreach (TKey key in _map._nonNullEntries.Keys) + { + yield return key; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } + + private sealed class ValueCollection(NullableKeyDictionary map) : ICollection + { + private readonly NullableKeyDictionary _map = map; + + public int Count => _map.Count; + + public bool IsReadOnly => true; + + public void Add(TValue item) + { + throw new NotSupportedException("Collection is read-only."); + } + + public void Clear() + { + throw new NotSupportedException("Collection is read-only."); + } + + public bool Contains(TValue item) + { + if (_map._hasNullKey && EqualityComparer.Default.Equals(_map._nullValue, item)) + { + return true; + } + + return _map._nonNullEntries.Values.Contains(item); + } + + public void CopyTo(TValue[] array, int arrayIndex) + { + ArgumentNullException.ThrowIfNull(array); + if (arrayIndex < 0) + { + throw new ArgumentOutOfRangeException(nameof(arrayIndex)); + } + + if (array.Length - arrayIndex < Count) + { + throw new ArgumentException("The destination array is too small.", nameof(array)); + } + + if (_map._hasNullKey) + { + array[arrayIndex++] = _map._nullValue; + } + + _map._nonNullEntries.Values.CopyTo(array, arrayIndex); + } + + public bool Remove(TValue item) + { + throw new NotSupportedException("Collection is read-only."); + } + + public IEnumerator GetEnumerator() + { + if (_map._hasNullKey) + { + yield return _map._nullValue; + } + + foreach (TValue value in _map._nonNullEntries.Values) + { + yield return value; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} + +public sealed class NullableKeyDictionarySerializer : Serializer> +{ + public override TypeId StaticTypeId => TypeId.Map; + public override bool IsNullableType => true; + public override bool IsReferenceTrackableType => true; + public override NullableKeyDictionary DefaultValue => null!; + public override bool IsNone(in NullableKeyDictionary value) => value is null; + + public override void WriteData(ref WriteContext context, in NullableKeyDictionary value, bool hasGenerics) + { + Serializer keySerializer = context.TypeResolver.GetSerializer(); + Serializer valueSerializer = context.TypeResolver.GetSerializer(); + NullableKeyDictionary map = value ?? new NullableKeyDictionary(); + context.Writer.WriteVarUInt32((uint)map.Count); + if (map.Count == 0) + { + return; + } + + bool trackKeyRef = context.TrackRef && keySerializer.IsReferenceTrackableType; + bool trackValueRef = context.TrackRef && valueSerializer.IsReferenceTrackableType; + bool keyDeclared = hasGenerics && !keySerializer.StaticTypeId.NeedsTypeInfoForField(); + bool valueDeclared = hasGenerics && !valueSerializer.StaticTypeId.NeedsTypeInfoForField(); + bool keyDynamicType = keySerializer.StaticTypeId == TypeId.Unknown; + bool valueDynamicType = valueSerializer.StaticTypeId == TypeId.Unknown; + KeyValuePair[] pairs = [.. map]; + if (keyDynamicType || valueDynamicType) + { + WriteDynamicMapPairs( + pairs, + ref context, + hasGenerics, + trackKeyRef, + trackValueRef, + keyDeclared, + valueDeclared, + keyDynamicType, + valueDynamicType, + keySerializer, + valueSerializer); + return; + } + + foreach (KeyValuePair entry in pairs) + { + bool keyIsNull = entry.Key is null || keySerializer.IsNoneObject(entry.Key); + bool valueIsNull = valueSerializer.IsNoneObject(entry.Value); + byte header = 0; + if (trackKeyRef) + { + header |= DictionaryBits.TrackingKeyRef; + } + + if (trackValueRef) + { + header |= DictionaryBits.TrackingValueRef; + } + + if (keyIsNull) + { + header |= DictionaryBits.KeyNull; + } + else if (keyDeclared) + { + header |= DictionaryBits.DeclaredKeyType; + } + + if (valueIsNull) + { + header |= DictionaryBits.ValueNull; + } + else if (valueDeclared) + { + header |= DictionaryBits.DeclaredValueType; + } + + context.Writer.WriteUInt8(header); + if (keyIsNull && valueIsNull) + { + continue; + } + + if (keyIsNull) + { + if (!valueDeclared) + { + valueSerializer.WriteTypeInfo(ref context); + } + + valueSerializer.Write( + ref context, + entry.Value, + trackValueRef ? RefMode.Tracking : RefMode.None, + false, + hasGenerics); + continue; + } + + if (valueIsNull) + { + if (!keyDeclared) + { + keySerializer.WriteTypeInfo(ref context); + } + + keySerializer.Write( + ref context, + entry.Key!, + trackKeyRef ? RefMode.Tracking : RefMode.None, + false, + hasGenerics); + continue; + } + + context.Writer.WriteUInt8(1); + if (!keyDeclared) + { + keySerializer.WriteTypeInfo(ref context); + } + + if (!valueDeclared) + { + valueSerializer.WriteTypeInfo(ref context); + } + + keySerializer.Write( + ref context, + entry.Key!, + trackKeyRef ? RefMode.Tracking : RefMode.None, + false, + hasGenerics); + valueSerializer.Write( + ref context, + entry.Value, + trackValueRef ? RefMode.Tracking : RefMode.None, + false, + hasGenerics); + } + } + + public override NullableKeyDictionary ReadData(ref ReadContext context) + { + Serializer keySerializer = context.TypeResolver.GetSerializer(); + Serializer valueSerializer = context.TypeResolver.GetSerializer(); + int totalLength = checked((int)context.Reader.ReadVarUInt32()); + if (totalLength == 0) + { + return new NullableKeyDictionary(); + } + + NullableKeyDictionary map = new(); + bool keyDynamicType = keySerializer.StaticTypeId == TypeId.Unknown; + bool valueDynamicType = valueSerializer.StaticTypeId == TypeId.Unknown; + bool canonicalizeValues = context.TrackRef && valueSerializer.IsReferenceTrackableType; + + int readCount = 0; + while (readCount < totalLength) + { + byte header = context.Reader.ReadUInt8(); + bool trackKeyRef = (header & DictionaryBits.TrackingKeyRef) != 0; + bool keyNull = (header & DictionaryBits.KeyNull) != 0; + bool keyDeclared = (header & DictionaryBits.DeclaredKeyType) != 0; + bool trackValueRef = (header & DictionaryBits.TrackingValueRef) != 0; + bool valueNull = (header & DictionaryBits.ValueNull) != 0; + bool valueDeclared = (header & DictionaryBits.DeclaredValueType) != 0; + + if (keyNull && valueNull) + { + map.SetNullKeyValue((TValue)valueSerializer.DefaultObject!); + readCount += 1; + continue; + } + + if (keyNull) + { + TValue valueRead = ReadValueElement( + ref context, + trackValueRef, + !valueDeclared, + canonicalizeValues, + valueSerializer); + + map.SetNullKeyValue(valueRead); + readCount += 1; + continue; + } + + if (valueNull) + { + TKey key = keySerializer.Read( + ref context, + trackKeyRef ? RefMode.Tracking : RefMode.None, + !keyDeclared); + + map[key] = (TValue)valueSerializer.DefaultObject!; + readCount += 1; + continue; + } + + int chunkSize = context.Reader.ReadUInt8(); + if (keyDynamicType || valueDynamicType) + { + for (int i = 0; i < chunkSize; i++) + { + DynamicTypeInfo? keyDynamicInfo = null; + DynamicTypeInfo? valueDynamicInfo = null; + + if (!keyDeclared) + { + if (keyDynamicType) + { + keyDynamicInfo = context.TypeResolver.ReadDynamicTypeInfo(ref context); + } + else + { + keySerializer.ReadTypeInfo(ref context); + } + } + + if (!valueDeclared) + { + if (valueDynamicType) + { + valueDynamicInfo = context.TypeResolver.ReadDynamicTypeInfo(ref context); + } + else + { + valueSerializer.ReadTypeInfo(ref context); + } + } + + if (keyDynamicInfo is not null) + { + context.SetDynamicTypeInfo(typeof(TKey), keyDynamicInfo); + } + + TKey key = keySerializer.Read(ref context, trackKeyRef ? RefMode.Tracking : RefMode.None, false); + if (keyDynamicInfo is not null) + { + context.ClearDynamicTypeInfo(typeof(TKey)); + } + + if (valueDynamicInfo is not null) + { + context.SetDynamicTypeInfo(typeof(TValue), valueDynamicInfo); + } + + TValue valueRead = ReadValueElement( + ref context, + trackValueRef, + false, + canonicalizeValues, + valueSerializer); + if (valueDynamicInfo is not null) + { + context.ClearDynamicTypeInfo(typeof(TValue)); + } + + map[key] = valueRead; + } + + readCount += chunkSize; + continue; + } + + if (!keyDeclared) + { + keySerializer.ReadTypeInfo(ref context); + } + + if (!valueDeclared) + { + valueSerializer.ReadTypeInfo(ref context); + } + + for (int i = 0; i < chunkSize; i++) + { + TKey key = keySerializer.Read(ref context, trackKeyRef ? RefMode.Tracking : RefMode.None, false); + TValue valueRead = ReadValueElement(ref context, trackValueRef, false, canonicalizeValues, valueSerializer); + map[key] = valueRead; + } + + if (!keyDeclared) + { + context.ClearDynamicTypeInfo(typeof(TKey)); + } + + if (!valueDeclared) + { + context.ClearDynamicTypeInfo(typeof(TValue)); + } + + readCount += chunkSize; + } + + return map; + } + + private static void WriteDynamicMapPairs( + KeyValuePair[] pairs, + ref WriteContext context, + bool hasGenerics, + bool trackKeyRef, + bool trackValueRef, + bool keyDeclared, + bool valueDeclared, + bool keyDynamicType, + bool valueDynamicType, + Serializer keySerializer, + Serializer valueSerializer) + { + foreach (KeyValuePair pair in pairs) + { + bool keyIsNull = pair.Key is null || keySerializer.IsNoneObject(pair.Key); + bool valueIsNull = valueSerializer.IsNoneObject(pair.Value); + byte header = 0; + if (trackKeyRef) + { + header |= DictionaryBits.TrackingKeyRef; + } + + if (trackValueRef) + { + header |= DictionaryBits.TrackingValueRef; + } + + if (keyIsNull) + { + header |= DictionaryBits.KeyNull; + } + else if (!keyDynamicType && keyDeclared) + { + header |= DictionaryBits.DeclaredKeyType; + } + + if (valueIsNull) + { + header |= DictionaryBits.ValueNull; + } + else if (!valueDynamicType && valueDeclared) + { + header |= DictionaryBits.DeclaredValueType; + } + + context.Writer.WriteUInt8(header); + if (keyIsNull && valueIsNull) + { + continue; + } + + if (keyIsNull) + { + valueSerializer.Write( + ref context, + pair.Value, + trackValueRef ? RefMode.Tracking : RefMode.None, + !valueDeclared, + hasGenerics); + continue; + } + + if (valueIsNull) + { + keySerializer.Write( + ref context, + pair.Key!, + trackKeyRef ? RefMode.Tracking : RefMode.None, + !keyDeclared, + hasGenerics); + continue; + } + + context.Writer.WriteUInt8(1); + if (!keyDeclared) + { + if (keyDynamicType) + { + DynamicAnyCodec.WriteAnyTypeInfo(pair.Key!, ref context); + } + else + { + keySerializer.WriteTypeInfo(ref context); + } + } + + if (!valueDeclared) + { + if (valueDynamicType) + { + DynamicAnyCodec.WriteAnyTypeInfo(pair.Value!, ref context); + } + else + { + valueSerializer.WriteTypeInfo(ref context); + } + } + + keySerializer.Write( + ref context, + pair.Key!, + trackKeyRef ? RefMode.Tracking : RefMode.None, + false, + hasGenerics); + valueSerializer.Write( + ref context, + pair.Value, + trackValueRef ? RefMode.Tracking : RefMode.None, + false, + hasGenerics); + } + } + + private static TValue ReadValueElement( + ref ReadContext context, + bool trackValueRef, + bool readTypeInfo, + bool canonicalizeValues, + Serializer valueSerializer) + { + if (trackValueRef || !canonicalizeValues) + { + return valueSerializer.Read(ref context, trackValueRef ? RefMode.Tracking : RefMode.None, readTypeInfo); + } + + int start = context.Reader.Cursor; + TValue value = valueSerializer.Read(ref context, RefMode.None, readTypeInfo); + int end = context.Reader.Cursor; + return context.CanonicalizeNonTrackingReference(value, start, end); + } +} +#pragma warning restore CS8714 diff --git a/csharp/src/Fory/OptionalSerializer.cs b/csharp/src/Fory/OptionalSerializer.cs new file mode 100644 index 0000000000..4ac31bcda1 --- /dev/null +++ b/csharp/src/Fory/OptionalSerializer.cs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public sealed class NullableSerializer : Serializer where T : struct +{ + private readonly Serializer _defaultWrappedSerializer = new TypeResolver().GetSerializer(); + + public override TypeId StaticTypeId => _defaultWrappedSerializer.StaticTypeId; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => _defaultWrappedSerializer.IsReferenceTrackableType; + + public override T? DefaultValue => null; + + public override bool IsNone(in T? value) + { + return !value.HasValue; + } + + public override void WriteData(ref WriteContext context, in T? value, bool hasGenerics) + { + if (!value.HasValue) + { + throw new InvalidDataException("Nullable.None cannot write raw payload"); + } + + T wrapped = value.Value; + Serializer wrappedSerializer = context.TypeResolver.GetSerializer(); + wrappedSerializer.WriteData(ref context, wrapped, hasGenerics); + } + + public override T? ReadData(ref ReadContext context) + { + Serializer wrappedSerializer = context.TypeResolver.GetSerializer(); + return wrappedSerializer.ReadData(ref context); + } + + public override void WriteTypeInfo(ref WriteContext context) + { + Serializer wrappedSerializer = context.TypeResolver.GetSerializer(); + wrappedSerializer.WriteTypeInfo(ref context); + } + + public override void ReadTypeInfo(ref ReadContext context) + { + Serializer wrappedSerializer = context.TypeResolver.GetSerializer(); + wrappedSerializer.ReadTypeInfo(ref context); + } + + public override IReadOnlyList CompatibleTypeMetaFields(bool trackRef) + { + return _defaultWrappedSerializer.CompatibleTypeMetaFields(trackRef); + } + + public override void Write(ref WriteContext context, in T? value, RefMode refMode, bool writeTypeInfo, bool hasGenerics) + { + Serializer wrappedSerializer = context.TypeResolver.GetSerializer(); + switch (refMode) + { + case RefMode.None: + if (!value.HasValue) + { + throw new InvalidDataException("Nullable.None with RefMode.None"); + } + + T wrapped = value.Value; + wrappedSerializer.Write(ref context, wrapped, RefMode.None, writeTypeInfo, hasGenerics); + break; + case RefMode.NullOnly: + if (!value.HasValue) + { + context.Writer.WriteInt8((sbyte)RefFlag.Null); + return; + } + + context.Writer.WriteInt8((sbyte)RefFlag.NotNullValue); + wrappedSerializer.Write(ref context, value.Value, RefMode.None, writeTypeInfo, hasGenerics); + break; + case RefMode.Tracking: + if (!value.HasValue) + { + context.Writer.WriteInt8((sbyte)RefFlag.Null); + return; + } + + wrappedSerializer.Write(ref context, value.Value, RefMode.Tracking, writeTypeInfo, hasGenerics); + break; + default: + throw new InvalidDataException($"unsupported ref mode {refMode}"); + } + } + + public override T? Read(ref ReadContext context, RefMode refMode, bool readTypeInfo) + { + Serializer wrappedSerializer = context.TypeResolver.GetSerializer(); + switch (refMode) + { + case RefMode.None: + return wrappedSerializer.Read(ref context, RefMode.None, readTypeInfo); + case RefMode.NullOnly: + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + return null; + } + + return wrappedSerializer.Read(ref context, RefMode.None, readTypeInfo); + } + case RefMode.Tracking: + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + return null; + } + + context.Reader.MoveBack(1); + return wrappedSerializer.Read(ref context, RefMode.Tracking, readTypeInfo); + } + default: + throw new InvalidDataException($"unsupported ref mode {refMode}"); + } + } +} diff --git a/csharp/src/Fory/PrimitiveArraySerializers.cs b/csharp/src/Fory/PrimitiveArraySerializers.cs new file mode 100644 index 0000000000..dffbcaa461 --- /dev/null +++ b/csharp/src/Fory/PrimitiveArraySerializers.cs @@ -0,0 +1,418 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +internal sealed class BoolArraySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.BoolArray; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override bool[] DefaultValue => null!; + + public override bool IsNone(in bool[] value) => value is null; + + public override void WriteData(ref WriteContext context, in bool[] value, bool hasGenerics) + { + _ = hasGenerics; + bool[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)safe.Length); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteUInt8(safe[i] ? (byte)1 : (byte)0); + } + } + + public override bool[] ReadData(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + bool[] values = new bool[payloadSize]; + for (int i = 0; i < payloadSize; i++) + { + values[i] = context.Reader.ReadUInt8() != 0; + } + + return values; + } +} + +internal sealed class Int8ArraySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Int8Array; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override sbyte[] DefaultValue => null!; + + public override bool IsNone(in sbyte[] value) => value is null; + + public override void WriteData(ref WriteContext context, in sbyte[] value, bool hasGenerics) + { + _ = hasGenerics; + sbyte[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)safe.Length); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteInt8(safe[i]); + } + } + + public override sbyte[] ReadData(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + sbyte[] values = new sbyte[payloadSize]; + for (int i = 0; i < payloadSize; i++) + { + values[i] = context.Reader.ReadInt8(); + } + + return values; + } +} + +internal sealed class Int16ArraySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Int16Array; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override short[] DefaultValue => null!; + + public override bool IsNone(in short[] value) => value is null; + + public override void WriteData(ref WriteContext context, in short[] value, bool hasGenerics) + { + _ = hasGenerics; + short[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)(safe.Length * 2)); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteInt16(safe[i]); + } + } + + public override short[] ReadData(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 1) != 0) + { + throw new InvalidDataException("int16 array payload size mismatch"); + } + + short[] values = new short[payloadSize / 2]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadInt16(); + } + + return values; + } +} + +internal sealed class Int32ArraySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Int32Array; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override int[] DefaultValue => null!; + + public override bool IsNone(in int[] value) => value is null; + + public override void WriteData(ref WriteContext context, in int[] value, bool hasGenerics) + { + _ = hasGenerics; + int[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)(safe.Length * 4)); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteInt32(safe[i]); + } + } + + public override int[] ReadData(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 3) != 0) + { + throw new InvalidDataException("int32 array payload size mismatch"); + } + + int[] values = new int[payloadSize / 4]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadInt32(); + } + + return values; + } +} + +internal sealed class Int64ArraySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Int64Array; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override long[] DefaultValue => null!; + + public override bool IsNone(in long[] value) => value is null; + + public override void WriteData(ref WriteContext context, in long[] value, bool hasGenerics) + { + _ = hasGenerics; + long[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)(safe.Length * 8)); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteInt64(safe[i]); + } + } + + public override long[] ReadData(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 7) != 0) + { + throw new InvalidDataException("int64 array payload size mismatch"); + } + + long[] values = new long[payloadSize / 8]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadInt64(); + } + + return values; + } +} + +internal sealed class UInt16ArraySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.UInt16Array; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override ushort[] DefaultValue => null!; + + public override bool IsNone(in ushort[] value) => value is null; + + public override void WriteData(ref WriteContext context, in ushort[] value, bool hasGenerics) + { + _ = hasGenerics; + ushort[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)(safe.Length * 2)); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteUInt16(safe[i]); + } + } + + public override ushort[] ReadData(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 1) != 0) + { + throw new InvalidDataException("uint16 array payload size mismatch"); + } + + ushort[] values = new ushort[payloadSize / 2]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadUInt16(); + } + + return values; + } +} + +internal sealed class UInt32ArraySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.UInt32Array; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override uint[] DefaultValue => null!; + + public override bool IsNone(in uint[] value) => value is null; + + public override void WriteData(ref WriteContext context, in uint[] value, bool hasGenerics) + { + _ = hasGenerics; + uint[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)(safe.Length * 4)); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteUInt32(safe[i]); + } + } + + public override uint[] ReadData(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 3) != 0) + { + throw new InvalidDataException("uint32 array payload size mismatch"); + } + + uint[] values = new uint[payloadSize / 4]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadUInt32(); + } + + return values; + } +} + +internal sealed class UInt64ArraySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.UInt64Array; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override ulong[] DefaultValue => null!; + + public override bool IsNone(in ulong[] value) => value is null; + + public override void WriteData(ref WriteContext context, in ulong[] value, bool hasGenerics) + { + _ = hasGenerics; + ulong[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)(safe.Length * 8)); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteUInt64(safe[i]); + } + } + + public override ulong[] ReadData(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 7) != 0) + { + throw new InvalidDataException("uint64 array payload size mismatch"); + } + + ulong[] values = new ulong[payloadSize / 8]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadUInt64(); + } + + return values; + } +} + +internal sealed class Float32ArraySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Float32Array; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override float[] DefaultValue => null!; + + public override bool IsNone(in float[] value) => value is null; + + public override void WriteData(ref WriteContext context, in float[] value, bool hasGenerics) + { + _ = hasGenerics; + float[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)(safe.Length * 4)); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteFloat32(safe[i]); + } + } + + public override float[] ReadData(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 3) != 0) + { + throw new InvalidDataException("float32 array payload size mismatch"); + } + + float[] values = new float[payloadSize / 4]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadFloat32(); + } + + return values; + } +} + +internal sealed class Float64ArraySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Float64Array; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override double[] DefaultValue => null!; + + public override bool IsNone(in double[] value) => value is null; + + public override void WriteData(ref WriteContext context, in double[] value, bool hasGenerics) + { + _ = hasGenerics; + double[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)(safe.Length * 8)); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteFloat64(safe[i]); + } + } + + public override double[] ReadData(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 7) != 0) + { + throw new InvalidDataException("float64 array payload size mismatch"); + } + + double[] values = new double[payloadSize / 8]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadFloat64(); + } + + return values; + } +} diff --git a/csharp/src/Fory/PrimitiveCollectionSerializers.cs b/csharp/src/Fory/PrimitiveCollectionSerializers.cs new file mode 100644 index 0000000000..28fefcad86 --- /dev/null +++ b/csharp/src/Fory/PrimitiveCollectionSerializers.cs @@ -0,0 +1,736 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +internal static class PrimitiveCollectionHeader +{ + public static void WriteListHeader(ref WriteContext context, int count, bool hasGenerics, TypeId elementTypeId, bool hasNull) + { + context.Writer.WriteVarUInt32((uint)count); + if (count == 0) + { + return; + } + + bool declared = hasGenerics && !elementTypeId.NeedsTypeInfoForField(); + byte header = CollectionBits.SameType; + if (hasNull) + { + header |= CollectionBits.HasNull; + } + + if (declared) + { + header |= CollectionBits.DeclaredElementType; + } + + context.Writer.WriteUInt8(header); + if (!declared) + { + context.Writer.WriteUInt8((byte)elementTypeId); + } + } +} + +internal sealed class ListBoolSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.Bool, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteUInt8(list[i] ? (byte)1 : (byte)0); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListInt8Serializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.Int8, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteInt8(list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListInt16Serializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.Int16, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteInt16(list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListIntSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.VarInt32, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteVarInt32(list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListLongSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.VarInt64, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteVarInt64(list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListUInt8Serializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.UInt8, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteUInt8(list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListUInt16Serializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.UInt16, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteUInt16(list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListUIntSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.VarUInt32, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteVarUInt32(list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListULongSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.VarUInt64, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteVarUInt64(list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListFloatSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.Float32, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteFloat32(list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListDoubleSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.Float64, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteFloat64(list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListStringSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + bool hasNull = false; + for (int i = 0; i < list.Count; i++) + { + if (list[i] is null) + { + hasNull = true; + break; + } + } + + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.String, hasNull); + if (hasNull) + { + for (int i = 0; i < list.Count; i++) + { + string? item = list[i]; + if (item is null) + { + context.Writer.WriteInt8((sbyte)RefFlag.Null); + continue; + } + + context.Writer.WriteInt8((sbyte)RefFlag.NotNullValue); + StringSerializer.WriteString(ref context, item); + } + + return; + } + + for (int i = 0; i < list.Count; i++) + { + StringSerializer.WriteString(ref context, list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class SetInt8Serializer : Serializer> +{ + private static readonly SetSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.Set; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override HashSet DefaultValue => null!; + + public override bool IsNone(in HashSet value) => value is null; + + public override void WriteData(ref WriteContext context, in HashSet value, bool hasGenerics) + { + HashSet set = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, set.Count, hasGenerics, TypeId.Int8, false); + foreach (sbyte item in set) + { + context.Writer.WriteInt8(item); + } + } + + public override HashSet ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class SetInt16Serializer : Serializer> +{ + private static readonly SetSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.Set; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override HashSet DefaultValue => null!; + + public override bool IsNone(in HashSet value) => value is null; + + public override void WriteData(ref WriteContext context, in HashSet value, bool hasGenerics) + { + HashSet set = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, set.Count, hasGenerics, TypeId.Int16, false); + foreach (short item in set) + { + context.Writer.WriteInt16(item); + } + } + + public override HashSet ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class SetIntSerializer : Serializer> +{ + private static readonly SetSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.Set; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override HashSet DefaultValue => null!; + + public override bool IsNone(in HashSet value) => value is null; + + public override void WriteData(ref WriteContext context, in HashSet value, bool hasGenerics) + { + HashSet set = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, set.Count, hasGenerics, TypeId.VarInt32, false); + foreach (int item in set) + { + context.Writer.WriteVarInt32(item); + } + } + + public override HashSet ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class SetLongSerializer : Serializer> +{ + private static readonly SetSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.Set; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override HashSet DefaultValue => null!; + + public override bool IsNone(in HashSet value) => value is null; + + public override void WriteData(ref WriteContext context, in HashSet value, bool hasGenerics) + { + HashSet set = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, set.Count, hasGenerics, TypeId.VarInt64, false); + foreach (long item in set) + { + context.Writer.WriteVarInt64(item); + } + } + + public override HashSet ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class SetUInt8Serializer : Serializer> +{ + private static readonly SetSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.Set; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override HashSet DefaultValue => null!; + + public override bool IsNone(in HashSet value) => value is null; + + public override void WriteData(ref WriteContext context, in HashSet value, bool hasGenerics) + { + HashSet set = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, set.Count, hasGenerics, TypeId.UInt8, false); + foreach (byte item in set) + { + context.Writer.WriteUInt8(item); + } + } + + public override HashSet ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class SetUInt16Serializer : Serializer> +{ + private static readonly SetSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.Set; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override HashSet DefaultValue => null!; + + public override bool IsNone(in HashSet value) => value is null; + + public override void WriteData(ref WriteContext context, in HashSet value, bool hasGenerics) + { + HashSet set = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, set.Count, hasGenerics, TypeId.UInt16, false); + foreach (ushort item in set) + { + context.Writer.WriteUInt16(item); + } + } + + public override HashSet ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class SetUIntSerializer : Serializer> +{ + private static readonly SetSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.Set; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override HashSet DefaultValue => null!; + + public override bool IsNone(in HashSet value) => value is null; + + public override void WriteData(ref WriteContext context, in HashSet value, bool hasGenerics) + { + HashSet set = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, set.Count, hasGenerics, TypeId.VarUInt32, false); + foreach (uint item in set) + { + context.Writer.WriteVarUInt32(item); + } + } + + public override HashSet ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class SetULongSerializer : Serializer> +{ + private static readonly SetSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.Set; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override HashSet DefaultValue => null!; + + public override bool IsNone(in HashSet value) => value is null; + + public override void WriteData(ref WriteContext context, in HashSet value, bool hasGenerics) + { + HashSet set = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, set.Count, hasGenerics, TypeId.VarUInt64, false); + foreach (ulong item in set) + { + context.Writer.WriteVarUInt64(item); + } + } + + public override HashSet ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class SetFloatSerializer : Serializer> +{ + private static readonly SetSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.Set; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override HashSet DefaultValue => null!; + + public override bool IsNone(in HashSet value) => value is null; + + public override void WriteData(ref WriteContext context, in HashSet value, bool hasGenerics) + { + HashSet set = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, set.Count, hasGenerics, TypeId.Float32, false); + foreach (float item in set) + { + context.Writer.WriteFloat32(item); + } + } + + public override HashSet ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class SetDoubleSerializer : Serializer> +{ + private static readonly SetSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.Set; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override HashSet DefaultValue => null!; + + public override bool IsNone(in HashSet value) => value is null; + + public override void WriteData(ref WriteContext context, in HashSet value, bool hasGenerics) + { + HashSet set = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, set.Count, hasGenerics, TypeId.Float64, false); + foreach (double item in set) + { + context.Writer.WriteFloat64(item); + } + } + + public override HashSet ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} diff --git a/csharp/src/Fory/PrimitiveDictionarySerializers.cs b/csharp/src/Fory/PrimitiveDictionarySerializers.cs new file mode 100644 index 0000000000..c32c729a52 --- /dev/null +++ b/csharp/src/Fory/PrimitiveDictionarySerializers.cs @@ -0,0 +1,579 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +internal static class PrimitiveDictionaryHeader +{ + public static void WriteMapChunkTypeInfo( + ref WriteContext context, + bool keyDeclared, + bool valueDeclared, + TypeId keyTypeId, + TypeId valueTypeId) + { + if (!keyDeclared) + { + context.Writer.WriteUInt8((byte)keyTypeId); + } + + if (!valueDeclared) + { + context.Writer.WriteUInt8((byte)valueTypeId); + } + } +} + +internal interface IPrimitiveDictionaryCodec +{ + static abstract TypeId WireTypeId { get; } + + static abstract bool IsNullable { get; } + + static abstract T DefaultValue { get; } + + static abstract bool IsNone(T value); + + static abstract void Write(ref WriteContext context, T value); + + static abstract T Read(ref ReadContext context); +} + +internal readonly struct StringPrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.String; + + public static bool IsNullable => true; + + public static string DefaultValue => null!; + + public static bool IsNone(string value) => value is null; + + public static void Write(ref WriteContext context, string value) + { + StringSerializer.WriteString(ref context, value ?? string.Empty); + } + + public static string Read(ref ReadContext context) + { + return StringSerializer.ReadString(ref context); + } +} + +internal readonly struct BoolPrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.Bool; + + public static bool IsNullable => false; + + public static bool DefaultValue => false; + + public static bool IsNone(bool value) => false; + + public static void Write(ref WriteContext context, bool value) + { + context.Writer.WriteUInt8(value ? (byte)1 : (byte)0); + } + + public static bool Read(ref ReadContext context) + { + return context.Reader.ReadUInt8() != 0; + } +} + +internal readonly struct Int8PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.Int8; + + public static bool IsNullable => false; + + public static sbyte DefaultValue => 0; + + public static bool IsNone(sbyte value) => false; + + public static void Write(ref WriteContext context, sbyte value) + { + context.Writer.WriteInt8(value); + } + + public static sbyte Read(ref ReadContext context) + { + return context.Reader.ReadInt8(); + } +} + +internal readonly struct Int16PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.Int16; + + public static bool IsNullable => false; + + public static short DefaultValue => 0; + + public static bool IsNone(short value) => false; + + public static void Write(ref WriteContext context, short value) + { + context.Writer.WriteInt16(value); + } + + public static short Read(ref ReadContext context) + { + return context.Reader.ReadInt16(); + } +} + +internal readonly struct Int32PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.VarInt32; + + public static bool IsNullable => false; + + public static int DefaultValue => 0; + + public static bool IsNone(int value) => false; + + public static void Write(ref WriteContext context, int value) + { + context.Writer.WriteVarInt32(value); + } + + public static int Read(ref ReadContext context) + { + return context.Reader.ReadVarInt32(); + } +} + +internal readonly struct Int64PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.VarInt64; + + public static bool IsNullable => false; + + public static long DefaultValue => 0; + + public static bool IsNone(long value) => false; + + public static void Write(ref WriteContext context, long value) + { + context.Writer.WriteVarInt64(value); + } + + public static long Read(ref ReadContext context) + { + return context.Reader.ReadVarInt64(); + } +} + +internal readonly struct UInt16PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.UInt16; + + public static bool IsNullable => false; + + public static ushort DefaultValue => 0; + + public static bool IsNone(ushort value) => false; + + public static void Write(ref WriteContext context, ushort value) + { + context.Writer.WriteUInt16(value); + } + + public static ushort Read(ref ReadContext context) + { + return context.Reader.ReadUInt16(); + } +} + +internal readonly struct UInt32PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.VarUInt32; + + public static bool IsNullable => false; + + public static uint DefaultValue => 0; + + public static bool IsNone(uint value) => false; + + public static void Write(ref WriteContext context, uint value) + { + context.Writer.WriteVarUInt32(value); + } + + public static uint Read(ref ReadContext context) + { + return context.Reader.ReadVarUInt32(); + } +} + +internal readonly struct UInt64PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.VarUInt64; + + public static bool IsNullable => false; + + public static ulong DefaultValue => 0; + + public static bool IsNone(ulong value) => false; + + public static void Write(ref WriteContext context, ulong value) + { + context.Writer.WriteVarUInt64(value); + } + + public static ulong Read(ref ReadContext context) + { + return context.Reader.ReadVarUInt64(); + } +} + +internal readonly struct Float32PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.Float32; + + public static bool IsNullable => false; + + public static float DefaultValue => 0; + + public static bool IsNone(float value) => false; + + public static void Write(ref WriteContext context, float value) + { + context.Writer.WriteFloat32(value); + } + + public static float Read(ref ReadContext context) + { + return context.Reader.ReadFloat32(); + } +} + +internal readonly struct Float64PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.Float64; + + public static bool IsNullable => false; + + public static double DefaultValue => 0; + + public static bool IsNone(double value) => false; + + public static void Write(ref WriteContext context, double value) + { + context.Writer.WriteFloat64(value); + } + + public static double Read(ref ReadContext context) + { + return context.Reader.ReadFloat64(); + } +} + +internal static class PrimitiveDictionaryCodecWriter +{ + public static void WriteMap( + ref WriteContext context, + Dictionary map, + bool hasGenerics) + where TKey : notnull + where TKeyCodec : struct, IPrimitiveDictionaryCodec + where TValueCodec : struct, IPrimitiveDictionaryCodec + { + KeyValuePair[] pairs = [.. map]; + context.Writer.WriteVarUInt32((uint)pairs.Length); + if (pairs.Length == 0) + { + return; + } + + TypeId keyTypeId = TKeyCodec.WireTypeId; + TypeId valueTypeId = TValueCodec.WireTypeId; + bool keyDeclared = hasGenerics && !keyTypeId.NeedsTypeInfoForField(); + bool valueDeclared = hasGenerics && !valueTypeId.NeedsTypeInfoForField(); + bool keyNullable = TKeyCodec.IsNullable; + bool valueNullable = TValueCodec.IsNullable; + + int index = 0; + while (index < pairs.Length) + { + KeyValuePair pair = pairs[index]; + bool keyNull = keyNullable && TKeyCodec.IsNone(pair.Key); + bool valueNull = valueNullable && TValueCodec.IsNone(pair.Value); + if (keyNull || valueNull) + { + byte header = 0; + if (keyNull) + { + header |= DictionaryBits.KeyNull; + } + else if (keyDeclared) + { + header |= DictionaryBits.DeclaredKeyType; + } + + if (valueNull) + { + header |= DictionaryBits.ValueNull; + } + else if (valueDeclared) + { + header |= DictionaryBits.DeclaredValueType; + } + + context.Writer.WriteUInt8(header); + if (!keyNull) + { + if (!keyDeclared) + { + context.Writer.WriteUInt8((byte)keyTypeId); + } + + TKeyCodec.Write(ref context, pair.Key); + } + + if (!valueNull) + { + if (!valueDeclared) + { + context.Writer.WriteUInt8((byte)valueTypeId); + } + + TValueCodec.Write(ref context, pair.Value); + } + + index += 1; + continue; + } + + byte blockHeader = 0; + if (keyDeclared) + { + blockHeader |= DictionaryBits.DeclaredKeyType; + } + + if (valueDeclared) + { + blockHeader |= DictionaryBits.DeclaredValueType; + } + + context.Writer.WriteUInt8(blockHeader); + int chunkSizeOffset = context.Writer.Count; + context.Writer.WriteUInt8(0); + PrimitiveDictionaryHeader.WriteMapChunkTypeInfo(ref context, keyDeclared, valueDeclared, keyTypeId, valueTypeId); + + byte chunkSize = 0; + while (index < pairs.Length && chunkSize < byte.MaxValue) + { + pair = pairs[index]; + keyNull = keyNullable && TKeyCodec.IsNone(pair.Key); + valueNull = valueNullable && TValueCodec.IsNone(pair.Value); + if (keyNull || valueNull) + { + break; + } + + TKeyCodec.Write(ref context, pair.Key); + TValueCodec.Write(ref context, pair.Value); + index += 1; + chunkSize += 1; + } + + context.Writer.SetByte(chunkSizeOffset, chunkSize); + } + } +} + +internal static class PrimitiveDictionaryCodecReader +{ + public static Dictionary ReadMap(ref ReadContext context) + where TKey : notnull + where TKeyCodec : struct, IPrimitiveDictionaryCodec + where TValueCodec : struct, IPrimitiveDictionaryCodec + { + int totalLength = checked((int)context.Reader.ReadVarUInt32()); + if (totalLength == 0) + { + return []; + } + + TypeId keyTypeId = TKeyCodec.WireTypeId; + TypeId valueTypeId = TValueCodec.WireTypeId; + bool keyNullable = TKeyCodec.IsNullable; + Dictionary map = new(totalLength); + + int readCount = 0; + while (readCount < totalLength) + { + byte header = context.Reader.ReadUInt8(); + bool trackKeyRef = (header & DictionaryBits.TrackingKeyRef) != 0; + bool keyNull = (header & DictionaryBits.KeyNull) != 0; + bool keyDeclared = (header & DictionaryBits.DeclaredKeyType) != 0; + bool trackValueRef = (header & DictionaryBits.TrackingValueRef) != 0; + bool valueNull = (header & DictionaryBits.ValueNull) != 0; + bool valueDeclared = (header & DictionaryBits.DeclaredValueType) != 0; + if (trackKeyRef || trackValueRef) + { + throw new InvalidDataException("primitive dictionary codecs do not support reference-tracking flags"); + } + + if (keyNull && !keyNullable) + { + throw new InvalidDataException("non-nullable primitive dictionary key cannot be null"); + } + + if (keyNull && valueNull) + { + readCount += 1; + continue; + } + + if (keyNull) + { + if (!valueDeclared) + { + ReadAndValidateTypeInfo(ref context, valueTypeId); + } + + _ = TValueCodec.Read(ref context); + readCount += 1; + continue; + } + + if (valueNull) + { + if (!keyDeclared) + { + ReadAndValidateTypeInfo(ref context, keyTypeId); + } + + TKey key = TKeyCodec.Read(ref context); + map[key] = TValueCodec.DefaultValue; + readCount += 1; + continue; + } + + int chunkSize = context.Reader.ReadUInt8(); + if (chunkSize == 0) + { + throw new InvalidDataException("invalid primitive map chunk size 0"); + } + + if (!keyDeclared) + { + ReadAndValidateTypeInfo(ref context, keyTypeId); + } + + if (!valueDeclared) + { + ReadAndValidateTypeInfo(ref context, valueTypeId); + } + + for (int i = 0; i < chunkSize; i++) + { + TKey key = TKeyCodec.Read(ref context); + TValue value = TValueCodec.Read(ref context); + map[key] = value; + } + + readCount += chunkSize; + } + + return map; + } + + private static void ReadAndValidateTypeInfo(ref ReadContext context, TypeId expectedTypeId) + { + uint actualTypeId = context.Reader.ReadVarUInt32(); + if (actualTypeId != (uint)expectedTypeId) + { + throw new TypeMismatchException((uint)expectedTypeId, actualTypeId); + } + } +} + +internal class PrimitiveDictionarySerializer : Serializer> + where TKey : notnull + where TKeyCodec : struct, IPrimitiveDictionaryCodec + where TValueCodec : struct, IPrimitiveDictionaryCodec +{ + public override TypeId StaticTypeId => TypeId.Map; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override Dictionary DefaultValue => null!; + + public override bool IsNone(in Dictionary value) => value is null; + + public override void WriteData(ref WriteContext context, in Dictionary value, bool hasGenerics) + { + Dictionary map = value ?? []; + PrimitiveDictionaryCodecWriter.WriteMap(ref context, map, hasGenerics); + } + + public override Dictionary ReadData(ref ReadContext context) + { + return PrimitiveDictionaryCodecReader.ReadMap(ref context); + } +} + +internal class PrimitiveStringKeyDictionarySerializer + : PrimitiveDictionarySerializer + where TValueCodec : struct, IPrimitiveDictionaryCodec +{ +} + +internal class PrimitiveSameTypeDictionarySerializer + : PrimitiveDictionarySerializer + where TValue : notnull + where TValueCodec : struct, IPrimitiveDictionaryCodec +{ +} + +// String-key primitive dictionary serializers. +internal sealed class DictionaryStringBoolSerializer : PrimitiveStringKeyDictionarySerializer { } + +internal sealed class DictionaryStringDoubleSerializer : PrimitiveStringKeyDictionarySerializer { } + +internal sealed class DictionaryStringFloatSerializer : PrimitiveStringKeyDictionarySerializer { } + +internal sealed class DictionaryStringInt8Serializer : PrimitiveStringKeyDictionarySerializer { } + +internal sealed class DictionaryStringInt16Serializer : PrimitiveStringKeyDictionarySerializer { } + +internal sealed class DictionaryStringIntSerializer : PrimitiveStringKeyDictionarySerializer { } + +internal sealed class DictionaryStringLongSerializer : PrimitiveStringKeyDictionarySerializer { } + +internal sealed class DictionaryStringStringSerializer : PrimitiveStringKeyDictionarySerializer { } + +internal sealed class DictionaryStringUInt16Serializer : PrimitiveStringKeyDictionarySerializer { } + +internal sealed class DictionaryStringUIntSerializer : PrimitiveStringKeyDictionarySerializer { } + +internal sealed class DictionaryStringULongSerializer : PrimitiveStringKeyDictionarySerializer { } + +// Same-type primitive dictionary serializers. +internal sealed class DictionaryIntIntSerializer : PrimitiveSameTypeDictionarySerializer { } + +internal sealed class DictionaryLongLongSerializer : PrimitiveSameTypeDictionarySerializer { } + +internal sealed class DictionaryUIntUIntSerializer : PrimitiveSameTypeDictionarySerializer { } + +internal sealed class DictionaryULongULongSerializer : PrimitiveSameTypeDictionarySerializer { } diff --git a/csharp/src/Fory/PrimitiveSerializers.cs b/csharp/src/Fory/PrimitiveSerializers.cs new file mode 100644 index 0000000000..9ca2b1c723 --- /dev/null +++ b/csharp/src/Fory/PrimitiveSerializers.cs @@ -0,0 +1,248 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +internal enum ForyStringEncoding : ulong +{ + Latin1 = 0, + Utf16 = 1, + Utf8 = 2, +} + +public sealed class BoolSerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Bool; + + public override bool DefaultValue => false; + + public override void WriteData(ref WriteContext context, in bool value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteUInt8(value ? (byte)1 : (byte)0); + } + + public override bool ReadData(ref ReadContext context) + { + return context.Reader.ReadUInt8() != 0; + } +} + +public sealed class Int8Serializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Int8; + + public override sbyte DefaultValue => 0; + + public override void WriteData(ref WriteContext context, in sbyte value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteInt8(value); + } + + public override sbyte ReadData(ref ReadContext context) + { + return context.Reader.ReadInt8(); + } +} + +public sealed class Int16Serializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Int16; + + public override short DefaultValue => 0; + + public override void WriteData(ref WriteContext context, in short value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteInt16(value); + } + + public override short ReadData(ref ReadContext context) + { + return context.Reader.ReadInt16(); + } +} + +public sealed class Int32Serializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.VarInt32; + + public override int DefaultValue => 0; + + public override void WriteData(ref WriteContext context, in int value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteVarInt32(value); + } + + public override int ReadData(ref ReadContext context) + { + return context.Reader.ReadVarInt32(); + } +} + +public sealed class Int64Serializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.VarInt64; + + public override long DefaultValue => 0; + + public override void WriteData(ref WriteContext context, in long value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteVarInt64(value); + } + + public override long ReadData(ref ReadContext context) + { + return context.Reader.ReadVarInt64(); + } +} + +public sealed class UInt8Serializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.UInt8; + + public override byte DefaultValue => 0; + + public override void WriteData(ref WriteContext context, in byte value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteUInt8(value); + } + + public override byte ReadData(ref ReadContext context) + { + return context.Reader.ReadUInt8(); + } +} + +public sealed class UInt16Serializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.UInt16; + + public override ushort DefaultValue => 0; + + public override void WriteData(ref WriteContext context, in ushort value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteUInt16(value); + } + + public override ushort ReadData(ref ReadContext context) + { + return context.Reader.ReadUInt16(); + } +} + +public sealed class UInt32Serializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.VarUInt32; + + public override uint DefaultValue => 0; + + public override void WriteData(ref WriteContext context, in uint value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteVarUInt32(value); + } + + public override uint ReadData(ref ReadContext context) + { + return context.Reader.ReadVarUInt32(); + } +} + +public sealed class UInt64Serializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.VarUInt64; + + public override ulong DefaultValue => 0; + + public override void WriteData(ref WriteContext context, in ulong value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteVarUInt64(value); + } + + public override ulong ReadData(ref ReadContext context) + { + return context.Reader.ReadVarUInt64(); + } +} + +public sealed class Float32Serializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Float32; + + public override float DefaultValue => 0; + + public override void WriteData(ref WriteContext context, in float value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteFloat32(value); + } + + public override float ReadData(ref ReadContext context) + { + return context.Reader.ReadFloat32(); + } +} + +public sealed class Float64Serializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Float64; + + public override double DefaultValue => 0; + + public override void WriteData(ref WriteContext context, in double value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteFloat64(value); + } + + public override double ReadData(ref ReadContext context) + { + return context.Reader.ReadFloat64(); + } +} + +public sealed class BinarySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Binary; + + public override bool IsNullableType => true; + + public override byte[] DefaultValue => null!; + + public override bool IsNone(in byte[] value) => value is null; + + public override void WriteData(ref WriteContext context, in byte[] value, bool hasGenerics) + { + _ = hasGenerics; + byte[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)safe.Length); + context.Writer.WriteBytes(safe); + } + + public override byte[] ReadData(ref ReadContext context) + { + uint length = context.Reader.ReadVarUInt32(); + return context.Reader.ReadBytes(checked((int)length)); + } +} diff --git a/csharp/src/Fory/RefResolver.cs b/csharp/src/Fory/RefResolver.cs new file mode 100644 index 0000000000..98c12baa74 --- /dev/null +++ b/csharp/src/Fory/RefResolver.cs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public sealed class RefWriter +{ + private readonly Dictionary _refs = new(ReferenceEqualityComparer.Instance); + private uint _nextRefId; + + public bool TryWriteReference(ByteWriter writer, object obj) + { + if (_refs.TryGetValue(obj, out uint refId)) + { + writer.WriteInt8((sbyte)RefFlag.Ref); + writer.WriteVarUInt32(refId); + return true; + } + + _refs[obj] = _nextRefId; + _nextRefId += 1; + writer.WriteInt8((sbyte)RefFlag.RefValue); + return false; + } + + public uint ReserveRefId() + { + uint id = _nextRefId; + _nextRefId += 1; + return id; + } + + public void Reset() + { + _refs.Clear(); + _nextRefId = 0; + } +} + +public sealed class RefReader +{ + private readonly List _refs = []; + + public uint ReserveRefId() + { + uint id = (uint)_refs.Count; + _refs.Add(null); + return id; + } + + public void StoreRef(object? value, uint refId) + { + int index = checked((int)refId); + _refs[index] = value; + } + + public T ReadRef(uint refId) + { + int index = checked((int)refId); + if (index < 0 || index >= _refs.Count) + { + throw new RefException($"ref_id out of range: {refId}"); + } + + if (_refs[index] is T typed) + { + return typed; + } + + throw new RefException($"ref_id {refId} has unexpected runtime type"); + } + + public object? ReadRefValue(uint refId) + { + int index = checked((int)refId); + if (index < 0 || index >= _refs.Count) + { + throw new RefException($"ref_id out of range: {refId}"); + } + + return _refs[index]; + } + + public void Reset() + { + _refs.Clear(); + } +} + diff --git a/csharp/src/Fory/SchemaHash.cs b/csharp/src/Fory/SchemaHash.cs new file mode 100644 index 0000000000..039e99024c --- /dev/null +++ b/csharp/src/Fory/SchemaHash.cs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Text; + +namespace Apache.Fory; + +public static class SchemaHash +{ + public static uint StructHash32(string fingerprint) + { + byte[] bytes = Encoding.UTF8.GetBytes(fingerprint); + (ulong h1, _) = MurmurHash3.X64_128(bytes, 47); + return unchecked((uint)h1); + } +} + diff --git a/csharp/src/Fory/Serializer.cs b/csharp/src/Fory/Serializer.cs new file mode 100644 index 0000000000..94b94a5d5b --- /dev/null +++ b/csharp/src/Fory/Serializer.cs @@ -0,0 +1,227 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public abstract class Serializer +{ + public abstract Type Type { get; } + + public abstract TypeId StaticTypeId { get; } + + public abstract bool IsNullableType { get; } + + public abstract bool IsReferenceTrackableType { get; } + + public abstract object? DefaultObject { get; } + + public abstract bool IsNoneObject(object? value); + + public abstract void WriteDataObject(ref WriteContext context, object? value, bool hasGenerics); + + public abstract object? ReadDataObject(ref ReadContext context); + + public abstract void WriteObject(ref WriteContext context, object? value, RefMode refMode, bool writeTypeInfo, bool hasGenerics); + + public abstract object? ReadObject(ref ReadContext context, RefMode refMode, bool readTypeInfo); + + public abstract void WriteTypeInfo(ref WriteContext context); + + public abstract void ReadTypeInfo(ref ReadContext context); + + public abstract IReadOnlyList CompatibleTypeMetaFields(bool trackRef); + + public abstract Serializer RequireSerializer(); +} + +public abstract class Serializer : Serializer +{ + public override Type Type => typeof(T); + + public abstract override TypeId StaticTypeId { get; } + + public override bool IsNullableType => false; + + public override bool IsReferenceTrackableType => false; + + public virtual T DefaultValue => default!; + + public override object? DefaultObject => DefaultValue; + + public virtual bool IsNone(in T value) + { + _ = value; + return false; + } + + public abstract void WriteData(ref WriteContext context, in T value, bool hasGenerics); + + public abstract T ReadData(ref ReadContext context); + + public virtual void Write(ref WriteContext context, in T value, RefMode refMode, bool writeTypeInfo, bool hasGenerics) + { + if (refMode != RefMode.None) + { + bool wroteTrackingRefFlag = false; + if (refMode == RefMode.Tracking && + IsReferenceTrackableType && + value is object obj) + { + if (context.RefWriter.TryWriteReference(context.Writer, obj)) + { + return; + } + + wroteTrackingRefFlag = true; + } + + if (!wroteTrackingRefFlag) + { + if (IsNullableType && IsNone(value)) + { + context.Writer.WriteInt8((sbyte)RefFlag.Null); + return; + } + + context.Writer.WriteInt8((sbyte)RefFlag.NotNullValue); + } + } + + if (writeTypeInfo) + { + WriteTypeInfo(ref context); + } + + WriteData(ref context, value, hasGenerics); + } + + public virtual T Read(ref ReadContext context, RefMode refMode, bool readTypeInfo) + { + if (refMode != RefMode.None) + { + sbyte rawFlag = context.Reader.ReadInt8(); + RefFlag flag = (RefFlag)rawFlag; + switch (flag) + { + case RefFlag.Null: + return DefaultValue; + case RefFlag.Ref: + { + uint refId = context.Reader.ReadVarUInt32(); + return context.RefReader.ReadRef(refId); + } + case RefFlag.RefValue: + { + uint reservedRefId = context.RefReader.ReserveRefId(); + context.PushPendingReference(reservedRefId); + if (readTypeInfo) + { + ReadTypeInfo(ref context); + } + + T value = ReadData(ref context); + context.FinishPendingReferenceIfNeeded(value); + context.PopPendingReference(); + return value; + } + case RefFlag.NotNullValue: + break; + default: + throw new RefException($"invalid ref flag {rawFlag}"); + } + } + + if (readTypeInfo) + { + ReadTypeInfo(ref context); + } + + return ReadData(ref context); + } + + public override void WriteTypeInfo(ref WriteContext context) + { + context.TypeResolver.WriteTypeInfo(Type, this, ref context); + } + + public override void ReadTypeInfo(ref ReadContext context) + { + context.TypeResolver.ReadTypeInfo(Type, this, ref context); + } + + public override IReadOnlyList CompatibleTypeMetaFields(bool trackRef) + { + _ = trackRef; + return []; + } + + public override bool IsNoneObject(object? value) + { + if (value is null) + { + return IsNullableType; + } + + return value is T typed && IsNone(typed); + } + + public override void WriteDataObject(ref WriteContext context, object? value, bool hasGenerics) + { + WriteData(ref context, CoerceValue(value), hasGenerics); + } + + public override object? ReadDataObject(ref ReadContext context) + { + return ReadData(ref context); + } + + public override void WriteObject(ref WriteContext context, object? value, RefMode refMode, bool writeTypeInfo, bool hasGenerics) + { + Write(ref context, CoerceValue(value), refMode, writeTypeInfo, hasGenerics); + } + + public override object? ReadObject(ref ReadContext context, RefMode refMode, bool readTypeInfo) + { + return Read(ref context, refMode, readTypeInfo); + } + + public override Serializer RequireSerializer() + { + if (typeof(TCast) == typeof(T)) + { + return (Serializer)(object)this; + } + + throw new InvalidDataException($"serializer type mismatch for {typeof(TCast)}"); + } + + protected virtual T CoerceValue(object? value) + { + if (value is T typed) + { + return typed; + } + + if (value is null && IsNullableType) + { + return DefaultValue; + } + + throw new InvalidDataException( + $"serializer {GetType().Name} expected value of type {typeof(T)}, got {value?.GetType()}"); + } +} diff --git a/csharp/src/Fory/StringSerializer.cs b/csharp/src/Fory/StringSerializer.cs new file mode 100644 index 0000000000..289de7c0ce --- /dev/null +++ b/csharp/src/Fory/StringSerializer.cs @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Text; + +namespace Apache.Fory; + +public sealed class StringSerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.String; + + public override bool IsNullableType => true; + + public override string DefaultValue => null!; + + public override bool IsNone(in string value) => value is null; + + public override void WriteData(ref WriteContext context, in string value, bool hasGenerics) + { + _ = hasGenerics; + WriteString(ref context, value ?? string.Empty); + } + + public override string ReadData(ref ReadContext context) + { + return ReadString(ref context); + } + + public static void WriteString(ref WriteContext context, string value) + { + string safe = value ?? string.Empty; + ForyStringEncoding encoding = SelectEncoding(safe); + switch (encoding) + { + case ForyStringEncoding.Latin1: + WriteLatin1(ref context, safe); + break; + case ForyStringEncoding.Utf8: + WriteUtf8(ref context, safe); + break; + case ForyStringEncoding.Utf16: + WriteUtf16(ref context, safe); + break; + default: + throw new EncodingException($"unsupported string encoding {encoding}"); + } + } + + public static string ReadString(ref ReadContext context) + { + ulong header = context.Reader.ReadVarUInt36Small(); + ulong encoding = header & 0x03; + int byteLength = checked((int)(header >> 2)); + byte[] bytes = context.Reader.ReadBytes(byteLength); + return encoding switch + { + (ulong)ForyStringEncoding.Utf8 => Encoding.UTF8.GetString(bytes), + (ulong)ForyStringEncoding.Latin1 => DecodeLatin1(bytes), + (ulong)ForyStringEncoding.Utf16 => DecodeUtf16(bytes), + _ => throw new EncodingException($"unsupported string encoding {encoding}"), + }; + } + + private static string DecodeLatin1(byte[] bytes) + { + return string.Create(bytes.Length, bytes, static (span, b) => + { + for (int i = 0; i < b.Length; i++) + { + span[i] = (char)b[i]; + } + }); + } + + private static string DecodeUtf16(byte[] bytes) + { + if ((bytes.Length & 1) != 0) + { + throw new EncodingException("utf16 byte length is not even"); + } + + return Encoding.Unicode.GetString(bytes); + } + + private static ForyStringEncoding SelectEncoding(string value) + { + int numChars = value.Length; + int sampleNum = Math.Min(64, numChars); + int asciiCount = 0; + int latin1Count = 0; + for (int i = 0; i < sampleNum; i++) + { + char c = value[i]; + if (c < 0x80) + { + asciiCount++; + latin1Count++; + } + else if (c <= 0xFF) + { + latin1Count++; + } + } + + if (latin1Count == numChars || (latin1Count == sampleNum && IsLatin(value, sampleNum))) + { + return ForyStringEncoding.Latin1; + } + + return asciiCount * 2 >= sampleNum ? ForyStringEncoding.Utf8 : ForyStringEncoding.Utf16; + } + + private static bool IsLatin(string value, int start) + { + for (int i = start; i < value.Length; i++) + { + if (value[i] > 0xFF) + { + return false; + } + } + + return true; + } + + private static void WriteLatin1(ref WriteContext context, string value) + { + byte[] latin1 = new byte[value.Length]; + for (int i = 0; i < value.Length; i++) + { + latin1[i] = unchecked((byte)value[i]); + } + + WriteEncodedBytes(ref context, latin1, ForyStringEncoding.Latin1); + } + + private static void WriteUtf8(ref WriteContext context, string value) + { + byte[] utf8 = Encoding.UTF8.GetBytes(value); + WriteEncodedBytes(ref context, utf8, ForyStringEncoding.Utf8); + } + + private static void WriteUtf16(ref WriteContext context, string value) + { + byte[] utf16 = Encoding.Unicode.GetBytes(value); + WriteEncodedBytes(ref context, utf16, ForyStringEncoding.Utf16); + } + + private static void WriteEncodedBytes(ref WriteContext context, byte[] bytes, ForyStringEncoding encoding) + { + ulong header = ((ulong)bytes.Length << 2) | (ulong)encoding; + context.Writer.WriteVarUInt36Small(header); + context.Writer.WriteBytes(bytes); + } +} diff --git a/csharp/src/Fory/TimeSerializers.cs b/csharp/src/Fory/TimeSerializers.cs new file mode 100644 index 0000000000..382d6bf2f0 --- /dev/null +++ b/csharp/src/Fory/TimeSerializers.cs @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +internal static class TimeCodec +{ + private static readonly DateOnly Epoch = new(1970, 1, 1); + + public static void WriteDate(ref WriteContext context, in DateOnly value) + { + context.Writer.WriteInt32(value.DayNumber - Epoch.DayNumber); + } + + public static DateOnly ReadDate(ref ReadContext context) + { + int days = context.Reader.ReadInt32(); + return DateOnly.FromDayNumber(Epoch.DayNumber + days); + } + + public static DateTimeOffset ToDateTimeOffset(in DateTime value) + { + return value.Kind switch + { + DateTimeKind.Utc => new DateTimeOffset(value, TimeSpan.Zero), + DateTimeKind.Local => value, + _ => new DateTimeOffset(DateTime.SpecifyKind(value, DateTimeKind.Utc)), + }; + } + + public static void WriteTimestamp(ref WriteContext context, in DateTimeOffset value) + { + (long seconds, uint nanos) = ToTimestampParts(value); + context.Writer.WriteInt64(seconds); + context.Writer.WriteUInt32(nanos); + } + + public static DateTimeOffset ReadTimestamp(ref ReadContext context) + { + long seconds = context.Reader.ReadInt64(); + uint nanos = context.Reader.ReadUInt32(); + return DateTimeOffset.FromUnixTimeSeconds(seconds).AddTicks(nanos / 100); + } + + public static void WriteDuration(ref WriteContext context, in TimeSpan value) + { + long seconds = value.Ticks / TimeSpan.TicksPerSecond; + int nanos = checked((int)((value.Ticks % TimeSpan.TicksPerSecond) * 100)); + context.Writer.WriteInt64(seconds); + context.Writer.WriteInt32(nanos); + } + + public static TimeSpan ReadDuration(ref ReadContext context) + { + long seconds = context.Reader.ReadInt64(); + int nanos = context.Reader.ReadInt32(); + return TimeSpan.FromSeconds(seconds) + TimeSpan.FromTicks(nanos / 100); + } + + private static (long Seconds, uint Nanos) ToTimestampParts(DateTimeOffset value) + { + long seconds = value.ToUnixTimeSeconds(); + long nanos = (value.Ticks % TimeSpan.TicksPerSecond) * 100; + long normalizedSeconds = seconds + nanos / 1_000_000_000L; + long normalizedNanos = nanos % 1_000_000_000L; + if (normalizedNanos < 0) + { + normalizedNanos += 1_000_000_000L; + normalizedSeconds -= 1; + } + + return (normalizedSeconds, unchecked((uint)normalizedNanos)); + } +} + +public sealed class DateOnlySerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Date; + + public override DateOnly DefaultValue => new(1970, 1, 1); + + public override void WriteData(ref WriteContext context, in DateOnly value, bool hasGenerics) + { + _ = hasGenerics; + TimeCodec.WriteDate(ref context, value); + } + + public override DateOnly ReadData(ref ReadContext context) + { + return TimeCodec.ReadDate(ref context); + } +} + +public sealed class DateTimeOffsetSerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Timestamp; + + public override DateTimeOffset DefaultValue => DateTimeOffset.UnixEpoch; + + public override void WriteData(ref WriteContext context, in DateTimeOffset value, bool hasGenerics) + { + _ = hasGenerics; + TimeCodec.WriteTimestamp(ref context, value); + } + + public override DateTimeOffset ReadData(ref ReadContext context) + { + return TimeCodec.ReadTimestamp(ref context); + } +} + +public sealed class DateTimeSerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Timestamp; + + public override DateTime DefaultValue => DateTime.UnixEpoch; + + public override void WriteData(ref WriteContext context, in DateTime value, bool hasGenerics) + { + _ = hasGenerics; + DateTimeOffset dto = TimeCodec.ToDateTimeOffset(value); + TimeCodec.WriteTimestamp(ref context, dto); + } + + public override DateTime ReadData(ref ReadContext context) + { + return TimeCodec.ReadTimestamp(ref context).UtcDateTime; + } +} + +public sealed class TimeSpanSerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Duration; + + public override TimeSpan DefaultValue => TimeSpan.Zero; + + public override void WriteData(ref WriteContext context, in TimeSpan value, bool hasGenerics) + { + _ = hasGenerics; + TimeCodec.WriteDuration(ref context, value); + } + + public override TimeSpan ReadData(ref ReadContext context) + { + return TimeCodec.ReadDuration(ref context); + } +} + +internal sealed class ListDateOnlySerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.Date, false); + for (int i = 0; i < list.Count; i++) + { + TimeCodec.WriteDate(ref context, list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListDateTimeOffsetSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.Timestamp, false); + for (int i = 0; i < list.Count; i++) + { + TimeCodec.WriteTimestamp(ref context, list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListDateTimeSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.Timestamp, false); + for (int i = 0; i < list.Count; i++) + { + DateTimeOffset dto = TimeCodec.ToDateTimeOffset(list[i]); + TimeCodec.WriteTimestamp(ref context, dto); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} + +internal sealed class ListTimeSpanSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + public override TypeId StaticTypeId => TypeId.List; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override List DefaultValue => null!; + + public override bool IsNone(in List value) => value is null; + + public override void WriteData(ref WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(ref context, list.Count, hasGenerics, TypeId.Duration, false); + for (int i = 0; i < list.Count; i++) + { + TimeCodec.WriteDuration(ref context, list[i]); + } + } + + public override List ReadData(ref ReadContext context) + { + return Fallback.ReadData(ref context); + } +} diff --git a/csharp/src/Fory/TypeId.cs b/csharp/src/Fory/TypeId.cs new file mode 100644 index 0000000000..47857e0915 --- /dev/null +++ b/csharp/src/Fory/TypeId.cs @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public enum TypeId : uint +{ + Unknown = 0, + Bool = 1, + Int8 = 2, + Int16 = 3, + Int32 = 4, + VarInt32 = 5, + Int64 = 6, + VarInt64 = 7, + TaggedInt64 = 8, + UInt8 = 9, + UInt16 = 10, + UInt32 = 11, + VarUInt32 = 12, + UInt64 = 13, + VarUInt64 = 14, + TaggedUInt64 = 15, + Float8 = 16, + Float16 = 17, + BFloat16 = 18, + Float32 = 19, + Float64 = 20, + String = 21, + List = 22, + Set = 23, + Map = 24, + Enum = 25, + NamedEnum = 26, + Struct = 27, + CompatibleStruct = 28, + NamedStruct = 29, + NamedCompatibleStruct = 30, + Ext = 31, + NamedExt = 32, + Union = 33, + TypedUnion = 34, + NamedUnion = 35, + None = 36, + Duration = 37, + Timestamp = 38, + Date = 39, + Decimal = 40, + Binary = 41, + Array = 42, + BoolArray = 43, + Int8Array = 44, + Int16Array = 45, + Int32Array = 46, + Int64Array = 47, + UInt8Array = 48, + UInt16Array = 49, + UInt32Array = 50, + UInt64Array = 51, + Float8Array = 52, + Float16Array = 53, + BFloat16Array = 54, + Float32Array = 55, + Float64Array = 56, +} + +internal static class TypeIdExtensions +{ + public static bool IsUserTypeKind(this TypeId typeId) + { + return typeId switch + { + TypeId.Enum or + TypeId.NamedEnum or + TypeId.Struct or + TypeId.CompatibleStruct or + TypeId.NamedStruct or + TypeId.NamedCompatibleStruct or + TypeId.Ext or + TypeId.NamedExt or + TypeId.TypedUnion or + TypeId.NamedUnion => true, + _ => false, + }; + } + + public static bool NeedsTypeInfoForField(this TypeId typeId) + { + return typeId switch + { + TypeId.Struct or + TypeId.CompatibleStruct or + TypeId.NamedStruct or + TypeId.NamedCompatibleStruct or + TypeId.Ext or + TypeId.NamedExt or + TypeId.Unknown => true, + _ => false, + }; + } +} + diff --git a/csharp/src/Fory/TypeInfo.cs b/csharp/src/Fory/TypeInfo.cs new file mode 100644 index 0000000000..9630db2163 --- /dev/null +++ b/csharp/src/Fory/TypeInfo.cs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public sealed class TypeInfo +{ + internal TypeInfo(Type type, Serializer serializer) + { + Type = type; + Serializer = serializer; + StaticTypeId = serializer.StaticTypeId; + IsNullableType = serializer.IsNullableType; + IsReferenceTrackableType = serializer.IsReferenceTrackableType; + } + + public Type Type { get; } + + internal Serializer Serializer { get; } + + public TypeId StaticTypeId { get; } + + public bool IsNullableType { get; } + + public bool IsReferenceTrackableType { get; } + + internal RegisteredTypeInfo? RegisteredTypeInfo { get; set; } +} diff --git a/csharp/src/Fory/TypeMeta.cs b/csharp/src/Fory/TypeMeta.cs new file mode 100644 index 0000000000..d51b7a2e02 --- /dev/null +++ b/csharp/src/Fory/TypeMeta.cs @@ -0,0 +1,690 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +internal static class TypeMetaConstants +{ + public const int SmallNumFieldsThreshold = 0b1_1111; + public const byte RegisterByNameFlag = 0b10_0000; + public const int FieldNameSizeThreshold = 0b1111; + public const int BigNameThreshold = 0b11_1111; + public const ulong TypeMetaHasFieldsMetaFlag = 1UL << 8; + public const ulong TypeMetaCompressedFlag = 1UL << 9; + public const ulong TypeMetaSizeMask = 0xFF; + public const ulong TypeMetaNumHashBits = 50; + public const ulong TypeMetaHashSeed = 47; + public const uint NoUserTypeId = uint.MaxValue; +} + +public static class TypeMetaEncodings +{ + public static readonly MetaStringEncoding[] NamespaceMetaStringEncodings = + [ + MetaStringEncoding.Utf8, + MetaStringEncoding.AllToLowerSpecial, + MetaStringEncoding.LowerUpperDigitSpecial, + ]; + + public static readonly MetaStringEncoding[] TypeNameMetaStringEncodings = + [ + MetaStringEncoding.Utf8, + MetaStringEncoding.AllToLowerSpecial, + MetaStringEncoding.LowerUpperDigitSpecial, + MetaStringEncoding.FirstToLowerSpecial, + ]; + + public static readonly MetaStringEncoding[] FieldNameMetaStringEncodings = + [ + MetaStringEncoding.Utf8, + MetaStringEncoding.AllToLowerSpecial, + MetaStringEncoding.LowerUpperDigitSpecial, + ]; +} + +internal static class TypeMetaUtils +{ + public static int EncodingIndexOf(IReadOnlyList encodings, MetaStringEncoding encoding) + { + for (int i = 0; i < encodings.Count; i++) + { + if (encodings[i] == encoding) + { + return i; + } + } + + return -1; + } + + public static string LowerCamelToLowerUnderscore(string name) + { + if (name.Length == 0) + { + return name; + } + + Span chars = name.ToCharArray(); + var sb = new System.Text.StringBuilder(name.Length + 4); + for (int i = 0; i < chars.Length; i++) + { + char c = chars[i]; + if (char.IsUpper(c)) + { + if (i > 0) + { + bool prevUpper = char.IsUpper(chars[i - 1]); + bool nextUpperOrEnd = i + 1 >= chars.Length || char.IsUpper(chars[i + 1]); + if (!prevUpper || !nextUpperOrEnd) + { + sb.Append('_'); + } + } + + sb.Append(char.ToLowerInvariant(c)); + } + else + { + sb.Append(c); + } + } + + return sb.ToString(); + } +} + +public sealed class TypeMetaFieldType : IEquatable +{ + public TypeMetaFieldType( + uint typeId, + bool nullable, + bool trackRef = false, + IReadOnlyList? generics = null) + { + TypeId = typeId; + Nullable = nullable; + TrackRef = trackRef; + Generics = generics ?? []; + } + + public uint TypeId { get; } + + public bool Nullable { get; } + + public bool TrackRef { get; } + + public IReadOnlyList Generics { get; } + + internal void Write(ByteWriter writer, bool writeFlags, bool? nullableOverride = null) + { + if (writeFlags) + { + uint header = TypeId << 2; + if (nullableOverride ?? Nullable) + { + header |= 0b10; + } + + if (TrackRef) + { + header |= 0b1; + } + + writer.WriteVarUInt32(header); + } + else + { + writer.WriteUInt8(unchecked((byte)TypeId)); + } + + if (TypeId is (uint)global::Apache.Fory.TypeId.List or (uint)global::Apache.Fory.TypeId.Set) + { + TypeMetaFieldType element = Generics.Count > 0 + ? Generics[0] + : new TypeMetaFieldType((uint)global::Apache.Fory.TypeId.Unknown, true); + element.Write(writer, true, element.Nullable); + } + else if (TypeId == (uint)global::Apache.Fory.TypeId.Map) + { + TypeMetaFieldType key = Generics.Count > 0 + ? Generics[0] + : new TypeMetaFieldType((uint)global::Apache.Fory.TypeId.Unknown, true); + TypeMetaFieldType value = Generics.Count > 1 + ? Generics[1] + : new TypeMetaFieldType((uint)global::Apache.Fory.TypeId.Unknown, true); + key.Write(writer, true, key.Nullable); + value.Write(writer, true, value.Nullable); + } + } + + internal static TypeMetaFieldType Read( + ByteReader reader, + bool readFlags, + bool? nullable = null, + bool? trackRef = null) + { + uint header = readFlags ? reader.ReadVarUInt32() : reader.ReadUInt8(); + + uint typeId; + bool resolvedNullable; + bool resolvedTrackRef; + if (readFlags) + { + typeId = header >> 2; + resolvedNullable = (header & 0b10) != 0; + resolvedTrackRef = (header & 0b1) != 0; + } + else + { + typeId = header; + resolvedNullable = nullable ?? false; + resolvedTrackRef = trackRef ?? false; + } + + if (typeId is (uint)global::Apache.Fory.TypeId.List or (uint)global::Apache.Fory.TypeId.Set) + { + TypeMetaFieldType element = Read(reader, true); + return new TypeMetaFieldType(typeId, resolvedNullable, resolvedTrackRef, [element]); + } + + if (typeId == (uint)global::Apache.Fory.TypeId.Map) + { + TypeMetaFieldType key = Read(reader, true); + TypeMetaFieldType value = Read(reader, true); + return new TypeMetaFieldType(typeId, resolvedNullable, resolvedTrackRef, [key, value]); + } + + return new TypeMetaFieldType(typeId, resolvedNullable, resolvedTrackRef); + } + + public bool Equals(TypeMetaFieldType? other) + { + if (other is null) + { + return false; + } + + return TypeId == other.TypeId && + Nullable == other.Nullable && + TrackRef == other.TrackRef && + Generics.SequenceEqual(other.Generics); + } + + public override bool Equals(object? obj) + { + return obj is TypeMetaFieldType other && Equals(other); + } + + public override int GetHashCode() + { + HashCode hc = new(); + hc.Add(TypeId); + hc.Add(Nullable); + hc.Add(TrackRef); + foreach (TypeMetaFieldType t in Generics) + { + hc.Add(t); + } + + return hc.ToHashCode(); + } +} + +public sealed class TypeMetaFieldInfo : IEquatable +{ + public TypeMetaFieldInfo(short? fieldId, string fieldName, TypeMetaFieldType fieldType) + { + FieldId = fieldId; + FieldName = fieldName; + FieldType = fieldType; + } + + public short? FieldId { get; } + + public string FieldName { get; } + + public TypeMetaFieldType FieldType { get; } + + internal void Write(ByteWriter writer) + { + byte header = 0; + if (FieldType.TrackRef) + { + header |= 0b1; + } + + if (FieldType.Nullable) + { + header |= 0b10; + } + + if (FieldId.HasValue) + { + short fieldId = FieldId.Value; + if (fieldId < 0) + { + throw new EncodingException("negative field id is invalid"); + } + + int size = fieldId; + header |= 0b11 << 6; + if (size >= TypeMetaConstants.FieldNameSizeThreshold) + { + header |= 0b0011_1100; + writer.WriteUInt8(header); + writer.WriteVarUInt32((uint)(size - TypeMetaConstants.FieldNameSizeThreshold)); + } + else + { + header |= (byte)(size << 2); + writer.WriteUInt8(header); + } + + FieldType.Write(writer, false); + return; + } + + string snakeName = TypeMetaUtils.LowerCamelToLowerUnderscore(FieldName); + MetaString encoded = MetaStringEncoder.FieldName.Encode(snakeName, TypeMetaEncodings.FieldNameMetaStringEncodings); + int encodingIndex = Array.IndexOf(TypeMetaEncodings.FieldNameMetaStringEncodings, encoded.Encoding); + if (encodingIndex < 0) + { + throw new EncodingException("unsupported field name encoding"); + } + + int encodedSize = encoded.Bytes.Length - 1; + header |= (byte)(encodingIndex << 6); + if (encodedSize >= TypeMetaConstants.FieldNameSizeThreshold) + { + header |= 0b0011_1100; + writer.WriteUInt8(header); + writer.WriteVarUInt32((uint)(encodedSize - TypeMetaConstants.FieldNameSizeThreshold)); + } + else + { + header |= (byte)(encodedSize << 2); + writer.WriteUInt8(header); + } + + FieldType.Write(writer, false); + writer.WriteBytes(encoded.Bytes); + } + + internal static TypeMetaFieldInfo Read(ByteReader reader) + { + byte header = reader.ReadUInt8(); + int encodingFlags = (header >> 6) & 0b11; + int size = (header >> 2) & 0b1111; + if (size == TypeMetaConstants.FieldNameSizeThreshold) + { + size += (int)reader.ReadVarUInt32(); + } + + size += 1; + + bool nullable = (header & 0b10) != 0; + bool trackRef = (header & 0b1) != 0; + TypeMetaFieldType fieldType = TypeMetaFieldType.Read(reader, false, nullable, trackRef); + + if (encodingFlags == 3) + { + short fieldId = unchecked((short)(size - 1)); + return new TypeMetaFieldInfo(fieldId, $"$tag{fieldId}", fieldType); + } + + if (encodingFlags >= TypeMetaEncodings.FieldNameMetaStringEncodings.Length) + { + throw new InvalidDataException("invalid field name encoding id"); + } + + byte[] nameBytes = reader.ReadBytes(size); + string name = MetaStringDecoder.FieldName.Decode( + nameBytes, + TypeMetaEncodings.FieldNameMetaStringEncodings[encodingFlags]).Value; + return new TypeMetaFieldInfo(null, name, fieldType); + } + + public bool Equals(TypeMetaFieldInfo? other) + { + if (other is null) + { + return false; + } + + return FieldId == other.FieldId && + FieldName == other.FieldName && + FieldType.Equals(other.FieldType); + } + + public override bool Equals(object? obj) + { + return obj is TypeMetaFieldInfo other && Equals(other); + } + + public override int GetHashCode() + { + return HashCode.Combine(FieldId, FieldName, FieldType); + } +} + +public sealed class TypeMeta : IEquatable +{ + public TypeMeta( + uint? typeId, + uint? userTypeId, + MetaString namespaceName, + MetaString typeName, + bool registerByName, + IReadOnlyList fields, + bool hasFieldsMeta = true, + bool compressed = false, + ulong headerHash = 0) + { + if (registerByName) + { + if (typeName.Value.Length == 0) + { + throw new EncodingException("type name is required in register-by-name mode"); + } + } + else + { + if (!typeId.HasValue) + { + throw new EncodingException("type id is required in register-by-id mode"); + } + + if (!userTypeId.HasValue || userTypeId.Value == TypeMetaConstants.NoUserTypeId) + { + throw new EncodingException("user type id is required in register-by-id mode"); + } + } + + TypeId = typeId; + UserTypeId = userTypeId; + NamespaceName = namespaceName; + TypeName = typeName; + RegisterByName = registerByName; + Fields = fields; + HasFieldsMeta = hasFieldsMeta; + Compressed = compressed; + HeaderHash = headerHash; + } + + public uint? TypeId { get; } + + public uint? UserTypeId { get; } + + public MetaString NamespaceName { get; } + + public MetaString TypeName { get; } + + public bool RegisterByName { get; } + + public IReadOnlyList Fields { get; } + + public bool HasFieldsMeta { get; } + + public bool Compressed { get; } + + public ulong HeaderHash { get; } + + public byte[] Encode() + { + if (Compressed) + { + throw new EncodingException("compressed TypeMeta is not supported yet"); + } + + byte[] body = EncodeBody(); + (ulong bodyHash, _) = MurmurHash3.X64_128(body, TypeMetaConstants.TypeMetaHashSeed); + ulong shifted = bodyHash << (int)(64 - TypeMetaConstants.TypeMetaNumHashBits); + long signed = unchecked((long)shifted); + long absSigned = signed == long.MinValue ? signed : Math.Abs(signed); + + ulong header = unchecked((ulong)absSigned); + if (HasFieldsMeta) + { + header |= TypeMetaConstants.TypeMetaHasFieldsMetaFlag; + } + + if (Compressed) + { + header |= TypeMetaConstants.TypeMetaCompressedFlag; + } + + header |= (ulong)Math.Min(body.Length, (int)TypeMetaConstants.TypeMetaSizeMask); + ByteWriter writer = new(body.Length + 16); + writer.WriteUInt64(header); + if (body.Length >= (int)TypeMetaConstants.TypeMetaSizeMask) + { + writer.WriteVarUInt32((uint)(body.Length - (int)TypeMetaConstants.TypeMetaSizeMask)); + } + + writer.WriteBytes(body); + return writer.ToArray(); + } + + public static TypeMeta Decode(byte[] bytes) + { + return Decode(new ByteReader(bytes)); + } + + public static TypeMeta Decode(ByteReader reader) + { + ulong header = reader.ReadUInt64(); + bool compressed = (header & TypeMetaConstants.TypeMetaCompressedFlag) != 0; + bool hasFieldsMeta = (header & TypeMetaConstants.TypeMetaHasFieldsMetaFlag) != 0; + int metaSize = (int)(header & TypeMetaConstants.TypeMetaSizeMask); + if (metaSize == (int)TypeMetaConstants.TypeMetaSizeMask) + { + metaSize += (int)reader.ReadVarUInt32(); + } + + byte[] encodedBody = reader.ReadBytes(metaSize); + if (compressed) + { + throw new EncodingException("compressed TypeMeta is not supported yet"); + } + + ByteReader bodyReader = new(encodedBody); + byte metaHeader = bodyReader.ReadUInt8(); + int numFields = metaHeader & TypeMetaConstants.SmallNumFieldsThreshold; + if (numFields == TypeMetaConstants.SmallNumFieldsThreshold) + { + numFields += (int)bodyReader.ReadVarUInt32(); + } + + bool registerByName = (metaHeader & TypeMetaConstants.RegisterByNameFlag) != 0; + uint? typeId; + uint? userTypeId; + MetaString namespaceName; + MetaString typeName; + if (registerByName) + { + namespaceName = ReadName(bodyReader, MetaStringDecoder.Namespace, TypeMetaEncodings.NamespaceMetaStringEncodings); + typeName = ReadName(bodyReader, MetaStringDecoder.TypeName, TypeMetaEncodings.TypeNameMetaStringEncodings); + typeId = null; + userTypeId = null; + } + else + { + typeId = bodyReader.ReadUInt8(); + userTypeId = bodyReader.ReadVarUInt32(); + namespaceName = MetaString.Empty('.', '_'); + typeName = MetaString.Empty('$', '_'); + } + + List fields = new(numFields); + for (int i = 0; i < numFields; i++) + { + fields.Add(TypeMetaFieldInfo.Read(bodyReader)); + } + + if (bodyReader.Remaining != 0) + { + throw new InvalidDataException("unexpected trailing bytes in TypeMeta body"); + } + + return new TypeMeta( + typeId, + userTypeId, + namespaceName, + typeName, + registerByName, + fields, + hasFieldsMeta, + compressed, + header >> (int)(64 - TypeMetaConstants.TypeMetaNumHashBits)); + } + + private byte[] EncodeBody() + { + ByteWriter writer = new(128); + byte metaHeader = (byte)Math.Min(Fields.Count, TypeMetaConstants.SmallNumFieldsThreshold); + if (RegisterByName) + { + metaHeader |= TypeMetaConstants.RegisterByNameFlag; + } + + writer.WriteUInt8(metaHeader); + if (Fields.Count >= TypeMetaConstants.SmallNumFieldsThreshold) + { + writer.WriteVarUInt32((uint)(Fields.Count - TypeMetaConstants.SmallNumFieldsThreshold)); + } + + if (RegisterByName) + { + WriteName(writer, NamespaceName, TypeMetaEncodings.NamespaceMetaStringEncodings); + WriteName(writer, TypeName, TypeMetaEncodings.TypeNameMetaStringEncodings); + } + else + { + if (!TypeId.HasValue) + { + throw new EncodingException("type id is required in register-by-id mode"); + } + + if (!UserTypeId.HasValue || UserTypeId == TypeMetaConstants.NoUserTypeId) + { + throw new EncodingException("user type id is required in register-by-id mode"); + } + + writer.WriteUInt8(unchecked((byte)TypeId.Value)); + writer.WriteVarUInt32(UserTypeId.Value); + } + + foreach (TypeMetaFieldInfo field in Fields) + { + field.Write(writer); + } + + return writer.ToArray(); + } + + private static void WriteName(ByteWriter writer, MetaString name, IReadOnlyList encodings) + { + MetaString normalized = encodings.Contains(name.Encoding) + ? name + : (encodings.SequenceEqual(TypeMetaEncodings.NamespaceMetaStringEncodings) + ? MetaStringEncoder.Namespace.Encode(name.Value, encodings) + : encodings.SequenceEqual(TypeMetaEncodings.TypeNameMetaStringEncodings) + ? MetaStringEncoder.TypeName.Encode(name.Value, encodings) + : MetaStringEncoder.FieldName.Encode(name.Value, encodings)); + + int encodingIndex = TypeMetaUtils.EncodingIndexOf(encodings, normalized.Encoding); + if (encodingIndex < 0) + { + throw new EncodingException("failed to normalize meta string encoding"); + } + + byte[] bytes = normalized.Bytes; + if (bytes.Length >= TypeMetaConstants.BigNameThreshold) + { + writer.WriteUInt8((byte)((TypeMetaConstants.BigNameThreshold << 2) | encodingIndex)); + writer.WriteVarUInt32((uint)(bytes.Length - TypeMetaConstants.BigNameThreshold)); + } + else + { + writer.WriteUInt8((byte)((bytes.Length << 2) | encodingIndex)); + } + + writer.WriteBytes(bytes); + } + + private static MetaString ReadName(ByteReader reader, MetaStringDecoder decoder, IReadOnlyList encodings) + { + byte header = reader.ReadUInt8(); + int encodingIndex = header & 0b11; + if (encodingIndex >= encodings.Count) + { + throw new InvalidDataException("invalid meta string encoding index"); + } + + int length = header >> 2; + if (length >= TypeMetaConstants.BigNameThreshold) + { + length = TypeMetaConstants.BigNameThreshold + (int)reader.ReadVarUInt32(); + } + + byte[] bytes = reader.ReadBytes(length); + return decoder.Decode(bytes, encodings[encodingIndex]); + } + + public bool Equals(TypeMeta? other) + { + if (other is null) + { + return false; + } + + return TypeId == other.TypeId && + UserTypeId == other.UserTypeId && + NamespaceName.Equals(other.NamespaceName) && + TypeName.Equals(other.TypeName) && + RegisterByName == other.RegisterByName && + Fields.SequenceEqual(other.Fields) && + HasFieldsMeta == other.HasFieldsMeta && + Compressed == other.Compressed && + HeaderHash == other.HeaderHash; + } + + public override bool Equals(object? obj) + { + return obj is TypeMeta other && Equals(other); + } + + public override int GetHashCode() + { + HashCode hc = new(); + hc.Add(TypeId); + hc.Add(UserTypeId); + hc.Add(NamespaceName); + hc.Add(TypeName); + hc.Add(RegisterByName); + hc.Add(HasFieldsMeta); + hc.Add(Compressed); + hc.Add(HeaderHash); + foreach (TypeMetaFieldInfo f in Fields) + { + hc.Add(f); + } + + return hc.ToHashCode(); + } + +} diff --git a/csharp/src/Fory/TypeResolver.cs b/csharp/src/Fory/TypeResolver.cs new file mode 100644 index 0000000000..82078276b3 --- /dev/null +++ b/csharp/src/Fory/TypeResolver.cs @@ -0,0 +1,1497 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Collections.Concurrent; + +namespace Apache.Fory; + +internal sealed record RegisteredTypeInfo( + uint? UserTypeId, + TypeId Kind, + bool RegisterByName, + MetaString? NamespaceName, + MetaString TypeName, + Serializer Serializer); + +internal enum DynamicRegistrationMode +{ + IdOnly, + NameOnly, + Mixed, +} + +internal readonly record struct TypeNameKey(string NamespaceName, string TypeName); + +internal sealed class TypeReader +{ + public required TypeId Kind { get; init; } + public required Func Reader { get; init; } + public required Func CompatibleReader { get; init; } +} + +public sealed class TypeResolver +{ + private static readonly ConcurrentDictionary> GeneratedFactories = new(); + + private readonly Dictionary _byUserTypeId = []; + private readonly Dictionary _byTypeName = []; + private readonly Dictionary _registrationModeByKind = []; + + private readonly Dictionary _typeInfos = []; + + public static void RegisterGenerated() + where TSerializer : Serializer, new() + { + Type type = typeof(T); + GeneratedFactories[type] = CreateSerializer; + } + + public Serializer GetSerializer() + { + return GetTypeInfo().Serializer.RequireSerializer(); + } + + public TypeInfo GetTypeInfo(Type type) + { + return GetOrCreateTypeInfo(type, null); + } + + public TypeInfo GetTypeInfo() + { + return GetTypeInfo(typeof(T)); + } + + private TypeInfo GetOrCreateTypeInfo(Type type, Serializer? explicitSerializer) + { + if (_typeInfos.TryGetValue(type, out TypeInfo? existing)) + { + if (explicitSerializer is null || ReferenceEquals(existing.Serializer, explicitSerializer)) + { + return existing; + } + + if (existing.RegisteredTypeInfo is not null) + { + throw new InvalidDataException($"cannot override serializer for registered type {type}"); + } + } + + Serializer serializer = explicitSerializer ?? CreateBindingCore(type); + TypeInfo typeInfo = new(type, serializer); + if (_typeInfos.TryGetValue(type, out TypeInfo? previous)) + { + typeInfo.RegisteredTypeInfo = previous.RegisteredTypeInfo; + } + + _typeInfos[type] = typeInfo; + return typeInfo; + } + + internal Serializer GetSerializer(Type type) + { + return GetOrCreateTypeInfo(type, null).Serializer; + } + + internal Serializer RegisterSerializer() + where TSerializer : Serializer, new() + { + Serializer serializerBinding = CreateSerializer(); + RegisterSerializer(typeof(T), serializerBinding); + return serializerBinding; + } + + internal void RegisterSerializer(Type type, Serializer serializerBinding) + { + GetOrCreateTypeInfo(type, serializerBinding); + } + + internal void Register(Type type, uint id, Serializer? explicitSerializer = null) + { + TypeInfo typeInfo = GetOrCreateTypeInfo(type, explicitSerializer); + Serializer serializer = typeInfo.Serializer; + RegisteredTypeInfo info = new( + id, + serializer.StaticTypeId, + false, + null, + MetaString.Empty('$', '_'), + serializer); + typeInfo.RegisteredTypeInfo = info; + MarkRegistrationMode(info.Kind, false); + _byUserTypeId[id] = new TypeReader + { + Kind = serializer.StaticTypeId, + Reader = context => serializer.ReadObject(ref context, RefMode.None, false), + CompatibleReader = (context, typeMeta) => + { + context.PushCompatibleTypeMeta(type, typeMeta); + return serializer.ReadObject(ref context, RefMode.None, false); + }, + }; + } + + internal void Register(Type type, string namespaceName, string typeName, Serializer? explicitSerializer = null) + { + TypeInfo typeInfo = GetOrCreateTypeInfo(type, explicitSerializer); + Serializer serializer = typeInfo.Serializer; + MetaString namespaceMeta = MetaStringEncoder.Namespace.Encode(namespaceName, TypeMetaEncodings.NamespaceMetaStringEncodings); + MetaString typeNameMeta = MetaStringEncoder.TypeName.Encode(typeName, TypeMetaEncodings.TypeNameMetaStringEncodings); + RegisteredTypeInfo info = new( + null, + serializer.StaticTypeId, + true, + namespaceMeta, + typeNameMeta, + serializer); + typeInfo.RegisteredTypeInfo = info; + MarkRegistrationMode(info.Kind, true); + _byTypeName[new TypeNameKey(namespaceName, typeName)] = new TypeReader + { + Kind = serializer.StaticTypeId, + Reader = context => serializer.ReadObject(ref context, RefMode.None, false), + CompatibleReader = (context, typeMeta) => + { + context.PushCompatibleTypeMeta(type, typeMeta); + return serializer.ReadObject(ref context, RefMode.None, false); + }, + }; + } + + internal RegisteredTypeInfo? GetRegisteredTypeInfo(Type type) + { + return _typeInfos.TryGetValue(type, out TypeInfo? typeInfo) + ? typeInfo.RegisteredTypeInfo + : null; + } + + internal RegisteredTypeInfo RequireRegisteredTypeInfo(Type type) + { + RegisteredTypeInfo? info = GetRegisteredTypeInfo(type); + if (info is not null) + { + return info; + } + + throw new TypeNotRegisteredException($"{type} is not registered"); + } + + internal void WriteTypeInfo(Type type, Serializer serializer, ref WriteContext context) + { + TypeId staticTypeId = serializer.StaticTypeId; + if (!staticTypeId.IsUserTypeKind()) + { + context.Writer.WriteUInt8((byte)staticTypeId); + return; + } + + RegisteredTypeInfo info = RequireRegisteredTypeInfo(type); + TypeId wireTypeId = ResolveWireTypeId(info.Kind, info.RegisterByName, context.Compatible); + context.Writer.WriteUInt8((byte)wireTypeId); + switch (wireTypeId) + { + case TypeId.CompatibleStruct: + case TypeId.NamedCompatibleStruct: + { + TypeMeta typeMeta = BuildCompatibleTypeMeta(info, wireTypeId, context.TrackRef); + context.WriteCompatibleTypeMeta(type, typeMeta); + return; + } + case TypeId.NamedEnum: + case TypeId.NamedStruct: + case TypeId.NamedExt: + case TypeId.NamedUnion: + { + if (context.Compatible) + { + TypeMeta typeMeta = BuildCompatibleTypeMeta(info, wireTypeId, context.TrackRef); + context.WriteCompatibleTypeMeta(type, typeMeta); + } + else + { + if (info.NamespaceName is null) + { + throw new InvalidDataException("missing namespace metadata for name-registered type"); + } + + WriteMetaString( + ref context, + info.NamespaceName.Value, + TypeMetaEncodings.NamespaceMetaStringEncodings, + MetaStringEncoder.Namespace); + WriteMetaString( + ref context, + info.TypeName, + TypeMetaEncodings.TypeNameMetaStringEncodings, + MetaStringEncoder.TypeName); + } + + return; + } + default: + if (!info.RegisterByName && WireTypeNeedsUserTypeId(wireTypeId)) + { + if (!info.UserTypeId.HasValue) + { + throw new InvalidDataException("missing user type id for id-registered type"); + } + + context.Writer.WriteVarUInt32(info.UserTypeId.Value); + } + + return; + } + } + + internal void ReadTypeInfo(Type type, Serializer serializer, ref ReadContext context) + { + uint rawTypeId = context.Reader.ReadVarUInt32(); + if (!Enum.IsDefined(typeof(TypeId), rawTypeId)) + { + throw new InvalidDataException($"unknown type id {rawTypeId}"); + } + + TypeId typeId = (TypeId)rawTypeId; + TypeId staticTypeId = serializer.StaticTypeId; + if (!staticTypeId.IsUserTypeKind()) + { + if (typeId != staticTypeId) + { + throw new TypeMismatchException((uint)staticTypeId, rawTypeId); + } + + return; + } + + RegisteredTypeInfo info = RequireRegisteredTypeInfo(type); + HashSet allowed = AllowedWireTypeIds(info.Kind, info.RegisterByName, context.Compatible); + if (!allowed.Contains(typeId)) + { + uint expected = 0; + foreach (TypeId allowedType in allowed) + { + expected = (uint)allowedType; + break; + } + + throw new TypeMismatchException(expected, rawTypeId); + } + + switch (typeId) + { + case TypeId.CompatibleStruct: + case TypeId.NamedCompatibleStruct: + { + TypeMeta remoteTypeMeta = context.ReadCompatibleTypeMeta(); + ValidateCompatibleTypeMeta(remoteTypeMeta, info, allowed, typeId); + context.PushCompatibleTypeMeta(type, remoteTypeMeta); + return; + } + case TypeId.NamedEnum: + case TypeId.NamedStruct: + case TypeId.NamedExt: + case TypeId.NamedUnion: + { + if (context.Compatible) + { + TypeMeta remoteTypeMeta = context.ReadCompatibleTypeMeta(); + ValidateCompatibleTypeMeta(remoteTypeMeta, info, allowed, typeId); + if (typeId == TypeId.NamedStruct) + { + context.PushCompatibleTypeMeta(type, remoteTypeMeta); + } + } + else + { + MetaString namespaceName = ReadMetaString( + ref context, + MetaStringDecoder.Namespace, + TypeMetaEncodings.NamespaceMetaStringEncodings); + MetaString typeName = ReadMetaString( + ref context, + MetaStringDecoder.TypeName, + TypeMetaEncodings.TypeNameMetaStringEncodings); + if (!info.RegisterByName || info.NamespaceName is null) + { + throw new InvalidDataException("received name-registered type info for id-registered local type"); + } + + if (namespaceName.Value != info.NamespaceName.Value.Value || typeName.Value != info.TypeName.Value) + { + throw new InvalidDataException( + $"type name mismatch: expected {info.NamespaceName.Value.Value}::{info.TypeName.Value}, got {namespaceName.Value}::{typeName.Value}"); + } + } + + return; + } + default: + if (!info.RegisterByName && WireTypeNeedsUserTypeId(typeId)) + { + if (!info.UserTypeId.HasValue) + { + throw new InvalidDataException("missing user type id for id-registered local type"); + } + + uint remoteUserTypeId = context.Reader.ReadVarUInt32(); + if (remoteUserTypeId != info.UserTypeId.Value) + { + throw new TypeMismatchException(info.UserTypeId.Value, remoteUserTypeId); + } + } + + return; + } + } + + internal static TypeId ResolveWireTypeId(TypeId declaredKind, bool registerByName, bool compatible) + { + TypeId baseKind = NormalizeBaseKind(declaredKind); + if (registerByName) + { + return NamedKind(baseKind, compatible); + } + + return IdKind(baseKind, compatible); + } + + internal static HashSet AllowedWireTypeIds(TypeId declaredKind, bool registerByName, bool compatible) + { + TypeId baseKind = NormalizeBaseKind(declaredKind); + TypeId expected = ResolveWireTypeId(declaredKind, registerByName, compatible); + HashSet allowed = [expected]; + if (baseKind == TypeId.Struct && compatible) + { + allowed.Add(TypeId.CompatibleStruct); + allowed.Add(TypeId.NamedCompatibleStruct); + allowed.Add(TypeId.Struct); + allowed.Add(TypeId.NamedStruct); + } + + return allowed; + } + + public object? ReadByUserTypeId(uint userTypeId, ref ReadContext context, TypeMeta? compatibleTypeMeta = null) + { + if (!_byUserTypeId.TryGetValue(userTypeId, out TypeReader? entry)) + { + throw new TypeNotRegisteredException($"user_type_id={userTypeId}"); + } + + return compatibleTypeMeta is null + ? entry.Reader(context) + : entry.CompatibleReader(context, compatibleTypeMeta); + } + + public object? ReadByTypeName(string namespaceName, string typeName, ref ReadContext context, TypeMeta? compatibleTypeMeta = null) + { + if (!_byTypeName.TryGetValue(new TypeNameKey(namespaceName, typeName), out TypeReader? entry)) + { + throw new TypeNotRegisteredException($"namespace={namespaceName}, type={typeName}"); + } + + return compatibleTypeMeta is null + ? entry.Reader(context) + : entry.CompatibleReader(context, compatibleTypeMeta); + } + + public DynamicTypeInfo ReadDynamicTypeInfo(ref ReadContext context) + { + uint rawTypeId = context.Reader.ReadVarUInt32(); + if (!Enum.IsDefined(typeof(TypeId), rawTypeId)) + { + throw new InvalidDataException($"unknown dynamic type id {rawTypeId}"); + } + + TypeId wireTypeId = (TypeId)rawTypeId; + switch (wireTypeId) + { + case TypeId.CompatibleStruct: + case TypeId.NamedCompatibleStruct: + { + TypeMeta typeMeta = context.ReadCompatibleTypeMeta(); + if (typeMeta.RegisterByName) + { + return new DynamicTypeInfo(wireTypeId, null, typeMeta.NamespaceName, typeMeta.TypeName, typeMeta); + } + + return new DynamicTypeInfo(wireTypeId, typeMeta.UserTypeId, null, null, typeMeta); + } + case TypeId.NamedStruct: + case TypeId.NamedEnum: + case TypeId.NamedExt: + case TypeId.NamedUnion: + { + MetaString namespaceName = ReadMetaString(context.Reader, MetaStringDecoder.Namespace, TypeMetaEncodings.NamespaceMetaStringEncodings); + MetaString typeName = ReadMetaString(context.Reader, MetaStringDecoder.TypeName, TypeMetaEncodings.TypeNameMetaStringEncodings); + return new DynamicTypeInfo(wireTypeId, null, namespaceName, typeName, null); + } + case TypeId.Struct: + case TypeId.Enum: + case TypeId.Ext: + case TypeId.TypedUnion: + { + DynamicRegistrationMode mode = DynamicRegistrationModeFor(wireTypeId); + if (mode == DynamicRegistrationMode.IdOnly) + { + return new DynamicTypeInfo(wireTypeId, context.Reader.ReadVarUInt32(), null, null, null); + } + + if (mode == DynamicRegistrationMode.NameOnly) + { + MetaString namespaceName = ReadMetaString(context.Reader, MetaStringDecoder.Namespace, TypeMetaEncodings.NamespaceMetaStringEncodings); + MetaString typeName = ReadMetaString(context.Reader, MetaStringDecoder.TypeName, TypeMetaEncodings.TypeNameMetaStringEncodings); + return new DynamicTypeInfo(wireTypeId, null, namespaceName, typeName, null); + } + + throw new InvalidDataException($"ambiguous dynamic type registration mode for {wireTypeId}"); + } + default: + return new DynamicTypeInfo(wireTypeId, null, null, null, null); + } + } + + public object? ReadDynamicValue(DynamicTypeInfo typeInfo, ref ReadContext context) + { + switch (typeInfo.WireTypeId) + { + case TypeId.Bool: + return context.Reader.ReadUInt8() != 0; + case TypeId.Int8: + return context.Reader.ReadInt8(); + case TypeId.Int16: + return context.Reader.ReadInt16(); + case TypeId.Int32: + return context.Reader.ReadInt32(); + case TypeId.VarInt32: + return context.Reader.ReadVarInt32(); + case TypeId.Int64: + return context.Reader.ReadInt64(); + case TypeId.VarInt64: + return context.Reader.ReadVarInt64(); + case TypeId.TaggedInt64: + return context.Reader.ReadTaggedInt64(); + case TypeId.UInt8: + return context.Reader.ReadUInt8(); + case TypeId.UInt16: + return context.Reader.ReadUInt16(); + case TypeId.UInt32: + return context.Reader.ReadUInt32(); + case TypeId.VarUInt32: + return context.Reader.ReadVarUInt32(); + case TypeId.UInt64: + return context.Reader.ReadUInt64(); + case TypeId.VarUInt64: + return context.Reader.ReadVarUInt64(); + case TypeId.TaggedUInt64: + return context.Reader.ReadTaggedUInt64(); + case TypeId.Float32: + return context.Reader.ReadFloat32(); + case TypeId.Float64: + return context.Reader.ReadFloat64(); + case TypeId.String: + return StringSerializer.ReadString(ref context); + case TypeId.Date: + return TimeCodec.ReadDate(ref context); + case TypeId.Timestamp: + return TimeCodec.ReadTimestamp(ref context); + case TypeId.Duration: + return TimeCodec.ReadDuration(ref context); + case TypeId.Binary: + case TypeId.UInt8Array: + return ReadBinary(ref context); + case TypeId.BoolArray: + return ReadBoolArray(ref context); + case TypeId.Int8Array: + return ReadInt8Array(ref context); + case TypeId.Int16Array: + return ReadInt16Array(ref context); + case TypeId.Int32Array: + return ReadInt32Array(ref context); + case TypeId.Int64Array: + return ReadInt64Array(ref context); + case TypeId.UInt16Array: + return ReadUInt16Array(ref context); + case TypeId.UInt32Array: + return ReadUInt32Array(ref context); + case TypeId.UInt64Array: + return ReadUInt64Array(ref context); + case TypeId.Float32Array: + return ReadFloat32Array(ref context); + case TypeId.Float64Array: + return ReadFloat64Array(ref context); + case TypeId.List: + return DynamicContainerCodec.ReadListPayload(ref context); + case TypeId.Set: + return DynamicContainerCodec.ReadSetPayload(ref context); + case TypeId.Map: + return DynamicContainerCodec.ReadMapPayload(ref context); + case TypeId.Union: + return GetSerializer().Read(ref context, RefMode.None, false); + case TypeId.Struct: + case TypeId.Enum: + case TypeId.Ext: + case TypeId.TypedUnion: + { + if (typeInfo.UserTypeId.HasValue) + { + return ReadByUserTypeId(typeInfo.UserTypeId.Value, ref context); + } + + if (typeInfo.NamespaceName.HasValue && typeInfo.TypeName.HasValue) + { + return ReadByTypeName(typeInfo.NamespaceName.Value.Value, typeInfo.TypeName.Value.Value, ref context); + } + + throw new InvalidDataException($"missing dynamic registration info for {typeInfo.WireTypeId}"); + } + case TypeId.NamedStruct: + case TypeId.NamedEnum: + case TypeId.NamedExt: + case TypeId.NamedUnion: + { + if (!typeInfo.NamespaceName.HasValue || !typeInfo.TypeName.HasValue) + { + throw new InvalidDataException($"missing dynamic type name for {typeInfo.WireTypeId}"); + } + + return ReadByTypeName(typeInfo.NamespaceName.Value.Value, typeInfo.TypeName.Value.Value, ref context); + } + case TypeId.CompatibleStruct: + case TypeId.NamedCompatibleStruct: + { + if (typeInfo.CompatibleTypeMeta is null) + { + throw new InvalidDataException($"missing compatible type meta for {typeInfo.WireTypeId}"); + } + + TypeMeta compatibleTypeMeta = typeInfo.CompatibleTypeMeta; + if (compatibleTypeMeta.RegisterByName) + { + return ReadByTypeName( + compatibleTypeMeta.NamespaceName.Value, + compatibleTypeMeta.TypeName.Value, + ref context, + compatibleTypeMeta); + } + + if (!compatibleTypeMeta.UserTypeId.HasValue) + { + throw new InvalidDataException("missing user type id in compatible dynamic type meta"); + } + + return ReadByUserTypeId(compatibleTypeMeta.UserTypeId.Value, ref context, compatibleTypeMeta); + } + case TypeId.None: + return null; + default: + throw new InvalidDataException($"unsupported dynamic type id {typeInfo.WireTypeId}"); + } + } + + private static byte[] ReadBinary(ref ReadContext context) + { + uint length = context.Reader.ReadVarUInt32(); + return context.Reader.ReadBytes(checked((int)length)); + } + + private static bool[] ReadBoolArray(ref ReadContext context) + { + int count = checked((int)context.Reader.ReadVarUInt32()); + bool[] values = new bool[count]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadUInt8() != 0; + } + + return values; + } + + private static sbyte[] ReadInt8Array(ref ReadContext context) + { + int count = checked((int)context.Reader.ReadVarUInt32()); + sbyte[] values = new sbyte[count]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadInt8(); + } + + return values; + } + + private static short[] ReadInt16Array(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 1) != 0) + { + throw new InvalidDataException("int16 array payload size mismatch"); + } + + short[] values = new short[payloadSize / 2]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadInt16(); + } + + return values; + } + + private static int[] ReadInt32Array(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 3) != 0) + { + throw new InvalidDataException("int32 array payload size mismatch"); + } + + int[] values = new int[payloadSize / 4]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadInt32(); + } + + return values; + } + + private static long[] ReadInt64Array(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 7) != 0) + { + throw new InvalidDataException("int64 array payload size mismatch"); + } + + long[] values = new long[payloadSize / 8]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadInt64(); + } + + return values; + } + + private static ushort[] ReadUInt16Array(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 1) != 0) + { + throw new InvalidDataException("uint16 array payload size mismatch"); + } + + ushort[] values = new ushort[payloadSize / 2]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadUInt16(); + } + + return values; + } + + private static uint[] ReadUInt32Array(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 3) != 0) + { + throw new InvalidDataException("uint32 array payload size mismatch"); + } + + uint[] values = new uint[payloadSize / 4]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadUInt32(); + } + + return values; + } + + private static ulong[] ReadUInt64Array(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 7) != 0) + { + throw new InvalidDataException("uint64 array payload size mismatch"); + } + + ulong[] values = new ulong[payloadSize / 8]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadUInt64(); + } + + return values; + } + + private static float[] ReadFloat32Array(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 3) != 0) + { + throw new InvalidDataException("float32 array payload size mismatch"); + } + + float[] values = new float[payloadSize / 4]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadFloat32(); + } + + return values; + } + + private static double[] ReadFloat64Array(ref ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 7) != 0) + { + throw new InvalidDataException("float64 array payload size mismatch"); + } + + double[] values = new double[payloadSize / 8]; + for (int i = 0; i < values.Length; i++) + { + values[i] = context.Reader.ReadFloat64(); + } + + return values; + } + + private void MarkRegistrationMode(TypeId kind, bool registerByName) + { + DynamicRegistrationMode mode = registerByName ? DynamicRegistrationMode.NameOnly : DynamicRegistrationMode.IdOnly; + if (!_registrationModeByKind.TryGetValue(kind, out DynamicRegistrationMode existing)) + { + _registrationModeByKind[kind] = mode; + return; + } + + if (existing != mode) + { + _registrationModeByKind[kind] = DynamicRegistrationMode.Mixed; + } + } + + private DynamicRegistrationMode DynamicRegistrationModeFor(TypeId kind) + { + if (_registrationModeByKind.TryGetValue(kind, out DynamicRegistrationMode mode)) + { + return mode; + } + + throw new TypeNotRegisteredException($"no dynamic registration mode for kind {kind}"); + } + + private static TypeId NormalizeBaseKind(TypeId kind) + { + return kind switch + { + TypeId.NamedEnum => TypeId.Enum, + TypeId.CompatibleStruct or TypeId.NamedCompatibleStruct or TypeId.NamedStruct => TypeId.Struct, + TypeId.NamedExt => TypeId.Ext, + TypeId.NamedUnion => TypeId.TypedUnion, + _ => kind, + }; + } + + private static TypeId NamedKind(TypeId baseKind, bool compatible) + { + return baseKind switch + { + TypeId.Struct => compatible ? TypeId.NamedCompatibleStruct : TypeId.NamedStruct, + TypeId.Enum => TypeId.NamedEnum, + TypeId.Ext => TypeId.NamedExt, + TypeId.TypedUnion => TypeId.NamedUnion, + _ => baseKind, + }; + } + + private static TypeId IdKind(TypeId baseKind, bool compatible) + { + return baseKind switch + { + TypeId.Struct => compatible ? TypeId.CompatibleStruct : TypeId.Struct, + _ => baseKind, + }; + } + + private static bool WireTypeNeedsUserTypeId(TypeId typeId) + { + return typeId is TypeId.Enum or TypeId.Struct or TypeId.Ext or TypeId.TypedUnion; + } + + private static TypeMeta BuildCompatibleTypeMeta( + RegisteredTypeInfo info, + TypeId wireTypeId, + bool trackRef) + { + IReadOnlyList fields = info.Serializer.CompatibleTypeMetaFields(trackRef); + bool hasFieldsMeta = fields.Count > 0; + if (info.RegisterByName) + { + if (info.NamespaceName is null) + { + throw new InvalidDataException("missing namespace metadata for name-registered type"); + } + + return new TypeMeta( + (uint)wireTypeId, + null, + info.NamespaceName.Value, + info.TypeName, + true, + fields, + hasFieldsMeta); + } + + if (!info.UserTypeId.HasValue) + { + throw new InvalidDataException("missing user type id metadata for id-registered type"); + } + + return new TypeMeta( + (uint)wireTypeId, + info.UserTypeId.Value, + MetaString.Empty('.', '_'), + MetaString.Empty('$', '_'), + false, + fields, + hasFieldsMeta); + } + + private static void ValidateCompatibleTypeMeta( + TypeMeta remoteTypeMeta, + RegisteredTypeInfo localInfo, + HashSet expectedWireTypes, + TypeId actualWireTypeId) + { + if (remoteTypeMeta.RegisterByName) + { + if (!localInfo.RegisterByName || localInfo.NamespaceName is null) + { + throw new InvalidDataException( + "received name-registered compatible metadata for id-registered local type"); + } + + if (remoteTypeMeta.NamespaceName.Value != localInfo.NamespaceName.Value.Value) + { + throw new InvalidDataException( + $"namespace mismatch: expected {localInfo.NamespaceName.Value.Value}, got {remoteTypeMeta.NamespaceName.Value}"); + } + + if (remoteTypeMeta.TypeName.Value != localInfo.TypeName.Value) + { + throw new InvalidDataException( + $"type name mismatch: expected {localInfo.TypeName.Value}, got {remoteTypeMeta.TypeName.Value}"); + } + } + else + { + if (localInfo.RegisterByName) + { + throw new InvalidDataException( + "received id-registered compatible metadata for name-registered local type"); + } + + if (!remoteTypeMeta.UserTypeId.HasValue) + { + throw new InvalidDataException("missing user type id in compatible type metadata"); + } + + if (!localInfo.UserTypeId.HasValue) + { + throw new InvalidDataException("missing local user type id metadata for id-registered type"); + } + + if (remoteTypeMeta.UserTypeId.Value != localInfo.UserTypeId.Value) + { + throw new TypeMismatchException(localInfo.UserTypeId.Value, remoteTypeMeta.UserTypeId.Value); + } + } + + if (remoteTypeMeta.TypeId.HasValue && + Enum.IsDefined(typeof(TypeId), remoteTypeMeta.TypeId.Value)) + { + TypeId remoteWireTypeId = (TypeId)remoteTypeMeta.TypeId.Value; + if (!expectedWireTypes.Contains(remoteWireTypeId)) + { + throw new TypeMismatchException((uint)actualWireTypeId, remoteTypeMeta.TypeId.Value); + } + } + } + + private static void WriteMetaString( + ref WriteContext context, + MetaString value, + IReadOnlyList encodings, + MetaStringEncoder encoder) + { + MetaString normalized = encodings.Contains(value.Encoding) + ? value + : encoder.Encode(value.Value, encodings); + if (!encodings.Contains(normalized.Encoding)) + { + throw new EncodingException("failed to normalize meta string encoding"); + } + + byte[] bytes = normalized.Bytes; + (uint index, bool isNew) = context.MetaStringWriteState.AssignIndexIfAbsent(normalized); + if (isNew) + { + context.Writer.WriteVarUInt32((uint)(bytes.Length << 1)); + if (bytes.Length > 16) + { + context.Writer.WriteInt64(unchecked((long)MetaStringHash(normalized))); + } + else if (bytes.Length > 0) + { + context.Writer.WriteUInt8((byte)normalized.Encoding); + } + + context.Writer.WriteBytes(bytes); + } + else + { + context.Writer.WriteVarUInt32(((index + 1) << 1) | 1); + } + } + + private static MetaString ReadMetaString( + ref ReadContext context, + MetaStringDecoder decoder, + IReadOnlyList encodings) + { + uint header = context.Reader.ReadVarUInt32(); + int length = checked((int)(header >> 1)); + bool isRef = (header & 1) == 1; + if (isRef) + { + int index = length - 1; + MetaString? cached = context.MetaStringReadState.ValueAt(index); + if (cached is null) + { + throw new InvalidDataException($"unknown meta string ref index {index}"); + } + + return cached.Value; + } + + MetaString value; + if (length == 0) + { + value = MetaString.Empty(decoder.SpecialChar1, decoder.SpecialChar2); + } + else + { + MetaStringEncoding encoding; + if (length > 16) + { + long hash = context.Reader.ReadInt64(); + byte rawEncoding = unchecked((byte)(hash & 0xFF)); + encoding = (MetaStringEncoding)rawEncoding; + } + else + { + encoding = (MetaStringEncoding)context.Reader.ReadUInt8(); + } + + if (!encodings.Contains(encoding)) + { + throw new InvalidDataException($"meta string encoding {encoding} not allowed in this context"); + } + + byte[] bytes = context.Reader.ReadBytes(length); + value = decoder.Decode(bytes, encoding); + } + + context.MetaStringReadState.Append(value); + return value; + } + + private static ulong MetaStringHash(MetaString metaString) + { + (ulong h1, _) = MurmurHash3.X64_128(metaString.Bytes, 47); + long hash = unchecked((long)h1); + if (hash != long.MinValue) + { + hash = Math.Abs(hash); + } + + ulong result = unchecked((ulong)hash); + if (result == 0) + { + result += 256; + } + + result &= 0xffffffffffffff00; + result |= (byte)metaString.Encoding; + return result; + } + + private Serializer CreateBindingCore(Type type) + { + if (GeneratedFactories.TryGetValue(type, out Func? generatedFactory)) + { + return generatedFactory(); + } + + if (type == typeof(bool)) + { + return new BoolSerializer(); + } + + if (type == typeof(sbyte)) + { + return new Int8Serializer(); + } + + if (type == typeof(short)) + { + return new Int16Serializer(); + } + + if (type == typeof(int)) + { + return new Int32Serializer(); + } + + if (type == typeof(long)) + { + return new Int64Serializer(); + } + + if (type == typeof(byte)) + { + return new UInt8Serializer(); + } + + if (type == typeof(ushort)) + { + return new UInt16Serializer(); + } + + if (type == typeof(uint)) + { + return new UInt32Serializer(); + } + + if (type == typeof(ulong)) + { + return new UInt64Serializer(); + } + + if (type == typeof(float)) + { + return new Float32Serializer(); + } + + if (type == typeof(double)) + { + return new Float64Serializer(); + } + + if (type == typeof(string)) + { + return new StringSerializer(); + } + + if (type == typeof(byte[])) + { + return new BinarySerializer(); + } + + if (type == typeof(bool[])) + { + return new BoolArraySerializer(); + } + + if (type == typeof(sbyte[])) + { + return new Int8ArraySerializer(); + } + + if (type == typeof(short[])) + { + return new Int16ArraySerializer(); + } + + if (type == typeof(int[])) + { + return new Int32ArraySerializer(); + } + + if (type == typeof(long[])) + { + return new Int64ArraySerializer(); + } + + if (type == typeof(ushort[])) + { + return new UInt16ArraySerializer(); + } + + if (type == typeof(uint[])) + { + return new UInt32ArraySerializer(); + } + + if (type == typeof(ulong[])) + { + return new UInt64ArraySerializer(); + } + + if (type == typeof(float[])) + { + return new Float32ArraySerializer(); + } + + if (type == typeof(double[])) + { + return new Float64ArraySerializer(); + } + + if (type == typeof(DateOnly)) + { + return new DateOnlySerializer(); + } + + if (type == typeof(DateTimeOffset)) + { + return new DateTimeOffsetSerializer(); + } + + if (type == typeof(DateTime)) + { + return new DateTimeSerializer(); + } + + if (type == typeof(TimeSpan)) + { + return new TimeSpanSerializer(); + } + + if (type == typeof(List)) + { + return new ListBoolSerializer(); + } + + if (type == typeof(List)) + { + return new ListInt8Serializer(); + } + + if (type == typeof(List)) + { + return new ListInt16Serializer(); + } + + if (type == typeof(List)) + { + return new ListIntSerializer(); + } + + if (type == typeof(List)) + { + return new ListLongSerializer(); + } + + if (type == typeof(List)) + { + return new ListUInt8Serializer(); + } + + if (type == typeof(List)) + { + return new ListUInt16Serializer(); + } + + if (type == typeof(List)) + { + return new ListUIntSerializer(); + } + + if (type == typeof(List)) + { + return new ListULongSerializer(); + } + + if (type == typeof(List)) + { + return new ListFloatSerializer(); + } + + if (type == typeof(List)) + { + return new ListDoubleSerializer(); + } + + if (type == typeof(List)) + { + return new ListStringSerializer(); + } + + if (type == typeof(List)) + { + return new ListDateOnlySerializer(); + } + + if (type == typeof(List)) + { + return new ListDateTimeOffsetSerializer(); + } + + if (type == typeof(List)) + { + return new ListDateTimeSerializer(); + } + + if (type == typeof(List)) + { + return new ListTimeSpanSerializer(); + } + + if (type == typeof(HashSet)) + { + return new SetInt8Serializer(); + } + + if (type == typeof(HashSet)) + { + return new SetInt16Serializer(); + } + + if (type == typeof(HashSet)) + { + return new SetIntSerializer(); + } + + if (type == typeof(HashSet)) + { + return new SetLongSerializer(); + } + + if (type == typeof(HashSet)) + { + return new SetUInt8Serializer(); + } + + if (type == typeof(HashSet)) + { + return new SetUInt16Serializer(); + } + + if (type == typeof(HashSet)) + { + return new SetUIntSerializer(); + } + + if (type == typeof(HashSet)) + { + return new SetULongSerializer(); + } + + if (type == typeof(HashSet)) + { + return new SetFloatSerializer(); + } + + if (type == typeof(HashSet)) + { + return new SetDoubleSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryStringStringSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryStringIntSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryStringLongSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryStringBoolSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryStringDoubleSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryStringFloatSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryStringUIntSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryStringULongSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryStringInt8Serializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryStringInt16Serializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryStringUInt16Serializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryIntIntSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryLongLongSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryUIntUIntSerializer(); + } + + if (type == typeof(Dictionary)) + { + return new DictionaryULongULongSerializer(); + } + + if (type == typeof(object)) + { + return new DynamicAnyObjectSerializer(); + } + + if (typeof(Union).IsAssignableFrom(type)) + { + Type serializerType = typeof(UnionSerializer<>).MakeGenericType(type); + return CreateSerializer(serializerType); + } + + if (type.IsEnum) + { + Type serializerType = typeof(EnumSerializer<>).MakeGenericType(type); + return CreateSerializer(serializerType); + } + + if (type.IsArray) + { + Type elementType = type.GetElementType()!; + Type serializerType = typeof(ArraySerializer<>).MakeGenericType(elementType); + return CreateSerializer(serializerType); + } + + if (type.IsGenericType) + { + Type genericType = type.GetGenericTypeDefinition(); + Type[] genericArgs = type.GetGenericArguments(); + if (genericType == typeof(Nullable<>)) + { + Type serializerType = typeof(NullableSerializer<>).MakeGenericType(genericArgs[0]); + return CreateSerializer(serializerType); + } + + if (genericType == typeof(List<>)) + { + Type serializerType = typeof(ListSerializer<>).MakeGenericType(genericArgs[0]); + return CreateSerializer(serializerType); + } + + if (genericType == typeof(HashSet<>)) + { + Type serializerType = typeof(SetSerializer<>).MakeGenericType(genericArgs[0]); + return CreateSerializer(serializerType); + } + + if (genericType == typeof(Dictionary<,>)) + { + Type serializerType = typeof(DictionarySerializer<,>).MakeGenericType(genericArgs[0], genericArgs[1]); + return CreateSerializer(serializerType); + } + + if (genericType == typeof(NullableKeyDictionary<,>)) + { + Type serializerType = typeof(NullableKeyDictionarySerializer<,>).MakeGenericType(genericArgs[0], genericArgs[1]); + return CreateSerializer(serializerType); + } + } + + throw new TypeNotRegisteredException($"No serializer available for {type}"); + } + + private static Serializer CreateSerializer() + where TSerializer : Serializer, new() + { + return new TSerializer(); + } + + private Serializer CreateSerializer(Type serializerType) + { + if (!typeof(Serializer).IsAssignableFrom(serializerType)) + { + throw new InvalidDataException($"{serializerType} is not a serializer"); + } + + try + { + if (Activator.CreateInstance(serializerType) is Serializer serializer) + { + return serializer; + } + } + catch (Exception ex) + { + throw new InvalidDataException($"failed to create serializer for {serializerType}: {ex.Message}"); + } + + throw new InvalidDataException($"{serializerType} is not a serializer"); + } + + private static MetaString ReadMetaString(ByteReader reader, MetaStringDecoder decoder, IReadOnlyList encodings) + { + byte header = reader.ReadUInt8(); + int encodingIndex = header & 0b11; + if (encodingIndex >= encodings.Count) + { + throw new InvalidDataException("invalid meta string encoding index"); + } + + int length = header >> 2; + if (length >= 0b11_1111) + { + length = 0b11_1111 + (int)reader.ReadVarUInt32(); + } + + byte[] bytes = reader.ReadBytes(length); + return decoder.Decode(bytes, encodings[encodingIndex]); + } +} diff --git a/csharp/src/Fory/Union.cs b/csharp/src/Fory/Union.cs new file mode 100644 index 0000000000..1db9a9ec27 --- /dev/null +++ b/csharp/src/Fory/Union.cs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +public class Union : IEquatable +{ + public Union(int index, object? value) + : this(index, value, (int)TypeId.Unknown) + { + } + + public Union(int index, object? value, int valueTypeId) + { + Index = index; + Value = value; + ValueTypeId = valueTypeId; + } + + public int Index { get; } + + public object? Value { get; } + + public int ValueTypeId { get; } + + public bool HasValue => Value is not null; + + public T GetValue() + { + if (Value is T typed) + { + return typed; + } + + throw new InvalidOperationException( + $"union value type mismatch: expected {typeof(T)}, got {Value?.GetType()}"); + } + + public bool Equals(Union? other) + { + return other is not null && Index == other.Index && Equals(Value, other.Value); + } + + public override bool Equals(object? obj) + { + return obj is Union other && Equals(other); + } + + public override int GetHashCode() + { + return HashCode.Combine(Index, Value); + } + + public override string ToString() + { + return $"Union{{index={Index}, value={Value}}}"; + } +} + +public sealed class Union2 : Union +{ + private Union2(int index, object? value) + : this(index, value, (int)TypeId.Unknown) + { + } + + private Union2(int index, object? value, int valueTypeId) + : base(index, value, valueTypeId) + { + if (index is < 0 or > 1) + { + throw new ArgumentOutOfRangeException(nameof(index), $"Union2 index must be 0 or 1, got {index}"); + } + } + + public static Union2 OfT1(T1 value) + { + return new Union2(0, value); + } + + public static Union2 OfT2(T2 value) + { + return new Union2(1, value); + } + + public static Union2 Of(int index, object? value) + { + return new Union2(index, value); + } + + public bool IsT1 => Index == 0; + + public bool IsT2 => Index == 1; + + public T1 GetT1() + { + if (!IsT1) + { + throw new InvalidOperationException($"Union2 currently holds case {Index}, not case 0"); + } + + return GetValue(); + } + + public T2 GetT2() + { + if (!IsT2) + { + throw new InvalidOperationException($"Union2 currently holds case {Index}, not case 1"); + } + + return GetValue(); + } + + public override string ToString() + { + return $"Union2{{index={Index}, value={Value}}}"; + } +} diff --git a/csharp/src/Fory/UnionSerializer.cs b/csharp/src/Fory/UnionSerializer.cs new file mode 100644 index 0000000000..0642e419f0 --- /dev/null +++ b/csharp/src/Fory/UnionSerializer.cs @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Linq.Expressions; +using System.Reflection; + +namespace Apache.Fory; + +public sealed class UnionSerializer : Serializer + where TUnion : Union +{ + private static readonly Func Factory = BuildFactory(); + + public override TypeId StaticTypeId => TypeId.TypedUnion; + + public override bool IsNullableType => true; + + public override bool IsReferenceTrackableType => true; + + public override TUnion DefaultValue => null!; + + public override bool IsNone(in TUnion value) + { + return value is null; + } + + public override void WriteData(ref WriteContext context, in TUnion value, bool hasGenerics) + { + _ = hasGenerics; + if (value is null) + { + throw new InvalidDataException("union value is null"); + } + + context.Writer.WriteVarUInt32((uint)value.Index); + DynamicAnyCodec.WriteAny(ref context, value.Value, RefMode.Tracking, true, false); + } + + public override TUnion ReadData(ref ReadContext context) + { + uint rawCaseId = context.Reader.ReadVarUInt32(); + if (rawCaseId > int.MaxValue) + { + throw new InvalidDataException($"union case id out of range: {rawCaseId}"); + } + + object? caseValue = DynamicAnyCodec.ReadAny(ref context, RefMode.Tracking, true); + return Factory((int)rawCaseId, caseValue); + } + + private static Func BuildFactory() + { + if (typeof(TUnion) == typeof(Union)) + { + return (index, value) => (TUnion)(object)new Union(index, value); + } + + ConstructorInfo? ctor = typeof(TUnion).GetConstructor( + BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, + binder: null, + [typeof(int), typeof(object)], + modifiers: null); + if (ctor is not null) + { + ParameterExpression indexParam = Expression.Parameter(typeof(int), "index"); + ParameterExpression valueParam = Expression.Parameter(typeof(object), "value"); + NewExpression created = Expression.New(ctor, indexParam, valueParam); + return Expression.Lambda>(created, indexParam, valueParam).Compile(); + } + + MethodInfo? ofFactory = typeof(TUnion).GetMethod( + "Of", + BindingFlags.Public | BindingFlags.Static, + binder: null, + [typeof(int), typeof(object)], + modifiers: null); + if (ofFactory is not null && typeof(TUnion).IsAssignableFrom(ofFactory.ReturnType)) + { + return (index, value) => (TUnion)ofFactory.Invoke(null, [index, value])!; + } + + throw new InvalidDataException( + $"union type {typeof(TUnion)} must define (int, object) constructor or static Of(int, object)"); + } +} diff --git a/csharp/tests/Fory.Tests/Fory.Tests.csproj b/csharp/tests/Fory.Tests/Fory.Tests.csproj new file mode 100644 index 0000000000..e9f9edbab4 --- /dev/null +++ b/csharp/tests/Fory.Tests/Fory.Tests.csproj @@ -0,0 +1,28 @@ + + + net8.0 + 12.0 + enable + enable + false + true + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + diff --git a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs new file mode 100644 index 0000000000..3b26540d08 --- /dev/null +++ b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs @@ -0,0 +1,641 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Buffers; +using Apache.Fory; +using ForyRuntime = Apache.Fory.Fory; + +namespace Apache.Fory.Tests; + +[ForyObject] +public enum TestColor +{ + Green, + Red, + Blue, + White, +} + +[ForyObject] +public sealed class Address +{ + public string Street { get; set; } = string.Empty; + public int Zip { get; set; } +} + +[ForyObject] +public sealed class Person +{ + public long Id { get; set; } + public string Name { get; set; } = string.Empty; + public string? Nickname { get; set; } + public List Scores { get; set; } = []; + public HashSet Tags { get; set; } = []; + public List
Addresses { get; set; } = []; + public Dictionary Metadata { get; set; } = []; +} + +[ForyObject] +public sealed class Node +{ + public int Value { get; set; } + public Node? Next { get; set; } +} + +[ForyObject] +public sealed class FieldOrder +{ + public string Z { get; set; } = string.Empty; + public long A { get; set; } + public short B { get; set; } + public int C { get; set; } +} + +[ForyObject] +public sealed class EncodedNumbers +{ + [Field(Encoding = FieldEncoding.Fixed)] + public uint U32Fixed { get; set; } + + [Field(Encoding = FieldEncoding.Tagged)] + public ulong U64Tagged { get; set; } +} + +[ForyObject] +public sealed class OneStringField +{ + public string? F1 { get; set; } +} + +[ForyObject] +public sealed class TwoStringField +{ + public string F1 { get; set; } = string.Empty; + public string F2 { get; set; } = string.Empty; +} + +[ForyObject] +public sealed class StructWithEnum +{ + public string Name { get; set; } = string.Empty; + public TestColor Color { get; set; } + public int Value { get; set; } +} + +[ForyObject] +public sealed class StructWithNullableMap +{ + public NullableKeyDictionary Data { get; set; } = new(); +} + +[ForyObject] +public sealed class StructWithUnion2 +{ + public Union2 Union { get; set; } = Union2.OfT1(string.Empty); +} + +[ForyObject] +public sealed class DynamicAnyHolder +{ + public object? AnyValue { get; set; } + public HashSet AnySet { get; set; } = []; + public Dictionary AnyMap { get; set; } = []; +} + +public sealed class ForyRuntimeTests +{ + private const ulong StringEncodingLatin1 = 0; + private const ulong StringEncodingUtf16 = 1; + private const ulong StringEncodingUtf8 = 2; + + [Fact] + public void PrimitiveRoundTrip() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + + Assert.True(fory.Deserialize(fory.Serialize(true))); + Assert.Equal(-123_456, fory.Deserialize(fory.Serialize(-123_456))); + Assert.Equal(9_223_372_036_854_775_000L, fory.Deserialize(fory.Serialize(9_223_372_036_854_775_000L))); + Assert.Equal(123_456u, fory.Deserialize(fory.Serialize(123_456u))); + Assert.Equal(9_223_372_036_854_775_000UL, fory.Deserialize(fory.Serialize(9_223_372_036_854_775_000UL))); + Assert.Equal(3.25f, fory.Deserialize(fory.Serialize(3.25f))); + Assert.Equal(3.1415926, fory.Deserialize(fory.Serialize(3.1415926))); + Assert.Equal("hello_fory", fory.Deserialize(fory.Serialize("hello_fory"))); + + byte[] binary = [0x01, 0x02, 0x03, 0xFF]; + Assert.Equal(binary, fory.Deserialize(fory.Serialize(binary))); + } + + [Fact] + public void OptionalRoundTrip() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + + string? present = "present"; + string? absent = null; + Assert.Equal("present", fory.Deserialize(fory.Serialize(present))); + Assert.Null(fory.Deserialize(fory.Serialize(absent))); + } + + [Fact] + public void CollectionsRoundTrip() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + + List list = ["a", null, "b"]; + Assert.Equal(list, fory.Deserialize>(fory.Serialize(list))); + + int[] intArray = [1, 2, 3, 4]; + Assert.Equal(intArray, fory.Deserialize(fory.Serialize(intArray))); + + byte[] bytes = [1, 2, 3, 250]; + Assert.Equal(bytes, fory.Deserialize(fory.Serialize(bytes))); + + HashSet set = [1, 5, 8]; + Assert.Equal(set, fory.Deserialize>(fory.Serialize(set))); + + Dictionary map = new() { [1] = 100, [2] = null, [3] = -7 }; + Dictionary decoded = fory.Deserialize>(fory.Serialize(map)); + Assert.Equal(map.Count, decoded.Count); + foreach ((sbyte key, int? value) in map) + { + Assert.Equal(value, decoded[key]); + } + } + + [Fact] + public void NumericListSetRoundTrip() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + + Assert.Equal((List)[-5, 0, 12], fory.Deserialize>(fory.Serialize((List)[-5, 0, 12]))); + Assert.Equal((List)[-1200, 0, 32000], fory.Deserialize>(fory.Serialize((List)[-1200, 0, 32000]))); + Assert.Equal((List)[-200_000, 0, 500_000], fory.Deserialize>(fory.Serialize((List)[-200_000, 0, 500_000]))); + Assert.Equal((List)[-9_000_000_000, 0, 9_000_000_000], fory.Deserialize>(fory.Serialize((List)[-9_000_000_000, 0, 9_000_000_000]))); + Assert.Equal((List)[0, 1, 200], fory.Deserialize>(fory.Serialize((List)[0, 1, 200]))); + Assert.Equal((List)[0, 1, 65000], fory.Deserialize>(fory.Serialize((List)[0, 1, 65000]))); + Assert.Equal((List)[0, 1, 4_000_000_000], fory.Deserialize>(fory.Serialize((List)[0, 1, 4_000_000_000]))); + Assert.Equal((List)[0, 1, 12_000_000_000], fory.Deserialize>(fory.Serialize((List)[0, 1, 12_000_000_000]))); + Assert.Equal((List)[-2.5f, 0f, 7.25f], fory.Deserialize>(fory.Serialize((List)[-2.5f, 0f, 7.25f]))); + Assert.Equal((List)[-2.5, 0d, 7.25], fory.Deserialize>(fory.Serialize((List)[-2.5, 0d, 7.25]))); + + Assert.Equal((HashSet)[-5, 0, 12], fory.Deserialize>(fory.Serialize((HashSet)[-5, 0, 12]))); + Assert.Equal((HashSet)[-1200, 0, 32000], fory.Deserialize>(fory.Serialize((HashSet)[-1200, 0, 32000]))); + Assert.Equal((HashSet)[-200_000, 0, 500_000], fory.Deserialize>(fory.Serialize((HashSet)[-200_000, 0, 500_000]))); + Assert.Equal((HashSet)[-9_000_000_000, 0, 9_000_000_000], fory.Deserialize>(fory.Serialize((HashSet)[-9_000_000_000, 0, 9_000_000_000]))); + Assert.Equal((HashSet)[0, 1, 200], fory.Deserialize>(fory.Serialize((HashSet)[0, 1, 200]))); + Assert.Equal((HashSet)[0, 1, 65000], fory.Deserialize>(fory.Serialize((HashSet)[0, 1, 65000]))); + Assert.Equal((HashSet)[0, 1, 4_000_000_000], fory.Deserialize>(fory.Serialize((HashSet)[0, 1, 4_000_000_000]))); + Assert.Equal((HashSet)[0, 1, 12_000_000_000], fory.Deserialize>(fory.Serialize((HashSet)[0, 1, 12_000_000_000]))); + Assert.Equal((HashSet)[-2.5f, 0f, 7.25f], fory.Deserialize>(fory.Serialize((HashSet)[-2.5f, 0f, 7.25f]))); + Assert.Equal((HashSet)[-2.5, 0d, 7.25], fory.Deserialize>(fory.Serialize((HashSet)[-2.5, 0d, 7.25]))); + } + + [Fact] + public void PrimitiveStringDictionaryRoundTrip() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + + static void AssertMapRoundTrip(ForyRuntime runtime, Dictionary source) + { + Dictionary decoded = runtime.Deserialize>(runtime.Serialize(source)); + Assert.Equal(source.Count, decoded.Count); + foreach ((string key, T value) in source) + { + Assert.Equal(value, decoded[key]); + } + } + + AssertMapRoundTrip(fory, new Dictionary { ["a"] = -1.25f, ["b"] = 7.5f }); + AssertMapRoundTrip(fory, new Dictionary { ["a"] = 1, ["b"] = 4_000_000_000 }); + AssertMapRoundTrip(fory, new Dictionary { ["a"] = 1, ["b"] = 12_000_000_000 }); + AssertMapRoundTrip(fory, new Dictionary { ["a"] = -7, ["b"] = 120 }); + AssertMapRoundTrip(fory, new Dictionary { ["a"] = -32000, ["b"] = 12345 }); + AssertMapRoundTrip(fory, new Dictionary { ["a"] = 1, ["b"] = 65000 }); + } + + [Fact] + public void PrimitiveUnsignedDictionaryRoundTrip() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + + static void AssertMapRoundTrip(ForyRuntime runtime, Dictionary source) + where TKey : notnull + { + Dictionary decoded = runtime.Deserialize>(runtime.Serialize(source)); + Assert.Equal(source.Count, decoded.Count); + foreach ((TKey key, TValue value) in source) + { + Assert.Equal(value, decoded[key]); + } + } + + AssertMapRoundTrip(fory, new Dictionary { [1] = 7, [2] = 4_000_000_000 }); + AssertMapRoundTrip(fory, new Dictionary { [1] = 7, [2] = 12_000_000_000 }); + } + + [Fact] + public void StreamDeserializeConsumesSingleFrame() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + + byte[] p1 = fory.Serialize(11); + byte[] p2 = fory.Serialize(22); + byte[] joined = new byte[p1.Length + p2.Length]; + Buffer.BlockCopy(p1, 0, joined, 0, p1.Length); + Buffer.BlockCopy(p2, 0, joined, p1.Length, p2.Length); + + ReadOnlySequence sequence = new(joined); + int first = fory.Deserialize(ref sequence); + int second = fory.Deserialize(ref sequence); + + Assert.Equal(11, first); + Assert.Equal(22, second); + Assert.Equal(0, sequence.Length); + } + + [Fact] + public void StreamDeserializeObjectConsumesSingleFrame() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + + byte[] p1 = fory.SerializeObject("first"); + byte[] p2 = fory.SerializeObject(99); + byte[] joined = new byte[p1.Length + p2.Length]; + Buffer.BlockCopy(p1, 0, joined, 0, p1.Length); + Buffer.BlockCopy(p2, 0, joined, p1.Length, p2.Length); + + ReadOnlySequence sequence = new(joined); + object? first = fory.DeserializeObject(ref sequence); + object? second = fory.DeserializeObject(ref sequence); + + Assert.Equal("first", first); + Assert.Equal(99, second); + Assert.Equal(0, sequence.Length); + } + + [Fact] + public void MacroStructRoundTrip() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + fory.Register
(100); + fory.Register(101); + + Person person = new() + { + Id = 42, + Name = "Alice", + Nickname = null, + Scores = [10, 20, 30], + Tags = ["swift", "xlang"], + Addresses = [new Address { Street = "Main", Zip = 94107 }], + Metadata = new Dictionary { [1] = 100, [2] = null }, + }; + + Person decoded = fory.Deserialize(fory.Serialize(person)); + Assert.Equal(person.Id, decoded.Id); + Assert.Equal(person.Name, decoded.Name); + Assert.Equal(person.Nickname, decoded.Nickname); + Assert.Equal(person.Scores, decoded.Scores); + Assert.Equal(person.Tags, decoded.Tags); + Assert.Single(decoded.Addresses); + Assert.Equal(person.Addresses[0].Street, decoded.Addresses[0].Street); + Assert.Equal(person.Addresses[0].Zip, decoded.Addresses[0].Zip); + Assert.Equal(person.Metadata.Count, decoded.Metadata.Count); + foreach ((sbyte key, int? value) in person.Metadata) + { + Assert.Equal(value, decoded.Metadata[key]); + } + } + + [Fact] + public void MacroClassReferenceTracking() + { + ForyRuntime fory = ForyRuntime.Builder().TrackRef(true).Build(); + fory.Register(200); + + Node node = new() { Value = 7 }; + node.Next = node; + + Node decoded = fory.Deserialize(fory.Serialize(node)); + Assert.Equal(7, decoded.Value); + Assert.Same(decoded, decoded.Next); + } + + [Fact] + public void NullableKeyDictionarySupportsNullKeyRoundTrip() + { + ForyRuntime fory = ForyRuntime.Builder().Compatible(true).Build(); + + NullableKeyDictionary map = new(); + map.Add("k1", "v1"); + map.Add((string)null!, "v2"); + map.Add("k3", null); + map.Add("k4", "v4"); + + NullableKeyDictionary decoded = fory.Deserialize>(fory.Serialize(map)); + Assert.True(decoded.HasNullKey); + Assert.Equal("v2", decoded.NullKeyValue); + Assert.True(decoded.TryGetValue("k1", out string? v1)); + Assert.Equal("v1", v1); + Assert.True(decoded.TryGetValue("k3", out string? v3)); + Assert.Null(v3); + } + + [Fact] + public void NullableKeyDictionarySupportsDropInDictionaryBehavior() + { + IDictionary map = new NullableKeyDictionary(); + map.Add("k1", "v1"); + map.Add(null!, "v2"); + + Assert.Throws(() => map.Add("k1", "dup")); + Assert.Throws(() => map.Add(null!, "dup")); + + map["k1"] = "v1-updated"; + map[null!] = "v2-updated"; + + Assert.True(map.ContainsKey("k1")); + Assert.True(map.ContainsKey(null!)); + Assert.Equal("v1-updated", map["k1"]); + Assert.Equal("v2-updated", map[null!]); + Assert.True(map.TryGetValue(null!, out string? nullValue)); + Assert.Equal("v2-updated", nullValue); + Assert.True(map.Remove(null!)); + Assert.False(map.ContainsKey(null!)); + } + + [Fact] + public void DictionarySerializerSkipsNullKeyEntries() + { + ForyRuntime fory = ForyRuntime.Builder().Compatible(true).Build(); + + NullableKeyDictionary source = new(); + source.Add("k1", "v1"); + source.Add((string)null!, "v-null"); + source.Add("k2", "v2"); + + Dictionary decoded = fory.Deserialize>(fory.Serialize(source)); + Assert.Equal(2, decoded.Count); + Assert.Equal("v1", decoded["k1"]); + Assert.Equal("v2", decoded["k2"]); + } + + [Fact] + public void StructWithNullableMapRoundTrip() + { + ForyRuntime fory = ForyRuntime.Builder().Compatible(true).Build(); + fory.Register(202); + + StructWithNullableMap value = new(); + value.Data.Add("key1", "value1"); + value.Data.Add((string)null!, "value2"); + value.Data.Add("key3", null); + + StructWithNullableMap decoded = fory.Deserialize(fory.Serialize(value)); + Assert.True(decoded.Data.HasNullKey); + Assert.Equal("value2", decoded.Data.NullKeyValue); + Assert.True(decoded.Data.TryGetValue("key1", out string? key1)); + Assert.Equal("value1", key1); + Assert.True(decoded.Data.TryGetValue("key3", out string? key3)); + Assert.Null(key3); + } + + [Fact] + public void MacroFieldOrderFollowsForyRules() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + fory.Register(300); + + FieldOrder value = new() { Z = "tail", A = 123_456_789, B = 17, C = 99 }; + byte[] data = fory.Serialize(value); + + ByteReader reader = new(data); + _ = fory.ReadHead(reader); + _ = reader.ReadInt8(); + _ = reader.ReadVarUInt32(); + _ = reader.ReadVarUInt32(); + _ = reader.ReadInt32(); + + short first = reader.ReadInt16(); + long second = reader.ReadVarInt64(); + int third = reader.ReadVarInt32(); + ReadContext tailContext = new(reader, new TypeResolver(), false, false); + string fourth = tailContext.TypeResolver.GetSerializer().ReadData(ref tailContext); + + Assert.Equal(value.B, first); + Assert.Equal(value.A, second); + Assert.Equal(value.C, third); + Assert.Equal(value.Z, fourth); + } + + [Fact] + public void MacroFieldEncodingOverridesForUnsignedTypes() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + fory.Register(301); + + EncodedNumbers value = new() + { + U32Fixed = 0x11223344u, + U64Tagged = (ulong)int.MaxValue + 99UL, + }; + + EncodedNumbers decoded = fory.Deserialize(fory.Serialize(value)); + Assert.Equal(value.U32Fixed, decoded.U32Fixed); + Assert.Equal(value.U64Tagged, decoded.U64Tagged); + } + + [Fact] + public void CompatibleSchemaEvolutionRoundTrip() + { + ForyRuntime writer = ForyRuntime.Builder().Compatible(true).Build(); + writer.Register(200); + + ForyRuntime reader = ForyRuntime.Builder().Compatible(true).Build(); + reader.Register(200); + + OneStringField source = new() { F1 = "hello" }; + byte[] payload = writer.Serialize(source); + TwoStringField evolved = reader.Deserialize(payload); + + Assert.Equal("hello", evolved.F1); + Assert.Equal(string.Empty, evolved.F2); + } + + [Fact] + public void SchemaVersionMismatchThrows() + { + ForyRuntime writer = ForyRuntime.Builder().Compatible(false).Build(); + writer.Register(200); + + ForyRuntime reader = ForyRuntime.Builder().Compatible(false).Build(); + reader.Register(200); + + byte[] payload = writer.Serialize(new OneStringField { F1 = "hello" }); + Assert.Throws(() => { _ = reader.Deserialize(payload); }); + } + + [Fact] + public void UnionFieldRoundTripCompatible() + { + ForyRuntime fory = ForyRuntime.Builder().Compatible(true).Build(); + fory.Register(301); + + StructWithUnion2 first = new() { Union = Union2.OfT1("hello") }; + StructWithUnion2 second = new() { Union = Union2.OfT2(42L) }; + + StructWithUnion2 firstDecoded = fory.Deserialize(fory.Serialize(first)); + StructWithUnion2 secondDecoded = fory.Deserialize(fory.Serialize(second)); + + Assert.Equal(0, firstDecoded.Union.Index); + Assert.Equal("hello", firstDecoded.Union.GetT1()); + Assert.Equal(1, secondDecoded.Union.Index); + Assert.Equal(42L, secondDecoded.Union.GetT2()); + } + + [Fact] + public void EnumRoundTrip() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + fory.Register(100); + fory.Register(101); + + StructWithEnum value = new() { Name = "enum", Color = TestColor.Blue, Value = 42 }; + StructWithEnum decoded = fory.Deserialize(fory.Serialize(value)); + Assert.Equal(value.Name, decoded.Name); + Assert.Equal(value.Color, decoded.Color); + Assert.Equal(value.Value, decoded.Value); + } + + [Fact] + public void DynamicObjectSupportsObjectKeyMapAndSet() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + + Dictionary map = new() + { + ["k1"] = 7, + [2] = "v2", + [true] = null, + }; + Dictionary mapDecoded = + Assert.IsType>(fory.DeserializeObject(fory.SerializeObject(map))); + Assert.Equal(3, mapDecoded.Count); + Assert.Equal(7, mapDecoded["k1"]); + Assert.Equal("v2", mapDecoded[2]); + Assert.True(mapDecoded.ContainsKey(true)); + Assert.Null(mapDecoded[true]); + + HashSet set = ["a", 7, false]; + HashSet setDecoded = + Assert.IsType>(fory.DeserializeObject(fory.SerializeObject(set))); + Assert.Equal(3, setDecoded.Count); + Assert.Contains("a", setDecoded); + Assert.Contains(7, setDecoded); + Assert.Contains(false, setDecoded); + } + + [Fact] + public void GeneratedSerializerSupportsObjectKeyMap() + { + ForyRuntime fory = ForyRuntime.Builder().TrackRef(true).Build(); + fory.Register(400); + + DynamicAnyHolder source = new() + { + AnyValue = new Dictionary + { + ["inner"] = 9, + [10] = "ten", + }, + AnySet = ["x", 123], + AnyMap = new Dictionary + { + ["key1"] = null, + [99] = new List { "n", 1 }, + }, + }; + + DynamicAnyHolder decoded = fory.Deserialize(fory.Serialize(source)); + Dictionary dynamicMap = Assert.IsType>(decoded.AnyValue); + Assert.Equal(9, dynamicMap["inner"]); + Assert.Equal("ten", dynamicMap[10]); + Assert.Equal(source.AnySet.Count, decoded.AnySet.Count); + Assert.Contains("x", decoded.AnySet); + Assert.Contains(123, decoded.AnySet); + Assert.Equal(source.AnyMap.Count, decoded.AnyMap.Count); + Assert.True(decoded.AnyMap.ContainsKey("key1")); + Assert.Null(decoded.AnyMap["key1"]); + List nested = Assert.IsType>(decoded.AnyMap[99]); + Assert.Equal("n", nested[0]); + Assert.Equal(1, nested[1]); + } + + [Fact] + public void StringSerializerUsesLatin1WhenAllCharsAreLatin1() + { + (ulong encoding, string decoded) = WriteAndReadString("Hello\u00E9\u00FF"); + Assert.Equal(StringEncodingLatin1, encoding); + Assert.Equal("Hello\u00E9\u00FF", decoded); + } + + [Fact] + public void StringSerializerUsesUtf8WhenAsciiRatioIsHigh() + { + (ulong encoding, string decoded) = WriteAndReadString("abc\u4E16\u754C"); + Assert.Equal(StringEncodingUtf8, encoding); + Assert.Equal("abc\u4E16\u754C", decoded); + } + + [Fact] + public void StringSerializerUsesUtf16WhenAsciiRatioIsLow() + { + (ulong encoding, string decoded) = WriteAndReadString("\u4F60\u597D\u4E16\u754Ca"); + Assert.Equal(StringEncodingUtf16, encoding); + Assert.Equal("\u4F60\u597D\u4E16\u754Ca", decoded); + } + + [Fact] + public void StringSerializerValidatesBeyondSampleForLatin1() + { + string value = new string('a', 64) + "\u4E16"; + (ulong encoding, string decoded) = WriteAndReadString(value); + Assert.Equal(StringEncodingUtf8, encoding); + Assert.Equal(value, decoded); + } + + private static (ulong Encoding, string Decoded) WriteAndReadString(string value) + { + ByteWriter writer = new(); + TypeResolver resolver = new(); + WriteContext writeContext = new(writer, resolver, trackRef: false, compatible: false); + StringSerializer.WriteString(ref writeContext, value); + + byte[] payload = writer.ToArray(); + ByteReader headerReader = new(payload); + ulong header = headerReader.ReadVarUInt36Small(); + ulong encoding = header & 0x03; + int byteLength = checked((int)(header >> 2)); + Assert.Equal(payload.Length - headerReader.Cursor, byteLength); + + ReadContext readContext = new(new ByteReader(payload), resolver, trackRef: false, compatible: false); + string decoded = StringSerializer.ReadString(ref readContext); + Assert.Equal(0, readContext.Reader.Remaining); + return (encoding, decoded); + } +} diff --git a/csharp/tests/Fory.Tests/GlobalUsings.cs b/csharp/tests/Fory.Tests/GlobalUsings.cs new file mode 100644 index 0000000000..32c5277813 --- /dev/null +++ b/csharp/tests/Fory.Tests/GlobalUsings.cs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +global using Xunit; diff --git a/csharp/tests/Fory.XlangPeer/Fory.XlangPeer.csproj b/csharp/tests/Fory.XlangPeer/Fory.XlangPeer.csproj new file mode 100644 index 0000000000..8c92630101 --- /dev/null +++ b/csharp/tests/Fory.XlangPeer/Fory.XlangPeer.csproj @@ -0,0 +1,14 @@ + + + Exe + net8.0 + 12.0 + enable + enable + + + + + + + diff --git a/csharp/tests/Fory.XlangPeer/Program.cs b/csharp/tests/Fory.XlangPeer/Program.cs new file mode 100644 index 0000000000..b9a70cd999 --- /dev/null +++ b/csharp/tests/Fory.XlangPeer/Program.cs @@ -0,0 +1,1321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Buffers; +using System.Text; +using Apache.Fory; +using ForyRuntime = Apache.Fory.Fory; + +namespace Apache.Fory.XlangPeer; + +internal static class Program +{ + private const string DataFileEnv = "DATA_FILE"; + + private static readonly string[] StringSamples = + [ + "ab", + "Rust123", + "Çüéâäàåçêëèïî", + "こんにちは", + "Привет", + "𝄞🎵🎶", + "Hello, 世界", + ]; + + private static readonly int[] VarInt32Values = + [ + int.MinValue, + int.MinValue + 1, + -1_000_000, + -1000, + -128, + -1, + 0, + 1, + 127, + 128, + 16_383, + 16_384, + 2_097_151, + 2_097_152, + 268_435_455, + 268_435_456, + int.MaxValue - 1, + int.MaxValue, + ]; + + private static readonly uint[] VarUInt32Values = + [ + 0, + 1, + 127, + 128, + 16_383, + 16_384, + 2_097_151, + 2_097_152, + 268_435_455, + 268_435_456, + 2_147_483_646, + 2_147_483_647, + ]; + + private static readonly ulong[] VarUInt64Values = + [ + 0UL, + 1UL, + 127UL, + 128UL, + 16_383UL, + 16_384UL, + 2_097_151UL, + 2_097_152UL, + 268_435_455UL, + 268_435_456UL, + 34_359_738_367UL, + 34_359_738_368UL, + 4_398_046_511_103UL, + 4_398_046_511_104UL, + 562_949_953_421_311UL, + 562_949_953_421_312UL, + 72_057_594_037_927_935UL, + 72_057_594_037_927_936UL, + long.MaxValue, + ]; + + private static readonly long[] VarInt64Values = + [ + long.MinValue, + long.MinValue + 1, + -1_000_000_000_000L, + -1_000_000L, + -1000L, + -128L, + -1L, + 0L, + 1L, + 127L, + 1000L, + 1_000_000L, + 1_000_000_000_000L, + long.MaxValue - 1, + long.MaxValue, + ]; + + private static int Main(string[] args) + { + try + { + string caseName = ParseCaseName(args); + string dataFile = RequireDataFile(); + byte[] input = File.ReadAllBytes(dataFile); + byte[] output = ExecuteCase(caseName, input); + File.WriteAllBytes(dataFile, output); + Console.WriteLine($"case {caseName} passed"); + return 0; + } + catch (Exception ex) + { + Console.Error.WriteLine($"xlang peer failed: {ex}"); + return 1; + } + } + + private static string ParseCaseName(string[] args) + { + for (int i = 0; i < args.Length; i++) + { + if (args[i] == "--case" && i + 1 < args.Length) + { + return args[i + 1]; + } + } + + if (args.Length == 1) + { + return args[0]; + } + + throw new InvalidOperationException("Usage: Fory.XlangPeer --case "); + } + + private static string RequireDataFile() + { + string? dataFile = Environment.GetEnvironmentVariable(DataFileEnv); + if (string.IsNullOrWhiteSpace(dataFile)) + { + throw new InvalidOperationException($"{DataFileEnv} environment variable is required"); + } + + return dataFile; + } + + private static byte[] ExecuteCase(string caseName, byte[] input) + { + return caseName switch + { + "test_buffer" => CaseBuffer(input), + "test_buffer_var" => CaseBufferVar(input), + "test_murmurhash3" => CaseMurmurHash3(input), + "test_string_serializer" => CaseStringSerializer(input), + "test_cross_language_serializer" => CaseCrossLanguageSerializer(input), + "test_simple_struct" => CaseSimpleStruct(input), + "test_named_simple_struct" => CaseNamedSimpleStruct(input), + "test_list" => CaseList(input), + "test_map" => CaseMap(input), + "test_integer" => CaseInteger(input), + "test_item" => CaseItem(input), + "test_color" => CaseColor(input), + "test_union_xlang" => CaseUnionXlang(input), + "test_struct_with_list" => CaseStructWithList(input), + "test_struct_with_map" => CaseStructWithMap(input), + "test_skip_id_custom" => CaseSkipIdCustom(input), + "test_skip_name_custom" => CaseSkipNameCustom(input), + "test_consistent_named" => CaseConsistentNamed(input), + "test_struct_version_check" => CaseStructVersionCheck(input), + "test_polymorphic_list" => CasePolymorphicList(input), + "test_polymorphic_map" => CasePolymorphicMap(input), + "test_one_field_struct_compatible" => CaseOneFieldStructCompatible(input), + "test_one_field_struct_schema" => CaseOneFieldStructSchema(input), + "test_one_string_field_schema" => CaseOneStringFieldSchema(input), + "test_one_string_field_compatible" => CaseOneStringFieldCompatible(input), + "test_two_string_field_compatible" => CaseTwoStringFieldCompatible(input), + "test_schema_evolution_compatible" => CaseSchemaEvolutionCompatible(input), + "test_schema_evolution_compatible_reverse" => CaseSchemaEvolutionCompatibleReverse(input), + "test_one_enum_field_schema" => CaseOneEnumFieldSchema(input), + "test_one_enum_field_compatible" => CaseOneEnumFieldCompatible(input), + "test_two_enum_field_compatible" => CaseTwoEnumFieldCompatible(input), + "test_enum_schema_evolution_compatible" => CaseEnumSchemaEvolutionCompatible(input), + "test_enum_schema_evolution_compatible_reverse" => CaseEnumSchemaEvolutionCompatibleReverse(input), + "test_nullable_field_schema_consistent_not_null" => CaseNullableFieldSchemaConsistentNotNull(input), + "test_nullable_field_schema_consistent_null" => CaseNullableFieldSchemaConsistentNull(input), + "test_nullable_field_compatible_not_null" => CaseNullableFieldCompatibleNotNull(input), + "test_nullable_field_compatible_null" => CaseNullableFieldCompatibleNull(input), + "test_ref_schema_consistent" => CaseRefSchemaConsistent(input), + "test_ref_compatible" => CaseRefCompatible(input), + "test_collection_element_ref_override" => CaseCollectionElementRefOverride(input), + "test_circular_ref_schema_consistent" => CaseCircularRefSchemaConsistent(input), + "test_circular_ref_compatible" => CaseCircularRefCompatible(input), + "test_unsigned_schema_consistent_simple" => CaseUnsignedSchemaConsistentSimple(input), + "test_unsigned_schema_consistent" => CaseUnsignedSchemaConsistent(input), + "test_unsigned_schema_compatible" => CaseUnsignedSchemaCompatible(input), + _ => throw new InvalidOperationException($"unknown test case {caseName}"), + }; + } + + private static byte[] CaseBuffer(byte[] input) + { + ByteReader reader = new(input); + Ensure(reader.ReadUInt8() == 1, "bool mismatch"); + Ensure(reader.ReadInt8() == sbyte.MaxValue, "byte mismatch"); + Ensure(reader.ReadInt16() == short.MaxValue, "int16 mismatch"); + Ensure(reader.ReadInt32() == int.MaxValue, "int32 mismatch"); + Ensure(reader.ReadInt64() == long.MaxValue, "int64 mismatch"); + Ensure(Math.Abs(reader.ReadFloat32() - (-1.1f)) < 0.0001f, "float32 mismatch"); + Ensure(Math.Abs(reader.ReadFloat64() - (-1.1d)) < 0.000001d, "float64 mismatch"); + Ensure(reader.ReadVarUInt32() == 100, "varuint32 mismatch"); + int size = reader.ReadInt32(); + byte[] payload = reader.ReadBytes(size); + Ensure(payload.SequenceEqual("ab"u8.ToArray()), "binary mismatch"); + Ensure(reader.Remaining == 0, "buffer should be fully consumed"); + + ByteWriter writer = new(); + writer.WriteUInt8(1); + writer.WriteInt8(sbyte.MaxValue); + writer.WriteInt16(short.MaxValue); + writer.WriteInt32(int.MaxValue); + writer.WriteInt64(long.MaxValue); + writer.WriteFloat32(-1.1f); + writer.WriteFloat64(-1.1d); + writer.WriteVarUInt32(100); + writer.WriteInt32(2); + writer.WriteBytes("ab"u8); + return writer.ToArray(); + } + + private static byte[] CaseBufferVar(byte[] input) + { + ByteReader reader = new(input); + foreach (int expected in VarInt32Values) + { + Ensure(reader.ReadVarInt32() == expected, $"varint32 mismatch {expected}"); + } + + foreach (uint expected in VarUInt32Values) + { + Ensure(reader.ReadVarUInt32() == expected, $"varuint32 mismatch {expected}"); + } + + foreach (ulong expected in VarUInt64Values) + { + Ensure(reader.ReadVarUInt64() == expected, $"varuint64 mismatch {expected}"); + } + + foreach (long expected in VarInt64Values) + { + Ensure(reader.ReadVarInt64() == expected, $"varint64 mismatch {expected}"); + } + + Ensure(reader.Remaining == 0, "buffer var should be fully consumed"); + + ByteWriter writer = new(); + foreach (int value in VarInt32Values) + { + writer.WriteVarInt32(value); + } + + foreach (uint value in VarUInt32Values) + { + writer.WriteVarUInt32(value); + } + + foreach (ulong value in VarUInt64Values) + { + writer.WriteVarUInt64(value); + } + + foreach (long value in VarInt64Values) + { + writer.WriteVarInt64(value); + } + + return writer.ToArray(); + } + + private static byte[] CaseMurmurHash3(byte[] input) + { + if (input.Length == 32) + { + (ulong h1a, ulong h1b) = MurmurHash3.X64_128([1, 2, 8], 47); + (ulong h2a, ulong h2b) = MurmurHash3.X64_128(Encoding.UTF8.GetBytes("01234567890123456789"), 47); + ByteWriter writer = new(); + writer.WriteInt64(unchecked((long)h1a)); + writer.WriteInt64(unchecked((long)h1b)); + writer.WriteInt64(unchecked((long)h2a)); + writer.WriteInt64(unchecked((long)h2b)); + return writer.ToArray(); + } + + if (input.Length == 16) + { + ByteReader reader = new(input); + long h1 = reader.ReadInt64(); + long h2 = reader.ReadInt64(); + (ulong expected1, ulong expected2) = MurmurHash3.X64_128([1, 2, 8], 47); + Ensure(h1 == unchecked((long)expected1), "murmur hash h1 mismatch"); + Ensure(h2 == unchecked((long)expected2), "murmur hash h2 mismatch"); + return []; + } + + throw new InvalidOperationException($"unexpected murmur hash input length {input.Length}"); + } + + private static byte[] CaseStringSerializer(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + ReadOnlySequence sequence = new(input); + foreach (string expected in StringSamples) + { + string value = fory.Deserialize(ref sequence); + Ensure(value == expected, "string value mismatch"); + } + + EnsureConsumed(sequence, nameof(CaseStringSerializer)); + List output = []; + foreach (string sample in StringSamples) + { + Append(output, fory.SerializeObject(sample)); + } + + return output.ToArray(); + } + + private static byte[] CaseCrossLanguageSerializer(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(101); + ReadOnlySequence sequence = new(input); + + bool b1 = fory.Deserialize(ref sequence); + bool b2 = fory.Deserialize(ref sequence); + int i32 = fory.Deserialize(ref sequence); + sbyte i8a = fory.Deserialize(ref sequence); + sbyte i8b = fory.Deserialize(ref sequence); + short i16a = fory.Deserialize(ref sequence); + short i16b = fory.Deserialize(ref sequence); + int i32a = fory.Deserialize(ref sequence); + int i32b = fory.Deserialize(ref sequence); + long i64a = fory.Deserialize(ref sequence); + long i64b = fory.Deserialize(ref sequence); + float f32 = fory.Deserialize(ref sequence); + double f64 = fory.Deserialize(ref sequence); + string str = fory.Deserialize(ref sequence); + DateOnly day = fory.Deserialize(ref sequence); + DateTimeOffset timestamp = fory.Deserialize(ref sequence); + bool[] bools = fory.Deserialize(ref sequence); + byte[] bytes = fory.Deserialize(ref sequence); + short[] int16s = fory.Deserialize(ref sequence); + int[] int32s = fory.Deserialize(ref sequence); + long[] int64s = fory.Deserialize(ref sequence); + float[] floats = fory.Deserialize(ref sequence); + double[] doubles = fory.Deserialize(ref sequence); + List list = fory.Deserialize>(ref sequence); + HashSet set = fory.Deserialize>(ref sequence); + Dictionary map = fory.Deserialize>(ref sequence); + Color color = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseCrossLanguageSerializer)); + + Ensure(b1, "bool1 mismatch"); + Ensure(!b2, "bool2 mismatch"); + Ensure(i32 == -1, "int mismatch"); + Ensure(str == "str", "string mismatch"); + Ensure(day == new DateOnly(2021, 11, 23), "date mismatch"); + Ensure(timestamp.ToUnixTimeSeconds() == 100, "timestamp mismatch"); + Ensure(color == Color.White, "color mismatch"); + + List output = []; + Append(output, fory.SerializeObject(b1)); + Append(output, fory.SerializeObject(b2)); + Append(output, fory.SerializeObject(i32)); + Append(output, fory.SerializeObject(i8a)); + Append(output, fory.SerializeObject(i8b)); + Append(output, fory.SerializeObject(i16a)); + Append(output, fory.SerializeObject(i16b)); + Append(output, fory.SerializeObject(i32a)); + Append(output, fory.SerializeObject(i32b)); + Append(output, fory.SerializeObject(i64a)); + Append(output, fory.SerializeObject(i64b)); + Append(output, fory.SerializeObject(f32)); + Append(output, fory.SerializeObject(f64)); + Append(output, fory.SerializeObject(str)); + Append(output, fory.SerializeObject(day)); + Append(output, fory.SerializeObject(timestamp)); + Append(output, fory.SerializeObject(bools)); + Append(output, fory.SerializeObject(bytes)); + Append(output, fory.SerializeObject(int16s)); + Append(output, fory.SerializeObject(int32s)); + Append(output, fory.SerializeObject(int64s)); + Append(output, fory.SerializeObject(floats)); + Append(output, fory.SerializeObject(doubles)); + Append(output, fory.SerializeObject(list)); + Append(output, fory.SerializeObject(set)); + Append(output, fory.SerializeObject(map)); + Append(output, fory.SerializeObject(color)); + return output.ToArray(); + } + + private static byte[] CaseSimpleStruct(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + RegisterSimpleById(fory); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseNamedSimpleStruct(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + RegisterSimpleByName(fory); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseList(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(102); + ReadOnlySequence sequence = new(input); + List strList = fory.Deserialize>(ref sequence); + List strList2 = fory.Deserialize>(ref sequence); + List itemList = fory.Deserialize>(ref sequence); + List itemList2 = fory.Deserialize>(ref sequence); + EnsureConsumed(sequence, nameof(CaseList)); + + List output = []; + Append(output, fory.SerializeObject(strList)); + Append(output, fory.SerializeObject(strList2)); + Append(output, fory.SerializeObject(itemList)); + Append(output, fory.SerializeObject(itemList2)); + return output.ToArray(); + } + + private static byte[] CaseMap(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(102); + ReadOnlySequence sequence = new(input); + NullableKeyDictionary strMap = fory.Deserialize>(ref sequence); + NullableKeyDictionary itemMap = fory.Deserialize>(ref sequence); + EnsureConsumed(sequence, nameof(CaseMap)); + + List output = []; + Append(output, fory.SerializeObject(strMap)); + Append(output, fory.SerializeObject(itemMap)); + return output.ToArray(); + } + + private static byte[] CaseInteger(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(101); + ReadOnlySequence sequence = new(input); + Item1 obj = fory.Deserialize(ref sequence); + int f1 = fory.Deserialize(ref sequence); + int f2 = fory.Deserialize(ref sequence); + int f3 = fory.Deserialize(ref sequence); + int f4 = fory.Deserialize(ref sequence); + int f5 = fory.Deserialize(ref sequence); + int f6 = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseInteger)); + + Ensure(obj.F1 == 1 && obj.F2 == 2, "item1 primitive fields mismatch"); + Ensure(obj.F3 == 3 && obj.F4 == 4 && obj.F5 == 0 && obj.F6 == 0, "item1 boxed fields mismatch"); + + List output = []; + Append(output, fory.SerializeObject(obj)); + Append(output, fory.SerializeObject(f1)); + Append(output, fory.SerializeObject(f2)); + Append(output, fory.SerializeObject(f3)); + Append(output, fory.SerializeObject(f4)); + Append(output, fory.SerializeObject(f5)); + Append(output, fory.SerializeObject(f6)); + return output.ToArray(); + } + + private static byte[] CaseItem(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(102); + ReadOnlySequence sequence = new(input); + Item i1 = fory.Deserialize(ref sequence); + Item i2 = fory.Deserialize(ref sequence); + Item i3 = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseItem)); + + List output = []; + Append(output, fory.SerializeObject(i1)); + Append(output, fory.SerializeObject(i2)); + Append(output, fory.SerializeObject(i3)); + return output.ToArray(); + } + + private static byte[] CaseColor(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(101); + ReadOnlySequence sequence = new(input); + Color c1 = fory.Deserialize(ref sequence); + Color c2 = fory.Deserialize(ref sequence); + Color c3 = fory.Deserialize(ref sequence); + Color c4 = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseColor)); + + List output = []; + Append(output, fory.SerializeObject(c1)); + Append(output, fory.SerializeObject(c2)); + Append(output, fory.SerializeObject(c3)); + Append(output, fory.SerializeObject(c4)); + return output.ToArray(); + } + + private static byte[] CaseStructWithList(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(201); + ReadOnlySequence sequence = new(input); + StructWithList s1 = fory.Deserialize(ref sequence); + StructWithList s2 = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseStructWithList)); + + List output = []; + Append(output, fory.SerializeObject(s1)); + Append(output, fory.SerializeObject(s2)); + return output.ToArray(); + } + + private static byte[] CaseStructWithMap(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(202); + ReadOnlySequence sequence = new(input); + StructWithMap s1 = fory.Deserialize(ref sequence); + StructWithMap s2 = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseStructWithMap)); + + List output = []; + Append(output, fory.SerializeObject(s1)); + Append(output, fory.SerializeObject(s2)); + return output.ToArray(); + } + + private static byte[] CaseUnionXlang(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(301); + + ReadOnlySequence sequence = new(input); + StructWithUnion2 first = fory.Deserialize(ref sequence); + StructWithUnion2 second = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseUnionXlang)); + + Ensure(first.Union.Index == 0, "union case index mismatch for first value"); + Ensure(first.Union.Value is string firstValue && firstValue == "hello", "union case value mismatch for first value"); + Ensure(second.Union.Index == 1, "union case index mismatch for second value"); + Ensure(second.Union.Value is long secondValue && secondValue == 42L, "union case value mismatch for second value"); + + List output = []; + Append(output, fory.SerializeObject(first)); + Append(output, fory.SerializeObject(second)); + return output.ToArray(); + } + + private static byte[] CaseSkipIdCustom(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(101); + fory.Register(102); + fory.Register(103); + fory.Register(104); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseSkipNameCustom(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register("color"); + fory.Register("my_struct"); + fory.Register(string.Empty, "my_ext"); + fory.Register("my_wrapper"); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseConsistentNamed(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false, checkStructVersion: true); + fory.Register("color"); + fory.Register("my_struct"); + fory.Register(string.Empty, "my_ext"); + + ReadOnlySequence sequence = new(input); + List output = []; + for (int i = 0; i < 3; i++) + { + Color color = fory.Deserialize(ref sequence); + Append(output, fory.SerializeObject(color)); + } + + for (int i = 0; i < 3; i++) + { + MyStruct myStruct = fory.Deserialize(ref sequence); + Append(output, fory.SerializeObject(myStruct)); + } + + for (int i = 0; i < 3; i++) + { + MyExt myExt = fory.Deserialize(ref sequence); + Append(output, fory.SerializeObject(myExt)); + } + + EnsureConsumed(sequence, nameof(CaseConsistentNamed)); + return output.ToArray(); + } + + private static byte[] CaseStructVersionCheck(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false, checkStructVersion: true); + fory.Register(201); + return RoundTripSingle(input, fory); + } + + private static byte[] CasePolymorphicList(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(302); + fory.Register(303); + fory.Register(304); + + ReadOnlySequence sequence = new(input); + List animals = fory.Deserialize>(ref sequence); + AnimalListHolder holder = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CasePolymorphicList)); + + List output = []; + Append(output, fory.SerializeObject(animals)); + Append(output, fory.SerializeObject(holder)); + return output.ToArray(); + } + + private static byte[] CasePolymorphicMap(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(302); + fory.Register(303); + fory.Register(305); + + ReadOnlySequence sequence = new(input); + Dictionary map = fory.Deserialize>(ref sequence); + AnimalMapHolder holder = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CasePolymorphicMap)); + + List output = []; + Append(output, fory.SerializeObject(map)); + Append(output, fory.SerializeObject(holder)); + return output.ToArray(); + } + + private static byte[] CaseOneFieldStructCompatible(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(200); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseOneFieldStructSchema(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false); + fory.Register(200); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseOneStringFieldSchema(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false); + fory.Register(200); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseOneStringFieldCompatible(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(200); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseTwoStringFieldCompatible(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(201); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseSchemaEvolutionCompatible(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(200); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseSchemaEvolutionCompatibleReverse(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(200); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseOneEnumFieldSchema(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false); + fory.Register(210); + fory.Register(211); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseOneEnumFieldCompatible(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(210); + fory.Register(211); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseTwoEnumFieldCompatible(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(210); + fory.Register(212); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseEnumSchemaEvolutionCompatible(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(210); + fory.Register(211); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseEnumSchemaEvolutionCompatibleReverse(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(210); + fory.Register(211); + + ReadOnlySequence sequence = new(input); + TwoEnumFieldStruct value = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseEnumSchemaEvolutionCompatibleReverse)); + Ensure(value.F1 == TestEnum.ValueC, "enum schema evolution reverse F1 mismatch"); + Ensure(value.F2 == TestEnum.ValueA, "enum schema evolution reverse F2 default mismatch"); + return fory.SerializeObject(value); + } + + private static byte[] CaseNullableFieldSchemaConsistentNotNull(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false); + fory.Register(401); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseNullableFieldSchemaConsistentNull(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false); + fory.Register(401); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseNullableFieldCompatibleNotNull(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(402); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseNullableFieldCompatibleNull(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(402); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseRefSchemaConsistent(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false, trackRef: true); + fory.Register(501); + fory.Register(502); + + ReadOnlySequence sequence = new(input); + RefOuterSchemaConsistent outer = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseRefSchemaConsistent)); + Ensure(ReferenceEquals(outer.Inner1, outer.Inner2), "reference tracking mismatch"); + return fory.SerializeObject(outer); + } + + private static byte[] CaseRefCompatible(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true, trackRef: true); + fory.Register(503); + fory.Register(504); + + ReadOnlySequence sequence = new(input); + RefOuterCompatible outer = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseRefCompatible)); + Ensure(ReferenceEquals(outer.Inner1, outer.Inner2), "reference tracking mismatch"); + return fory.SerializeObject(outer); + } + + private static byte[] CaseCollectionElementRefOverride(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false, trackRef: true); + fory.Register(701); + fory.Register(702); + + ReadOnlySequence sequence = new(input); + RefOverrideContainer container = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseCollectionElementRefOverride)); + + if (container.ListField.Count > 0) + { + RefOverrideElement? shared = container.ListField[0]; + if (shared is not null) + { + if (container.ListField.Count > 1) + { + container.ListField[1] = shared; + } + + if (container.MapField.ContainsKey("k1")) + { + container.MapField["k1"] = shared; + } + + if (container.MapField.ContainsKey("k2")) + { + container.MapField["k2"] = shared; + } + } + } + + return fory.SerializeObject(container); + } + + private static byte[] CaseCircularRefSchemaConsistent(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false, trackRef: true); + fory.Register(601); + + ReadOnlySequence sequence = new(input); + CircularRefStruct value = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseCircularRefSchemaConsistent)); + Ensure(ReferenceEquals(value, value.SelfRef), "circular ref mismatch"); + return fory.SerializeObject(value); + } + + private static byte[] CaseCircularRefCompatible(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true, trackRef: true); + fory.Register(602); + + ReadOnlySequence sequence = new(input); + CircularRefStruct value = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseCircularRefCompatible)); + Ensure(ReferenceEquals(value, value.SelfRef), "circular ref mismatch"); + return fory.SerializeObject(value); + } + + private static byte[] CaseUnsignedSchemaConsistentSimple(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false); + fory.Register(1); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseUnsignedSchemaConsistent(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false); + fory.Register(501); + return RoundTripSingle(input, fory); + } + + private static byte[] CaseUnsignedSchemaCompatible(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(502); + return RoundTripSingle(input, fory); + } + + private static byte[] RoundTripSingle(byte[] input, ForyRuntime fory) + { + ReadOnlySequence sequence = new(input); + T value = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, typeof(T).Name); + return fory.SerializeObject(value); + } + + private static void RegisterSimpleById(ForyRuntime fory) + { + fory.Register(101); + fory.Register(102); + fory.Register(103); + } + + private static void RegisterSimpleByName(ForyRuntime fory) + { + fory.Register("demo", "color"); + fory.Register("demo", "item"); + fory.Register("demo", "simple_struct"); + } + + private static ForyRuntime BuildFory(bool compatible, bool trackRef = false, bool checkStructVersion = false) + { + return ForyRuntime.Builder() + .Compatible(compatible) + .TrackRef(trackRef) + .CheckStructVersion(checkStructVersion) + .Build(); + } + + private static void Append(List target, byte[] payload) + { + target.AddRange(payload); + } + + private static void EnsureConsumed(ReadOnlySequence sequence, string caseName) + { + Ensure(sequence.Length == 0, $"case {caseName} did not consume full payload"); + } + + private static void Ensure(bool condition, string message) + { + if (!condition) + { + throw new InvalidOperationException(message); + } + } +} + +[ForyObject] +public enum Color +{ + Green, + Red, + Blue, + White, +} + +[ForyObject] +public sealed class Item +{ + public string Name { get; set; } = string.Empty; +} + +[ForyObject] +public sealed class SimpleStruct +{ + public Dictionary F1 { get; set; } = []; + public int F2 { get; set; } + public Item F3 { get; set; } = new(); + public string F4 { get; set; } = string.Empty; + public Color F5 { get; set; } + public List F6 { get; set; } = []; + public int F7 { get; set; } + public int F8 { get; set; } + public int Last { get; set; } +} + +[ForyObject] +public sealed class Item1 +{ + public int F1 { get; set; } + public int F2 { get; set; } + public int F3 { get; set; } + public int F4 { get; set; } + public int F5 { get; set; } + public int F6 { get; set; } +} + +[ForyObject] +public sealed class StructWithList +{ + public List Items { get; set; } = []; +} + +[ForyObject] +public sealed class StructWithMap +{ + public NullableKeyDictionary Data { get; set; } = new(); +} + +[ForyObject] +public sealed class StructWithUnion2 +{ + public Union2 Union { get; set; } = Union2.OfT1(string.Empty); +} + +[ForyObject] +public sealed class MyStruct +{ + public int Id { get; set; } +} + +[ForyObject] +public sealed class MyExt +{ + public int Id { get; set; } +} + +public sealed class MyExtSerializer : Serializer +{ + public override TypeId StaticTypeId => TypeId.Ext; + public override bool IsNullableType => true; + public override bool IsReferenceTrackableType => true; + public override MyExt DefaultValue => null!; + public override bool IsNone(in MyExt value) => value is null; + + public override void WriteData(ref WriteContext context, in MyExt value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteVarInt32((value ?? new MyExt()).Id); + } + + public override MyExt ReadData(ref ReadContext context) + { + return new MyExt { Id = context.Reader.ReadVarInt32() }; + } +} + +[ForyObject] +public sealed class MyWrapper +{ + public Color Color { get; set; } + public MyExt MyExt { get; set; } = new(); + public MyStruct MyStruct { get; set; } = new(); +} + +[ForyObject] +public sealed class EmptyWrapper +{ +} + +[ForyObject] +public sealed class VersionCheckStruct +{ + public int F1 { get; set; } + public string? F2 { get; set; } + public double F3 { get; set; } +} + +[ForyObject] +public sealed class Dog +{ + public int Age { get; set; } + public string? Name { get; set; } +} + +[ForyObject] +public sealed class Cat +{ + public int Age { get; set; } + public int Lives { get; set; } +} + +[ForyObject] +public sealed class AnimalListHolder +{ + public List Animals { get; set; } = []; +} + +[ForyObject] +public sealed class AnimalMapHolder +{ + public Dictionary AnimalMap { get; set; } = []; +} + +[ForyObject] +public sealed class OneFieldStruct +{ + public int Value { get; set; } +} + +[ForyObject] +public sealed class OneStringFieldStruct +{ + public string? F1 { get; set; } +} + +[ForyObject] +public sealed class TwoStringFieldStruct +{ + public string F1 { get; set; } = string.Empty; + public string F2 { get; set; } = string.Empty; +} + +[ForyObject] +public enum TestEnum +{ + ValueA, + ValueB, + ValueC, +} + +[ForyObject] +public sealed class OneEnumFieldStruct +{ + public TestEnum F1 { get; set; } +} + +[ForyObject] +public sealed class TwoEnumFieldStruct +{ + public TestEnum F1 { get; set; } + public TestEnum F2 { get; set; } +} + +[ForyObject] +public sealed class NullableComprehensiveSchemaConsistent +{ + public sbyte ByteField { get; set; } + public short ShortField { get; set; } + public int IntField { get; set; } + public long LongField { get; set; } + public float FloatField { get; set; } + public double DoubleField { get; set; } + public bool BoolField { get; set; } + + public string StringField { get; set; } = string.Empty; + public List ListField { get; set; } = []; + public HashSet SetField { get; set; } = []; + public NullableKeyDictionary MapField { get; set; } = new(); + + public int? NullableInt { get; set; } + public long? NullableLong { get; set; } + public float? NullableFloat { get; set; } + public double? NullableDouble { get; set; } + public bool? NullableBool { get; set; } + public string? NullableString { get; set; } + public List? NullableList { get; set; } + public HashSet? NullableSet { get; set; } + public NullableKeyDictionary? NullableMap { get; set; } +} + +[ForyObject] +public sealed class NullableComprehensiveCompatible +{ + public sbyte ByteField { get; set; } + public short ShortField { get; set; } + public int IntField { get; set; } + public long LongField { get; set; } + public float FloatField { get; set; } + public double DoubleField { get; set; } + public bool BoolField { get; set; } + + public int BoxedInt { get; set; } + public long BoxedLong { get; set; } + public float BoxedFloat { get; set; } + public double BoxedDouble { get; set; } + public bool BoxedBool { get; set; } + + public string StringField { get; set; } = string.Empty; + public List ListField { get; set; } = []; + public HashSet SetField { get; set; } = []; + public NullableKeyDictionary MapField { get; set; } = new(); + + public int NullableInt1 { get; set; } + public long NullableLong1 { get; set; } + public float NullableFloat1 { get; set; } + public double NullableDouble1 { get; set; } + public bool NullableBool1 { get; set; } + + public string NullableString2 { get; set; } = string.Empty; + public List NullableList2 { get; set; } = []; + public HashSet NullableSet2 { get; set; } = []; + public NullableKeyDictionary NullableMap2 { get; set; } = new(); +} + +[ForyObject] +public sealed class RefInnerSchemaConsistent +{ + public int Id { get; set; } + public string Name { get; set; } = string.Empty; +} + +[ForyObject] +public sealed class RefOuterSchemaConsistent +{ + public RefInnerSchemaConsistent? Inner1 { get; set; } + public RefInnerSchemaConsistent? Inner2 { get; set; } +} + +[ForyObject] +public sealed class RefInnerCompatible +{ + public int Id { get; set; } + public string Name { get; set; } = string.Empty; +} + +[ForyObject] +public sealed class RefOuterCompatible +{ + public RefInnerCompatible? Inner1 { get; set; } + public RefInnerCompatible? Inner2 { get; set; } +} + +[ForyObject] +public sealed class RefOverrideElement +{ + public int Id { get; set; } + public string Name { get; set; } = string.Empty; +} + +[ForyObject] +public sealed class RefOverrideContainer +{ + public List ListField { get; set; } = []; + public Dictionary MapField { get; set; } = []; +} + +[ForyObject] +public sealed class CircularRefStruct +{ + public string Name { get; set; } = string.Empty; + public CircularRefStruct? SelfRef { get; set; } +} + +[ForyObject] +public sealed class UnsignedSchemaConsistentSimple +{ + [Field(Encoding = FieldEncoding.Tagged)] + public ulong U64Tagged { get; set; } + + [Field(Encoding = FieldEncoding.Tagged)] + public ulong? U64TaggedNullable { get; set; } +} + +[ForyObject] +public sealed class UnsignedSchemaConsistent +{ + public byte U8Field { get; set; } + public ushort U16Field { get; set; } + public uint U32VarField { get; set; } + + [Field(Encoding = FieldEncoding.Fixed)] + public uint U32FixedField { get; set; } + + public ulong U64VarField { get; set; } + + [Field(Encoding = FieldEncoding.Fixed)] + public ulong U64FixedField { get; set; } + + [Field(Encoding = FieldEncoding.Tagged)] + public ulong U64TaggedField { get; set; } + + public byte? U8NullableField { get; set; } + public ushort? U16NullableField { get; set; } + public uint? U32VarNullableField { get; set; } + + [Field(Encoding = FieldEncoding.Fixed)] + public uint? U32FixedNullableField { get; set; } + + public ulong? U64VarNullableField { get; set; } + + [Field(Encoding = FieldEncoding.Fixed)] + public ulong? U64FixedNullableField { get; set; } + + [Field(Encoding = FieldEncoding.Tagged)] + public ulong? U64TaggedNullableField { get; set; } +} + +[ForyObject] +public sealed class UnsignedSchemaCompatible +{ + public byte? U8Field1 { get; set; } + public ushort? U16Field1 { get; set; } + public uint? U32VarField1 { get; set; } + + [Field(Encoding = FieldEncoding.Fixed)] + public uint? U32FixedField1 { get; set; } + + public ulong? U64VarField1 { get; set; } + + [Field(Encoding = FieldEncoding.Fixed)] + public ulong? U64FixedField1 { get; set; } + + [Field(Encoding = FieldEncoding.Tagged)] + public ulong? U64TaggedField1 { get; set; } + + public byte U8Field2 { get; set; } + public ushort U16Field2 { get; set; } + public uint U32VarField2 { get; set; } + + [Field(Encoding = FieldEncoding.Fixed)] + public uint U32FixedField2 { get; set; } + + public ulong U64VarField2 { get; set; } + + [Field(Encoding = FieldEncoding.Fixed)] + public ulong U64FixedField2 { get; set; } + + [Field(Encoding = FieldEncoding.Tagged)] + public ulong U64TaggedField2 { get; set; } +} diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java new file mode 100644 index 0000000000..17005fa0c8 --- /dev/null +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.xlang; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.fory.test.TestUtils; +import org.testng.SkipException; +import org.testng.annotations.Test; + +/** Executes cross-language tests against the C# implementation. */ +@Test +public class CSharpXlangTest extends XlangTestBase { + private static final String CSHARP_DLL = "Fory.XlangPeer.dll"; + private static final File CSHARP_DIR = new File("../../csharp"); + private static final File CSHARP_BINARY_DIR = + new File(CSHARP_DIR, "tests/Fory.XlangPeer/bin/Debug/net8.0"); + private volatile boolean peerBuilt; + + @Override + protected void ensurePeerReady() { + String enabled = System.getenv("FORY_CSHARP_JAVA_CI"); + if (!"1".equals(enabled)) { + throw new SkipException("Skipping CSharpXlangTest: FORY_CSHARP_JAVA_CI not set to 1"); + } + + if (!isDotnetAvailable()) { + throw new SkipException("Skipping CSharpXlangTest: dotnet is not available"); + } + + try { + ensurePeerBuilt(); + } catch (IOException e) { + throw new RuntimeException("Failed to build C# peer", e); + } + } + + @Override + protected CommandContext buildCommandContext(String caseName, Path dataFile) throws IOException { + ensurePeerBuilt(); + + List command = new ArrayList<>(); + command.add("dotnet"); + command.add(new File(CSHARP_BINARY_DIR, CSHARP_DLL).getAbsolutePath()); + command.add("--case"); + command.add(caseName); + + Map env = envBuilder(dataFile); + return new CommandContext(command, env, CSHARP_BINARY_DIR); + } + + private boolean isDotnetAvailable() { + try { + Process process = new ProcessBuilder("dotnet", "--version").start(); + if (!process.waitFor(30, TimeUnit.SECONDS)) { + process.destroyForcibly(); + return false; + } + return process.exitValue() == 0; + } catch (IOException | InterruptedException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + return false; + } + } + + private void ensurePeerBuilt() throws IOException { + if (peerBuilt) { + return; + } + + synchronized (this) { + if (peerBuilt) { + return; + } + + List buildCommand = + Arrays.asList( + "dotnet", "build", "tests/Fory.XlangPeer/Fory.XlangPeer.csproj", "-c", "Debug"); + boolean built = + TestUtils.executeCommand(buildCommand, 180, Collections.emptyMap(), CSHARP_DIR); + if (!built) { + throw new IOException("dotnet build failed for csharp/tests/Fory.XlangPeer"); + } + + File dll = new File(CSHARP_BINARY_DIR, CSHARP_DLL); + if (!dll.exists()) { + throw new IOException("C# peer assembly not found: " + dll.getAbsolutePath()); + } + + peerBuilt = true; + } + } + + // ============================================================================ + // Test methods - duplicated from XlangTestBase for Maven Surefire discovery + // ============================================================================ + + @Test(groups = "xlang") + public void testBuffer() throws java.io.IOException { + super.testBuffer(); + } + + @Test(groups = "xlang") + public void testBufferVar() throws java.io.IOException { + super.testBufferVar(); + } + + @Test(groups = "xlang") + public void testMurmurHash3() throws java.io.IOException { + super.testMurmurHash3(); + } + + @Test(groups = "xlang") + public void testStringSerializer() throws Exception { + super.testStringSerializer(); + } + + @Test(groups = "xlang") + public void testCrossLanguageSerializer() throws Exception { + super.testCrossLanguageSerializer(); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testSimpleStruct(boolean enableCodegen) throws java.io.IOException { + super.testSimpleStruct(enableCodegen); + } + + @Test(groups = "xlang") + public void testSimpleNamedStructCodegenEnabled() throws java.io.IOException { + super.testSimpleNamedStruct(false); + } + + @Test(groups = "xlang") + public void testSimpleNamedStructCodegenDisabled() throws java.io.IOException { + super.testSimpleNamedStruct(false); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testList(boolean enableCodegen) throws java.io.IOException { + super.testList(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testMap(boolean enableCodegen) throws java.io.IOException { + super.testMap(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testInteger(boolean enableCodegen) throws java.io.IOException { + super.testInteger(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testItem(boolean enableCodegen) throws java.io.IOException { + super.testItem(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testColor(boolean enableCodegen) throws java.io.IOException { + super.testColor(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testStructWithList(boolean enableCodegen) throws java.io.IOException { + super.testStructWithList(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testStructWithMap(boolean enableCodegen) throws java.io.IOException { + super.testStructWithMap(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testCollectionElementRefOverride(boolean enableCodegen) throws java.io.IOException { + super.testCollectionElementRefOverride(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testSkipIdCustom(boolean enableCodegen) throws java.io.IOException { + super.testSkipIdCustom(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testSkipNameCustom(boolean enableCodegen) throws java.io.IOException { + super.testSkipNameCustom(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testConsistentNamed(boolean enableCodegen) throws java.io.IOException { + super.testConsistentNamed(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testStructVersionCheck(boolean enableCodegen) throws java.io.IOException { + super.testStructVersionCheck(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testPolymorphicList(boolean enableCodegen) throws java.io.IOException { + super.testPolymorphicList(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testPolymorphicMap(boolean enableCodegen) throws java.io.IOException { + super.testPolymorphicMap(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testOneStringFieldSchemaConsistent(boolean enableCodegen) throws java.io.IOException { + super.testOneStringFieldSchemaConsistent(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testOneStringFieldCompatible(boolean enableCodegen) throws java.io.IOException { + super.testOneStringFieldCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testTwoStringFieldCompatible(boolean enableCodegen) throws java.io.IOException { + super.testTwoStringFieldCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testSchemaEvolutionCompatible(boolean enableCodegen) throws java.io.IOException { + super.testSchemaEvolutionCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testOneEnumFieldSchemaConsistent(boolean enableCodegen) throws java.io.IOException { + super.testOneEnumFieldSchemaConsistent(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testOneEnumFieldCompatible(boolean enableCodegen) throws java.io.IOException { + super.testOneEnumFieldCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testTwoEnumFieldCompatible(boolean enableCodegen) throws java.io.IOException { + super.testTwoEnumFieldCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testEnumSchemaEvolutionCompatible(boolean enableCodegen) throws java.io.IOException { + super.testEnumSchemaEvolutionCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testNullableFieldSchemaConsistentNotNull(boolean enableCodegen) + throws java.io.IOException { + super.testNullableFieldSchemaConsistentNotNull(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testNullableFieldSchemaConsistentNull(boolean enableCodegen) + throws java.io.IOException { + super.testNullableFieldSchemaConsistentNull(enableCodegen); + } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testNullableFieldCompatibleNotNull(boolean enableCodegen) throws java.io.IOException { + super.testNullableFieldCompatibleNotNull(enableCodegen); + } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testNullableFieldCompatibleNull(boolean enableCodegen) throws java.io.IOException { + super.testNullableFieldCompatibleNull(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testUnionXlang(boolean enableCodegen) throws java.io.IOException { + super.testUnionXlang(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testRefSchemaConsistent(boolean enableCodegen) throws java.io.IOException { + super.testRefSchemaConsistent(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testRefCompatible(boolean enableCodegen) throws java.io.IOException { + super.testRefCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testCircularRefSchemaConsistent(boolean enableCodegen) throws java.io.IOException { + super.testCircularRefSchemaConsistent(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testCircularRefCompatible(boolean enableCodegen) throws java.io.IOException { + super.testCircularRefCompatible(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testUnsignedSchemaConsistent(boolean enableCodegen) throws java.io.IOException { + super.testUnsignedSchemaConsistent(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testUnsignedSchemaConsistentSimple(boolean enableCodegen) throws java.io.IOException { + super.testUnsignedSchemaConsistentSimple(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testUnsignedSchemaCompatible(boolean enableCodegen) throws java.io.IOException { + super.testUnsignedSchemaCompatible(enableCodegen); + } +} diff --git a/licenserc.toml b/licenserc.toml index e4982ff2e8..cc0ecab0be 100644 --- a/licenserc.toml +++ b/licenserc.toml @@ -61,5 +61,5 @@ excludes = [ ] [mapping.DOUBLESLASH_STYLE] -extensions = ['go'] +extensions = ['go', 'cs'] filenames = ['go.mod']