Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,10 @@ public virtual void Write(string value)

if (_useFastUtf8)
{
if (value.Length <= 127 / 3)
// If this is a non-derived BinaryWriter, then we can bypass the Write7BitEncodedInt call.
// But when this is a derived instance, call must not bypass it for compatibility reasons
// as it calls the virtual Write(byte) overload.
if (GetType() == typeof(BinaryWriter) && value.Length <= 127 / 3)
{
// Max expansion: each char -> 3 bytes, so 127 bytes max of data, +1 for length prefix
Span<byte> buffer = stackalloc byte[128];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Text;
Expand Down Expand Up @@ -188,26 +188,359 @@ protected virtual Stream CreateStream()

private void WriteTest<T>(T[] testElements, Action<BinaryWriter, T> write, Func<BinaryReader, T> read)
{
// Non-derived BinaryWriter/BinaryReader, UTF-8 encoding
using (Stream memStream = CreateStream())
using (BinaryWriter writer = new BinaryWriter(memStream))
using (BinaryReader reader = new BinaryReader(memStream))
using (var writer = new BinaryWriter(memStream))
using (var reader = new BinaryReader(memStream))
{
for (int i = 0; i < testElements.Length; i++)
{
write(writer, testElements[i]);
}
WriteTest(memStream, writer, reader, testElements, write, read);
}

// Derived BinaryWriter/BinaryReader, UTF-8 encoding
using (Stream memStream = CreateStream())
using (var writer = new TestWriter(memStream))
using (var reader = new TestReader(memStream))
{
WriteTest(memStream, writer, reader, testElements, write, read);
}

// Non-derived BinaryWriter/BinaryReader, UTF-16 encoding
using (Stream memStream = CreateStream())
using (var writer = new BinaryWriter(memStream, Encoding.Unicode))
using (var reader = new BinaryReader(memStream, Encoding.Unicode))
{
WriteTest(memStream, writer, reader, testElements, write, read);
}

// Derived BinaryWriter/BinaryReader, UTF-16 encoding
using (Stream memStream = CreateStream())
using (var writer = new TestWriter(memStream, Encoding.Unicode))
using (var reader = new TestReader(memStream, Encoding.Unicode))
{
WriteTest(memStream, writer, reader, testElements, write, read);
}
}

private void WriteTest<T>(Stream stream, BinaryWriter writer, BinaryReader reader, T[] testElements, Action<BinaryWriter, T> write, Func<BinaryReader, T> read)
{
for (int i = 0; i < testElements.Length; i++)
{
write(writer, testElements[i]);
}

writer.Flush();
stream.Position = 0;

for (int i = 0; i < testElements.Length; i++)
{
Assert.Equal(testElements[i], read(reader));
}

if (writer is TestWriter derivedWriter && reader is TestReader derivedReader)
{
// Checking if the internally tracked positions of a derived reader/writer are in sync (#107265)
Assert.Equal(derivedReader.Position, derivedWriter.Position);
}

// We've reached the end of the stream. Check for expected EndOfStreamException
Assert.Throws<EndOfStreamException>(() => read(reader));
}


private class TestWriter : BinaryWriter
{
private readonly Encoding _encoding;
public long Position { get; private set; }

public TestWriter(Stream stream, Encoding? encoding = null)
: base(stream, encoding ?? Encoding.UTF8)
{
_encoding = encoding ?? Encoding.UTF8;
}

public override void Write(bool value)
{
Advance(sizeof(byte));
base.Write(value);
}

public override void Write(byte value)
{
Advance(sizeof(byte));
base.Write(value);
}

public override void Write(byte[] buffer)
{
Advance(buffer.Length);
base.Write(buffer);
}

public override void Write(byte[] buffer, int index, int count)
{
Advance(count);
base.Write(buffer, index, count);
}

public override void Write(char ch)
{
Advance(_encoding.GetBytes([ch]).Length);
base.Write(ch);
}

public override void Write(char[] chars)
{
Advance(_encoding.GetBytes(chars).Length);
base.Write(chars);
}

public override void Write(char[] chars, int index, int count)
{
Advance(_encoding.GetBytes(chars, index, count).Length);
base.Write(chars, index, count);
}

public override void Write(decimal value)
{
Advance(sizeof(decimal));
base.Write(value);
}

public override void Write(double value)
{
Advance(sizeof(double));
base.Write(value);
}

public override void Write(float value)
{
Advance(sizeof(float));
base.Write(value);
}

public override void Write(int value)
{
Advance(sizeof(int));
base.Write(value);
}

public override void Write(long value)
{
Advance(sizeof(long));
base.Write(value);
}

public override void Write(sbyte value)
{
Advance(sizeof(sbyte));
base.Write(value);
}

public override void Write(short value)
{
Advance(sizeof(short));
base.Write(value);
}

public override void Write(string value)
{
Advance(_encoding.GetBytes(value).Length);
base.Write(value);
}

public override void Write(uint value)
{
Advance(sizeof(uint));
base.Write(value);
}

public override void Write(ulong value)
{
Advance(sizeof(ulong));
base.Write(value);
}

public override void Write(ushort value)
{
Advance(sizeof(ushort));
base.Write(value);
}

public override unsafe void Write(Half value)
{
Advance(sizeof(Half));
base.Write(value);
}

public override void Write(ReadOnlySpan<byte> buffer)
{
Advance(buffer.Length);
base.Write(buffer);
}

public override void Write(ReadOnlySpan<char> chars)
{
Advance(_encoding.GetBytes(chars.ToArray()).Length);
base.Write(chars);
}

private void Advance(int offset) => Position += offset;
}

private class TestReader : BinaryReader
{
private readonly Encoding _encoding;
public long Position { get; private set; }

public TestReader(Stream s, Encoding? encoding = null)
: base(s, encoding ?? Encoding.UTF8)
{
_encoding = encoding ?? Encoding.UTF8;
}

public override int Read()
{
var current = BaseStream.Position;
var result = base.Read();
Advance(BaseStream.Position - current);
return result;
}

writer.Flush();
memStream.Position = 0;
public override int Read(byte[] buffer, int index, int count)
{
var result = base.Read(buffer, index, count);
Advance(result);
return result;
}

for (int i = 0; i < testElements.Length; i++)
{
Assert.Equal(testElements[i], read(reader));
}
public override int Read(char[] buffer, int index, int count)
{
var result = base.Read(buffer, index, count);
Advance(_encoding.GetBytes(buffer, 0, result).Length);
return result;
}

// We've reached the end of the stream. Check for expected EndOfStreamException
Assert.Throws<EndOfStreamException>(() => read(reader));
public override bool ReadBoolean()
{
Advance(sizeof(bool));
return base.ReadBoolean();
}

public override byte ReadByte()
{
Advance(sizeof(byte));
return base.ReadByte();
}

public override byte[] ReadBytes(int count)
{
var result = base.ReadBytes(count);
Advance(result.Length);
return result;
}

public override char ReadChar()
{
var result = base.ReadChar();
Advance(_encoding.GetBytes([result]).Length);
return result;
}

public override char[] ReadChars(int count)
{
var result = base.ReadChars(count);
Advance(_encoding.GetBytes(result).Length);
return result;
}

public override decimal ReadDecimal()
{
Advance(sizeof(decimal));
return base.ReadDecimal();
}

public override double ReadDouble()
{
Advance(sizeof(double));
return base.ReadDouble();
}

public override short ReadInt16()
{
Advance(sizeof(short));
return base.ReadInt16();
}

public override int ReadInt32()
{
Advance(sizeof(int));
return base.ReadInt32();
}

public override long ReadInt64()
{
Advance(sizeof(long));
return base.ReadInt64();
}

public override sbyte ReadSByte()
{
Advance(sizeof(sbyte));
return base.ReadSByte();
}

public override float ReadSingle()
{
Advance(sizeof(float));
return base.ReadSingle();
}

public override string ReadString()
{
var result = base.ReadString();
Advance(_encoding.GetBytes(result).Length);
return result;
}

public override ushort ReadUInt16()
{
Advance(sizeof(ushort));
return base.ReadUInt16();
}

public override uint ReadUInt32()
{
Advance(sizeof(uint));
return base.ReadUInt32();
}

public override ulong ReadUInt64()
{
Advance(sizeof(ulong));
return base.ReadUInt64();
}

public override unsafe Half ReadHalf()
{
Advance(sizeof(Half));
return base.ReadHalf();
}

public override int Read(Span<byte> buffer)
{
var result = base.Read(buffer);
Advance(result);
return result;
}

public override int Read(Span<char> buffer)
{
var result = base.Read(buffer);
Advance(_encoding.GetBytes(buffer[..result].ToArray()).Length);
return result;
}

private void Advance(long offset) => Position += offset;
}
}
}
Loading