Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion src/windows/common/DnsResolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ using wsl::core::networking::DnsResolver;

static constexpr auto c_dnsModuleName = L"dnsapi.dll";

// Size of the 2-byte length prefix used in DNS-over-TCP (RFC 1035, section 4.2.2).
// This matches c_byteCountTcpRequestLength on the Linux side (DnsServer.h).
static constexpr size_t c_byteCountTcpRequestLength = 2;

std::optional<LxssDynamicFunction<decltype(DnsQueryRaw)>> DnsResolver::s_dnsQueryRaw;
std::optional<LxssDynamicFunction<decltype(DnsCancelQueryRaw)>> DnsResolver::s_dnsCancelQueryRaw;
std::optional<LxssDynamicFunction<decltype(DnsQueryRawResultFree)>> DnsResolver::s_dnsQueryRawResultFree;
Expand Down Expand Up @@ -223,6 +227,26 @@ try
auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context));
const auto localContext = it->second.get();

// Store the DNS transaction ID for constructing SERVFAIL responses.
// For UDP, the DNS header starts at offset 0; for TCP, the first 2 bytes are a length prefix.
size_t transactionIdOffset = 0;
if (dnsClientIdentifier.Protocol == IPPROTO_TCP)
{
transactionIdOffset = c_byteCountTcpRequestLength;
}

if (dnsBuffer.size() >= transactionIdOffset + sizeof(localContext->m_dnsTransactionId))
{
memcpy(&localContext->m_dnsTransactionId, dnsBuffer.data() + transactionIdOffset, sizeof(localContext->m_dnsTransactionId));
}
else
{
WSL_LOG(
Comment on lines +238 to +244
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the DNS buffer is too small to contain the transaction ID (e.g., <2 bytes for UDP or <4 bytes for TCP), m_dnsTransactionId remains at its default value and later failure paths can send a SERVFAIL with an incorrect transaction ID. That response won’t correlate to the outstanding query and can reintroduce client timeouts (or, worse, collide with a legitimate query that happens to use ID 0). Consider treating this as invalid input: validate the minimum header size up-front and return early (and/or track whether the transaction ID was successfully extracted and only send SERVFAIL when it was).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems legitimate.

"DnsResolver::ProcessDnsRequest - DNS buffer too small to extract transaction ID",
TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"),
TraceLoggingValue(transactionIdOffset, "transactionIdOffset"));
}

auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); });

// Fill DNS request structure
Expand Down Expand Up @@ -264,6 +288,10 @@ try
TraceLoggingValue(requestId, "requestId"),
TraceLoggingValue(result, "result"),
TraceLoggingValue("DnsQueryRaw", "executionStep"));

// Send SERVFAIL back to Linux so the DNS client gets an immediate error
// instead of waiting for a timeout.
SendServfailResponse(localContext->m_dnsTransactionId, localContext->m_dnsClientIdentifier);
return;
}

Expand Down Expand Up @@ -322,7 +350,9 @@ try

if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr)
{
// Copy DNS response buffer
// Copy DNS response buffer.
// Note: For TCP, queryRawResponse includes the 2-byte length prefix per the DnsQueryRaw API contract,
// which matches what HandleTcpDnsResponse on the Linux side expects when writing to the TCP socket.
std::vector<gsl::byte> dnsResponse(queryResults->queryRawResponseSize);
CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize);

Expand All @@ -337,6 +367,12 @@ try
m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier);
});
}
else if (!m_stopped)
{
// The Windows DNS API failed to resolve the request. Send a SERVFAIL response to the Linux DNS client
// so it gets an immediate error instead of waiting for a timeout (which can take 5-10 seconds).
SendServfailResponse(queryContext->m_dnsTransactionId, queryContext->m_dnsClientIdentifier);
}
Comment on lines +370 to +375
Copy link

Copilot AI Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New behavior adds a SERVFAIL fallback path when DnsQueryRaw fails/nulls. There are existing DNS tunneling integration tests (test/windows/NetworkTests.cpp) but none appear to validate the failure path; adding a test that forces a DnsQueryRaw failure and asserts that the Linux client receives an immediate SERVFAIL (for both UDP and TCP framing) would help prevent regressions and ensure the timeout/leak fix stays covered.

Copilot uses AI. Check for mistakes.

// Stop tracking this DNS request and delete the request context
WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1);
Expand All @@ -349,6 +385,43 @@ try
}
CATCH_LOG()

void DnsResolver::SendServfailResponse(uint16_t transactionId, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier)
{
// Build a minimal DNS SERVFAIL response per RFC 1035 section 4.1.1.
// This allows the Linux DNS client to immediately learn the query failed,
// rather than waiting for a retransmit timeout (typically 5-10 seconds).
//
// For TCP, the response must include a 2-byte length prefix (RFC 1035 section 4.2.2)
// to maintain TCP stream framing, matching the format used by the success path.
constexpr size_t c_dnsHeaderSize = 12;
const bool isTcp = (dnsClientIdentifier.Protocol == IPPROTO_TCP);
const size_t prefixSize = isTcp ? c_byteCountTcpRequestLength : 0;
const size_t totalSize = prefixSize + c_dnsHeaderSize;

std::vector<gsl::byte> servfail(totalSize, gsl::byte{0});

if (isTcp)
{
// Write the 2-byte length prefix in network byte order
uint16_t dnsLength = htons(static_cast<uint16_t>(c_dnsHeaderSize));
memcpy(servfail.data(), &dnsLength, sizeof(dnsLength));
}

auto* dnsHeader = servfail.data() + prefixSize;
memcpy(dnsHeader, &transactionId, sizeof(transactionId)); // Transaction ID (network byte order, copied as-is)
dnsHeader[2] = gsl::byte{0x80}; // QR=1 (response), OPCODE=0 (standard query)
dnsHeader[3] = gsl::byte{0x02}; // RA=0, Z=0, RCODE=2 (Server Failure)
Comment on lines +412 to +413
Copy link

Copilot AI Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SERVFAIL header flags are hard-coded to clear RD (and RA is forced to 0). Per RFC 1035, RD is copied from the query into the response, and some stub resolvers validate this. Consider capturing the RD bit from the original request (alongside the transaction ID) and reflecting it in the SERVFAIL response; also consider setting RA appropriately to match normal resolver behavior.

Suggested change
dnsHeader[2] = gsl::byte{0x80}; // QR=1 (response), OPCODE=0 (standard query)
dnsHeader[3] = gsl::byte{0x02}; // RA=0, Z=0, RCODE=2 (Server Failure)
// Flags: QR=1 (response), OPCODE=0 (standard query), AA=0, TC=0, RD=1
dnsHeader[2] = gsl::byte{0x81};
// Flags: RA=1 (recursion available), Z=0, RCODE=2 (Server Failure)
dnsHeader[3] = gsl::byte{0x82};

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this intended?


WSL_LOG(
"DnsResolver::SendServfailResponse",
TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"),
TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id"));

m_dnsResponseQueue.submit([this, servfail = std::move(servfail), dnsClientIdentifier]() mutable {
m_dnsChannel.SendDnsMessage(gsl::make_span(servfail), dnsClientIdentifier);
});
}

void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept
try
{
Expand Down
11 changes: 10 additions & 1 deletion src/windows/common/DnsResolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class DnsResolver
// Unique query id.
uint32_t m_id{};

// Transaction ID field from the original DNS message header (stored in network byte order).
// Note: For DNS-over-TCP the tunneled buffer includes a 2-byte length prefix before the
// DNS header; this value is taken from the DNS header itself, not assumed to be at offset 0.
uint16_t m_dnsTransactionId{};

// Callback to the parent object to notify about the DNS query completion.
std::function<void(DnsQueryContext*, DNS_QUERY_RAW_RESULT*)> m_handleQueryCompletion;

Expand Down Expand Up @@ -78,6 +83,10 @@ class DnsResolver
// queryResults - structure containing result of the DNS request.
void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept;

// Build and send a minimal DNS SERVFAIL response (RFC 1035, RCODE=2) back to the Linux DNS client.
// This is used when the Windows DNS API fails, to prevent the Linux client from waiting until timeout.
Copy link

Copilot AI Mar 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SendServfailResponse takes a raw uint16_t transactionId but the implementation assumes it is already in network byte order (it memcpy's the value directly). Please document this explicitly on the SendServfailResponse declaration (or normalize to host order in the context and htons() when writing) to prevent future callers from accidentally passing host-order IDs.

Suggested change
// This is used when the Windows DNS API fails, to prevent the Linux client from waiting until timeout.
// This is used when the Windows DNS API fails, to prevent the Linux client from waiting until timeout.
//
// Arguments:
// transactionId - DNS transaction ID to place in the response header, in network byte order.
// Callers that have a host-order transaction ID must convert it (e.g. via htons)
// before passing it to this method.
// dnsClientIdentifier - struct containing protocol (TCP/UDP) and unique id of the Linux DNS client
// to which the SERVFAIL response will be sent.

Copilot uses AI. Check for mistakes.
void SendServfailResponse(uint16_t transactionId, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier);

void ResolveExternalInterfaceConstraintIndex() noexcept;

// Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled.
Expand Down Expand Up @@ -105,7 +114,7 @@ class DnsResolver
_Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0;

// Mapping request id to the request context structure.
_Guarded_by_(m_dnsLock) std::unordered_map<uint32_t, std::unique_ptr<DnsQueryContext>> m_dnsRequests {};
_Guarded_by_(m_dnsLock) std::unordered_map<uint32_t, std::unique_ptr<DnsQueryContext>> m_dnsRequests{};

// Event that is set when all tracked DNS requests have completed.
wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset};
Expand Down