Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions include/mcp/filter/http_routing_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,11 @@ class HttpRoutingFilter : public HttpCodecFilter::MessageCallbacks {
// Default handler for unmatched requests
HandlerFunc default_handler_;

// Following production pattern - completely stateless filter
// No request state stored - all decisions made immediately
// State for POST requests that need body
bool pending_post_request_ = false;
RequestContext pending_context_;
HandlerFunc pending_handler_;
std::string accumulated_body_;
};

/**
Expand Down
59 changes: 59 additions & 0 deletions include/mcp/filter/http_sse_filter_chain_factory.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <functional>
#include <memory>
#include <string>

Expand All @@ -23,6 +24,13 @@ class MetricsFilter;
namespace mcp {
namespace filter {

/**
* Callback type for registering custom HTTP routes
* Called with the HttpRoutingFilter when setting up the filter chain.
* Use this to register custom endpoints (e.g., OAuth discovery, health checks).
*/
using HttpRouteRegistrationCallback = std::function<void(HttpRoutingFilter*)>;

/**
* MCP HTTP+SSE Filter Chain Factory
*
Expand Down Expand Up @@ -104,6 +112,48 @@ class HttpSseFilterChainFactory : public network::FilterChainFactory {
*/
void enableMetrics(bool enable = true) { enable_metrics_ = enable; }

/**
* Add a filter factory that runs before protocol filters
* Filter factories are invoked in order during chain creation.
* The created filters process data before HTTP/SSE/JSON-RPC protocol filters.
* Useful for authentication, logging, or other cross-cutting concerns.
*
* This follows the existing FilterFactoryCb pattern used by
* FilterChainFactoryImpl and createNetworkFilterChain().
*
* @param factory Factory callback that creates a filter instance
*/
void addFilterFactory(network::FilterFactoryCb factory) {
filter_factories_.push_back(std::move(factory));
}

/**
* Get the list of filter factories
* @return Vector of filter factories
*/
const std::vector<network::FilterFactoryCb>& getFilterFactories() const {
return filter_factories_;
}

/**
* Set callback for registering custom HTTP routes
* The callback will be invoked when the filter chain is created,
* allowing registration of custom endpoints like OAuth discovery.
*
* @param callback Function to call with the HttpRoutingFilter
*/
void setRouteRegistrationCallback(HttpRouteRegistrationCallback callback) {
route_registration_callback_ = std::move(callback);
}

/**
* Get the route registration callback
* @return The callback, or nullptr if not set
*/
const HttpRouteRegistrationCallback& getRouteRegistrationCallback() const {
return route_registration_callback_;
}

/**
* Send a response through the connection's filter chain
* Following production pattern: connection context flows through
Expand All @@ -124,6 +174,15 @@ class HttpSseFilterChainFactory : public network::FilterChainFactory {

// Store filters for lifetime management
mutable std::vector<network::FilterSharedPtr> filters_;

// Filter factories added by user (authentication, logging, etc.)
// These are invoked during chain creation to add filters before protocol
// filters Following the existing FilterFactoryCb pattern from
// FilterChainFactoryImpl
std::vector<network::FilterFactoryCb> filter_factories_;

// Callback for registering custom HTTP routes
HttpRouteRegistrationCallback route_registration_callback_;
};

} // namespace filter
Expand Down
16 changes: 16 additions & 0 deletions include/mcp/server/mcp_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "mcp/event/event_loop.h"
#include "mcp/filter/filter_chain_callbacks.h"
#include "mcp/filter/filter_chain_event_hub.h"
#include "mcp/filter/http_sse_filter_chain_factory.h"
#include "mcp/filter/metrics_filter.h"
#include "mcp/json/json_bridge.h"
#include "mcp/logging/log_macros.h"
Expand Down Expand Up @@ -112,6 +113,21 @@ struct McpServerConfig : public application::ApplicationBase::Config {
// If provided, uses ConfigurableFilterChainFactory instead of hardcoded
// factories
optional<json::JsonValue> filter_chain_config;

// Filter factories for HTTP-level processing (optional)
// These factories are invoked during chain creation to add filters
// that run before protocol filters (HTTP/SSE/JSON-RPC).
// Useful for authentication, logging, or other cross-cutting concerns.
// This follows the existing FilterFactoryCb pattern used throughout
// gopher-mcp. Example: Add an OAuth auth filter factory to validate tokens
// before processing
std::vector<network::FilterFactoryCb> filter_factories;

// Callback for registering custom HTTP routes (optional)
// Called when filter chain is created, allowing registration of custom
// endpoints like OAuth discovery (/.well-known/oauth-protected-resource).
// Example: registerOAuthEndpoints(router, config);
filter::HttpRouteRegistrationCallback route_registration_callback;
};

/**
Expand Down
79 changes: 64 additions & 15 deletions src/filter/http_routing_filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,20 @@ void HttpRoutingFilter::onHeaders(

// Server mode: route incoming requests
std::string method = extractMethod(headers);
std::string path = extractPath(headers);
std::string path =
extractPath(headers); // Path without query string for routing

// Get full URL including query string for handler context
std::string full_url = path;
auto url_it = headers.find("url");
if (url_it != headers.end()) {
full_url = url_it->second;
} else {
auto path_it = headers.find(":path");
if (path_it != headers.end()) {
full_url = path_it->second;
}
}

GOPHER_LOG_DEBUG("HttpRoutingFilter: method={} path={}", method, path);

Expand All @@ -74,14 +87,24 @@ void HttpRoutingFilter::onHeaders(
auto handler_it = handlers_.find(key);

if (handler_it != handlers_.end()) {
// We have a handler - execute it immediately with available info
// We have a handler
RequestContext ctx;
ctx.method = method;
ctx.path = path;
ctx.path = full_url; // Full URL with query string for handler
ctx.headers = headers;
ctx.keep_alive = keep_alive;
// Note: body not available yet in onHeaders

// For POST/PUT requests, defer handler until we have the body
if (method == "POST" || method == "PUT" || method == "PATCH") {
pending_post_request_ = true;
pending_context_ = ctx;
pending_handler_ = handler_it->second;
accumulated_body_.clear();
return; // Wait for body
}

// For GET/OPTIONS etc, execute immediately
Response resp = handler_it->second(ctx);
if (resp.status_code != 0) {
// Handler wants to handle this - send response immediately
Expand All @@ -98,8 +121,13 @@ void HttpRoutingFilter::onHeaders(
}

void HttpRoutingFilter::onBody(const std::string& data, bool end_stream) {
// Stateless - always pass through
// If we handled the request in onHeaders, this won't be called
// If we're accumulating body for a POST handler, buffer it
if (pending_post_request_) {
accumulated_body_ += data;
return; // Don't forward - we'll handle in onMessageComplete
}

// Otherwise pass through
if (next_callbacks_) {
next_callbacks_->onBody(data, end_stream);
}
Expand All @@ -108,8 +136,20 @@ void HttpRoutingFilter::onBody(const std::string& data, bool end_stream) {
void HttpRoutingFilter::onMessageComplete() {
GOPHER_LOG_DEBUG("HttpRoutingFilter::onMessageComplete called");

// Stateless - always pass through
// If we handled the request in onHeaders, this won't be called
// If we have a pending POST request, now we have the complete body
if (pending_post_request_) {
pending_context_.body = accumulated_body_;
Response resp = pending_handler_(pending_context_);
if (resp.status_code != 0) {
sendResponse(resp);
}
// Reset state
pending_post_request_ = false;
accumulated_body_.clear();
return;
}

// Otherwise pass through
if (next_callbacks_) {
next_callbacks_->onMessageComplete();
}
Expand Down Expand Up @@ -211,20 +251,29 @@ std::string HttpRoutingFilter::extractMethod(

std::string HttpRoutingFilter::extractPath(
const std::map<std::string, std::string>& headers) {
std::string full_path;

// Check for :path pseudo-header (HTTP/2 style)
auto it = headers.find(":path");
if (it != headers.end()) {
return it->second;
full_path = it->second;
} else {
// For HTTP/1.1, the codec stores the URL in a "url" header
it = headers.find("url");
if (it != headers.end()) {
full_path = it->second;
} else {
// Default to root if not found
return "/";
}
}

// For HTTP/1.1, the codec stores the URL in a "url" header
it = headers.find("url");
if (it != headers.end()) {
return it->second;
// Strip query string for routing purposes
size_t query_pos = full_path.find('?');
if (query_pos != std::string::npos) {
return full_path.substr(0, query_pos);
}

// Default to root if not found
return "/";
return full_path;
}

// Factory methods
Expand Down
50 changes: 39 additions & 11 deletions src/filter/http_sse_filter_chain_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,21 @@ class HttpSseJsonRpcProtocolFilter
friend void HttpSseFilterChainFactory::sendHttpResponse(
const jsonrpc::Response&, network::Connection&);

HttpSseJsonRpcProtocolFilter(event::Dispatcher& dispatcher,
McpProtocolCallbacks& mcp_callbacks,
bool is_server,
const std::string& http_path = "/rpc",
const std::string& http_host = "localhost",
bool use_sse = true)
HttpSseJsonRpcProtocolFilter(
event::Dispatcher& dispatcher,
McpProtocolCallbacks& mcp_callbacks,
bool is_server,
const std::string& http_path = "/rpc",
const std::string& http_host = "localhost",
bool use_sse = true,
const HttpRouteRegistrationCallback& route_callback = nullptr)
: dispatcher_(dispatcher),
mcp_callbacks_(mcp_callbacks),
is_server_(is_server),
http_path_(http_path),
http_host_(http_host),
use_sse_(use_sse) {
use_sse_(use_sse),
route_registration_callback_(route_callback) {
// Following production pattern: all operations for this filter
// happen in the single dispatcher thread
// Create routing filter first (it will receive HTTP callbacks)
Expand Down Expand Up @@ -887,6 +890,12 @@ class HttpSseJsonRpcProtocolFilter
resp.status_code = 0;
return resp;
});

// Call custom route registration callback if provided
// This allows users to register additional endpoints like OAuth discovery
if (route_registration_callback_) {
route_registration_callback_(routing_filter_.get());
}
}

event::Dispatcher& dispatcher_;
Expand Down Expand Up @@ -931,6 +940,9 @@ class HttpSseJsonRpcProtocolFilter
// Buffered data
OwnedBuffer pending_json_data_;
OwnedBuffer pending_sse_data_; // For accumulating SSE event stream data

// Custom route registration callback
HttpRouteRegistrationCallback route_registration_callback_;
};

// RequestStream method implementation (after HttpSseJsonRpcProtocolFilter
Expand Down Expand Up @@ -964,9 +976,23 @@ void HttpSseFilterChainFactory::sendHttpResponse(
bool HttpSseFilterChainFactory::createFilterChain(
network::FilterManager& filter_manager) const {
// Following production pattern: create filters in order
// 1. HTTP Routing Filter (handles arbitrary HTTP endpoints)
// 2. Combined Protocol Filter (HTTP/SSE/JSON-RPC)
// 3. Metrics Filter (collects statistics)
// 1. Pre-filters (authentication, logging, etc.) - added by user
// 2. Metrics Filter (collects statistics)
// 3. Combined Protocol Filter (HTTP/SSE/JSON-RPC)

// Invoke user-provided filter factories first (e.g., auth filters)
// These filters run before protocol filters and can intercept/reject requests
// Following the existing FilterFactoryCb pattern from FilterChainFactoryImpl
for (const auto& factory : filter_factories_) {
if (factory) {
auto filter = factory();
if (filter) {
filter_manager.addReadFilter(filter);
filter_manager.addWriteFilter(filter);
filters_.push_back(filter);
}
}
}

// Create metrics filter if enabled
if (enable_metrics_) {
Expand Down Expand Up @@ -1000,9 +1026,11 @@ bool HttpSseFilterChainFactory::createFilterChain(
// No separate routing filter needed

// Create the combined protocol filter
// Pass the route registration callback so custom HTTP routes can be
// registered
auto combined_filter = std::make_shared<HttpSseJsonRpcProtocolFilter>(
dispatcher_, message_callbacks_, is_server_, http_path_, http_host_,
use_sse_);
use_sse_, route_registration_callback_);

// Add as both read and write filter
filter_manager.addReadFilter(combined_filter);
Expand Down
Loading
Loading