Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 97 additions & 8 deletions backend/src/Taskdeck.Application/Services/EgressEnvelopeHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace Taskdeck.Application.Services;
/// </summary>
public sealed class EgressEnvelopeHandler : DelegatingHandler
{
internal const long MaxRedirectReplayContentBytes = 1_048_576;

private readonly IEgressRegistry _egressRegistry;
private readonly ILogger<EgressEnvelopeHandler>? _logger;
private readonly string? _sourceComponent;
Expand Down Expand Up @@ -42,7 +44,9 @@ protected override async Task<HttpResponseMessage> SendAsync(
// against the egress allowlist. If auto-redirect is enabled, the handler only
// sees the final response and the redirect check becomes ineffective.

var response = await base.SendAsync(request, cancellationToken);
var replayContent = await PrepareReplayableContentAsync(request, cancellationToken);
var currentRequest = request;
var response = await base.SendAsync(currentRequest, cancellationToken);

// Manually follow redirects, validating each target against the egress envelope
var redirectCount = 0;
Expand All @@ -52,7 +56,7 @@ protected override async Task<HttpResponseMessage> SendAsync(

var resolvedRedirectUri = redirectUri.IsAbsoluteUri
? redirectUri
: new Uri(request.RequestUri!, redirectUri);
: new Uri(currentRequest.RequestUri!, redirectUri);

var redirectHost = resolvedRedirectUri.Host;

Expand All @@ -67,31 +71,53 @@ protected override async Task<HttpResponseMessage> SendAsync(

_logger?.LogError(
"EgressViolation: redirect to '{Host}' not in egress envelope. OriginalURI={OriginalUri}, RedirectURI={RedirectUri}, Source={Source}",
redirectHost, request.RequestUri, resolvedRedirectUri, _sourceComponent);
redirectHost, currentRequest.RequestUri, resolvedRedirectUri, _sourceComponent);

throw new EgressViolationException(violation);
}

// Follow the redirect: create a new request preserving the method for 307/308
var statusCode = (int)response.StatusCode;
var previousRequest = currentRequest;
var redirectRequest = new HttpRequestMessage
{
RequestUri = resolvedRedirectUri,
Version = request.Version,
Method = statusCode is 307 or 308 ? request.Method : HttpMethod.Get,
Version = previousRequest.Version,
Method = statusCode is 307 or 308 ? previousRequest.Method : HttpMethod.Get,
};

// 307/308 require preserving the original headers and body
// 307/308 require preserving the original body and safe headers
if (statusCode is 307 or 308)
{
redirectRequest.Content = request.Content;
foreach (var header in request.Headers)
if (previousRequest.Content is not null)
{
if (replayContent is null)
{
response.Dispose();
throw new InvalidOperationException(
$"Cannot replay request content across a 307/308 redirect because the content length is unknown or exceeds {MaxRedirectReplayContentBytes} bytes.");
}

redirectRequest.Content = CreateReplayContent(replayContent);
}

var isCrossOrigin = !IsSameOrigin(previousRequest.RequestUri, resolvedRedirectUri);
foreach (var header in previousRequest.Headers)
{
if (string.Equals(header.Key, "Host", StringComparison.OrdinalIgnoreCase))
continue;
if (isCrossOrigin && IsSensitiveRedirectHeader(header.Key))
continue;
redirectRequest.Headers.TryAddWithoutValidation(header.Key, header.Value);
}
}
else
{
replayContent = null;
}

response.Dispose();
currentRequest = redirectRequest;
response = await base.SendAsync(redirectRequest, cancellationToken);
}

Expand Down Expand Up @@ -139,6 +165,69 @@ private static bool IsRedirect(HttpResponseMessage response)
var statusCode = (int)response.StatusCode;
return statusCode is >= 300 and < 400;
}

private static async Task<ReplayableContent?> PrepareReplayableContentAsync(
HttpRequestMessage request,
CancellationToken cancellationToken)
{
if (request.Content is null)
return null;

var contentLength = request.Content.Headers.ContentLength;
if (contentLength is null || contentLength > MaxRedirectReplayContentBytes)
return null;

var headers = request.Content.Headers
.Where(header => !string.Equals(header.Key, "Content-Length", StringComparison.OrdinalIgnoreCase))
.Select(header => new KeyValuePair<string, string[]>(header.Key, header.Value.ToArray()))
.ToArray();

var content = await request.Content.ReadAsByteArrayAsync(cancellationToken);
if (content.LongLength > MaxRedirectReplayContentBytes)
{
throw new InvalidOperationException(
$"Request content exceeds the {MaxRedirectReplayContentBytes} byte redirect replay limit.");
}

var originalContent = request.Content;
var replayContent = new ReplayableContent(content, headers);
request.Content = CreateReplayContent(replayContent);
Comment thread
Chris0Jeky marked this conversation as resolved.
originalContent.Dispose();
return replayContent;
}

private static ByteArrayContent CreateReplayContent(ReplayableContent replayContent)
{
var content = new ByteArrayContent(replayContent.Content);
foreach (var header in replayContent.Headers)
{
content.Headers.TryAddWithoutValidation(header.Key, header.Value);
}

return content;
}

private static bool IsSameOrigin(Uri? left, Uri right)
{
if (left is null)
return false;

return string.Equals(left.Scheme, right.Scheme, StringComparison.OrdinalIgnoreCase)
&& string.Equals(left.Host, right.Host, StringComparison.OrdinalIgnoreCase)
&& left.Port == right.Port;
}

private static bool IsSensitiveRedirectHeader(string header)
=> string.Equals(header, "Authorization", StringComparison.OrdinalIgnoreCase)
|| string.Equals(header, "Proxy-Authorization", StringComparison.OrdinalIgnoreCase)
|| string.Equals(header, "Cookie", StringComparison.OrdinalIgnoreCase)
|| string.Equals(header, "x-goog-api-key", StringComparison.OrdinalIgnoreCase)
|| string.Equals(header, "x-api-key", StringComparison.OrdinalIgnoreCase)
|| string.Equals(header, "api-key", StringComparison.OrdinalIgnoreCase)
|| header.Contains("token", StringComparison.OrdinalIgnoreCase)
|| header.Contains("secret", StringComparison.OrdinalIgnoreCase);

private sealed record ReplayableContent(byte[] Content, IReadOnlyList<KeyValuePair<string, string[]>> Headers);
}

/// <summary>
Expand Down
Loading
Loading