Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ public void KeyMatchTest()
var user = "SYSDBA";
var password = "masterkey";
var client = new Srp256Client();
var salt = client.GetSalt();
var salt = Srp256Client.GetSalt();
var serverKeyPair = client.ServerSeed(user, password, salt);
var serverSessionKey = client.GetServerSessionKey(user, password, salt, client.PublicKey, serverKeyPair.Item1, serverKeyPair.Item2);
var serverSessionKey = Srp256Client.GetServerSessionKey(user, password, salt, client.PublicKey, serverKeyPair.Item1, serverKeyPair.Item2);
client.ClientProof(user, password, salt, serverKeyPair.Item1);
Assert.AreEqual(serverSessionKey.ToString(), client.SessionKey.ToString());
}
Expand Down
4 changes: 2 additions & 2 deletions src/FirebirdSql.Data.FirebirdClient.Tests/SrpClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ public void KeyMatchTest()
var user = "SYSDBA";
var password = "masterkey";
var client = new SrpClient();
var salt = client.GetSalt();
var salt = SrpClient.GetSalt();
var serverKeyPair = client.ServerSeed(user, password, salt);
var serverSessionKey = client.GetServerSessionKey(user, password, salt, client.PublicKey, serverKeyPair.Item1, serverKeyPair.Item2);
var serverSessionKey = SrpClient.GetServerSessionKey(user, password, salt, client.PublicKey, serverKeyPair.Item1, serverKeyPair.Item2);
client.ClientProof(user, password, salt, serverKeyPair.Item1);
Assert.AreEqual(serverSessionKey.ToString(), client.SessionKey.ToString());
}
Expand Down
190 changes: 153 additions & 37 deletions src/FirebirdSql.Data.FirebirdClient/Client/Managed/AuthBlock.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ sealed class AuthBlock

public bool WireCryptInitialized { get; private set; }

private const int STACKALLOC_LIMIT = 512;

public AuthBlock(GdsConnection connection, string user, string password, WireCryptOption wireCrypt)
{
_srp256 = new Srp256Client();
Expand All @@ -68,60 +70,160 @@ public byte[] UserIdentificationData()
{
using (var result = new MemoryStream(256))
{
var userString = Environment.GetEnvironmentVariable("USERNAME") ?? Environment.GetEnvironmentVariable("USER") ?? string.Empty;
var user = Encoding.UTF8.GetBytes(userString);
result.WriteByte(IscCodes.CNCT_user);
result.WriteByte((byte)user.Length);
result.Write(user, 0, user.Length);
{
var userString = Environment.GetEnvironmentVariable("USERNAME") ?? Environment.GetEnvironmentVariable("USER") ?? string.Empty;
var slen = Encoding.UTF8.GetByteCount(userString);
byte[] rented = null;
Span<byte> user = slen > STACKALLOC_LIMIT
? (rented = System.Buffers.ArrayPool<byte>.Shared.Rent(slen)).AsSpan(0, slen)
: stackalloc byte[slen];
int real_len = Encoding.UTF8.GetBytes(userString, user);
result.WriteByte(IscCodes.CNCT_user);
result.WriteByte((byte)real_len);
result.Write(user);
if (rented != null)
{
System.Buffers.ArrayPool<byte>.Shared.Return(rented, clearArray: true);
}
}

var host = Encoding.UTF8.GetBytes(Dns.GetHostName());
result.WriteByte(IscCodes.CNCT_host);
result.WriteByte((byte)host.Length);
result.Write(host, 0, host.Length);
{
var hostName = Dns.GetHostName();
var slen = Encoding.UTF8.GetByteCount(hostName);
byte[] rented = null;
Span<byte> host = slen > STACKALLOC_LIMIT
? (rented = System.Buffers.ArrayPool<byte>.Shared.Rent(slen)).AsSpan(0, slen)
: stackalloc byte[slen];
int real_len = Encoding.UTF8.GetBytes(hostName, host);
result.WriteByte(IscCodes.CNCT_host);
result.WriteByte((byte)real_len);
result.Write(host);
if (rented != null)
{
System.Buffers.ArrayPool<byte>.Shared.Return(rented, clearArray: true);
}
}

result.WriteByte(IscCodes.CNCT_user_verification);
result.WriteByte(0);

if (!string.IsNullOrEmpty(User))
{
var login = Encoding.UTF8.GetBytes(User);
result.WriteByte(IscCodes.CNCT_login);
result.WriteByte((byte)login.Length);
result.Write(login, 0, login.Length);

var pluginNameBytes = Encoding.UTF8.GetBytes(_srp256.Name);
result.WriteByte(IscCodes.CNCT_plugin_name);
result.WriteByte((byte)pluginNameBytes.Length);
result.Write(pluginNameBytes, 0, pluginNameBytes.Length);
var specificData = Encoding.UTF8.GetBytes(_srp256.PublicKeyHex);
WriteMultiPartHelper(result, IscCodes.CNCT_specific_data, specificData);

var plugins = string.Join(",", new[] { _srp256.Name, _srp.Name });
var pluginsBytes = Encoding.UTF8.GetBytes(plugins);
result.WriteByte(IscCodes.CNCT_plugin_list);
result.WriteByte((byte)pluginsBytes.Length);
result.Write(pluginsBytes, 0, pluginsBytes.Length);
{
var slen = Encoding.UTF8.GetByteCount(User);
byte[] rented = null;
Span<byte> bytes = slen > STACKALLOC_LIMIT
? (rented = System.Buffers.ArrayPool<byte>.Shared.Rent(slen)).AsSpan(0, slen)
: stackalloc byte[slen];
int real_len = Encoding.UTF8.GetBytes(User, bytes);
result.WriteByte(IscCodes.CNCT_login);
result.WriteByte((byte)real_len);
result.Write(bytes);
if (rented != null)
{
System.Buffers.ArrayPool<byte>.Shared.Return(rented, clearArray: true);
}
}
{
var slen = Encoding.UTF8.GetByteCount(_srp256.Name);
byte[] rented = null;
Span<byte> bytes = slen > STACKALLOC_LIMIT
? (rented = System.Buffers.ArrayPool<byte>.Shared.Rent(slen)).AsSpan(0, slen)
: stackalloc byte[slen];
int real_len = Encoding.UTF8.GetBytes(_srp256.Name, bytes);
result.WriteByte(IscCodes.CNCT_plugin_name);
result.WriteByte((byte)real_len);
result.Write(bytes[..real_len]);
if (rented != null)
{
System.Buffers.ArrayPool<byte>.Shared.Return(rented, clearArray: true);
}
}
{
var slen = Encoding.UTF8.GetByteCount(_srp256.PublicKeyHex);
byte[] rented = null;
Span<byte> specificData = slen > STACKALLOC_LIMIT
? (rented = System.Buffers.ArrayPool<byte>.Shared.Rent(slen)).AsSpan(0, slen)
: stackalloc byte[slen];
Encoding.UTF8.GetBytes(_srp256.PublicKeyHex.AsSpan(), specificData);
WriteMultiPartHelper(result, IscCodes.CNCT_specific_data, specificData);
if (rented != null)
{
System.Buffers.ArrayPool<byte>.Shared.Return(rented, clearArray: true);
}
}
{
var slen1 = Encoding.UTF8.GetByteCount(_srp256.Name);
byte[] rented1 = null;
Span<byte> bytes1 = slen1 > STACKALLOC_LIMIT
? (rented1 = System.Buffers.ArrayPool<byte>.Shared.Rent(slen1)).AsSpan(0, slen1)
: stackalloc byte[slen1];
Span<byte> bytes2 = stackalloc byte[1];
var slen3 = Encoding.UTF8.GetByteCount(_srp.Name);
byte[] rented3 = null;
Span<byte> bytes3 = slen3 > STACKALLOC_LIMIT
? (rented3 = System.Buffers.ArrayPool<byte>.Shared.Rent(slen3)).AsSpan(0, slen3)
: stackalloc byte[slen3];
int l1 = Encoding.UTF8.GetBytes(_srp256.Name.AsSpan(), bytes1);
int l2 = Encoding.UTF8.GetBytes(",".AsSpan(), bytes2);
int l3 = Encoding.UTF8.GetBytes(_srp.Name.AsSpan(), bytes3);
result.WriteByte(IscCodes.CNCT_plugin_list);
result.WriteByte((byte)(l1+l2+l3));
result.Write(bytes1);
result.Write(bytes2);
result.Write(bytes3);
if (rented1 != null)
{
System.Buffers.ArrayPool<byte>.Shared.Return(rented1, clearArray: true);
}
if (rented3 != null)
{
System.Buffers.ArrayPool<byte>.Shared.Return(rented3, clearArray: true);
}
}

result.WriteByte(IscCodes.CNCT_client_crypt);
result.WriteByte(4);
result.Write(TypeEncoder.EncodeInt32(WireCryptOptionValue(WireCrypt)), 0, 4);
{
result.WriteByte(IscCodes.CNCT_client_crypt);
result.WriteByte(4);
Span<byte> bytes = stackalloc byte[4];
if (!BitConverter.TryWriteBytes(bytes, IPAddress.NetworkToHostOrder(WireCryptOptionValue(WireCrypt))))
{
throw new InvalidOperationException("Failed to write wire crypt option bytes.");
}
result.Write(bytes);
}
}
else
{
var pluginNameBytes = Encoding.UTF8.GetBytes(_sspi.Name);
var slen = Encoding.UTF8.GetByteCount(_sspi.Name);
byte[] rented = null;
Span<byte> pluginNameBytes = slen > STACKALLOC_LIMIT
? (rented = System.Buffers.ArrayPool<byte>.Shared.Rent(slen)).AsSpan(0, slen)
: stackalloc byte[slen];
int pluginNameLen = Encoding.UTF8.GetBytes(_sspi.Name.AsSpan(), pluginNameBytes);
result.WriteByte(IscCodes.CNCT_plugin_name);
result.WriteByte((byte)pluginNameBytes.Length);
result.Write(pluginNameBytes, 0, pluginNameBytes.Length);
result.WriteByte((byte)pluginNameLen);
result.Write(pluginNameBytes[..pluginNameLen]);

var specificData = _sspi.InitializeClientSecurity();
WriteMultiPartHelper(result, IscCodes.CNCT_specific_data, specificData);

result.WriteByte(IscCodes.CNCT_plugin_list);
result.WriteByte((byte)pluginNameBytes.Length);
result.Write(pluginNameBytes, 0, pluginNameBytes.Length);
result.WriteByte((byte)pluginNameLen);
result.Write(pluginNameBytes[..pluginNameLen]);

result.WriteByte(IscCodes.CNCT_client_crypt);
result.WriteByte(4);
result.Write(TypeEncoder.EncodeInt32(IscCodes.WIRE_CRYPT_DISABLED), 0, 4);
Span<byte> wireCryptBytes = stackalloc byte[4];
if (!BitConverter.TryWriteBytes(wireCryptBytes, IPAddress.NetworkToHostOrder(IscCodes.WIRE_CRYPT_DISABLED)))
{
throw new InvalidOperationException("Failed to write wire crypt disabled bytes.");
}
result.Write(wireCryptBytes);
if (rented != null)
{
System.Buffers.ArrayPool<byte>.Shared.Return(rented, clearArray: true);
}
}

return result.ToArray();
Expand Down Expand Up @@ -309,7 +411,21 @@ void ReleaseAuth()
_sspi = null;
}

static void WriteMultiPartHelper(Stream stream, byte code, byte[] data)
static void WriteMultiPartHelper(MemoryStream stream, byte code, byte[] data)
{
const int MaxLength = 255 - 1;
var part = 0;
for (var i = 0; i < data.Length; i += MaxLength) {
stream.WriteByte(code);
var length = Math.Min(data.Length - i, MaxLength);
stream.WriteByte((byte)(length + 1));
stream.WriteByte((byte)part);
stream.Write(data, i, length);
part++;
}
}

static void WriteMultiPartHelper(MemoryStream stream, byte code, ReadOnlySpan<byte> data)
{
const int MaxLength = 255 - 1;
var part = 0;
Expand All @@ -319,7 +435,7 @@ static void WriteMultiPartHelper(Stream stream, byte code, byte[] data)
var length = Math.Min(data.Length - i, MaxLength);
stream.WriteByte((byte)(length + 1));
stream.WriteByte((byte)part);
stream.Write(data, i, length);
stream.Write(data[i..(i+length)]);
part++;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

//$Authors = Jiri Cincura (jiri@cincura.net)

using System;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;

Expand All @@ -36,12 +38,31 @@ public int Read(byte[] buffer, int offset, int count)
{
return _stream.Read(buffer, offset, count);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int Read(Span<byte> buffer, int offset, int count)
{
return _stream.Read(buffer[offset..(offset+count)]);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ValueTask<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
{
return new ValueTask<int>(_stream.ReadAsync(buffer, offset, count, cancellationToken));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ValueTask<int> ReadAsync(Memory<byte> buffer, int offset, int count, CancellationToken cancellationToken = default)
{
return _stream.ReadAsync(buffer.Slice(offset, count), cancellationToken);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Write(ReadOnlySpan<byte> buffer)
{
_stream.Write(buffer);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Write(byte[] buffer, int offset, int count)
{
Expand All @@ -53,6 +74,12 @@ public ValueTask WriteAsync(byte[] buffer, int offset, int count, CancellationTo
return new ValueTask(_stream.WriteAsync(buffer, offset, count, cancellationToken));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, int offset, int count, CancellationToken cancellationToken = default)
{
return _stream.WriteAsync(buffer.Slice(offset, count), cancellationToken);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Flush()
{
Expand Down
Loading