diff --git a/src/windows/common/DnsResolver.cpp b/src/windows/common/DnsResolver.cpp index 090e9b426..03f80f9a1 100644 --- a/src/windows/common/DnsResolver.cpp +++ b/src/windows/common/DnsResolver.cpp @@ -8,6 +8,10 @@ using wsl::core::networking::DnsResolver; static constexpr auto c_dnsModuleName = L"dnsapi.dll"; +// Size of the 2-byte length prefix used in DNS-over-TCP (RFC 1035, section 4.2.2). +// This matches c_byteCountTcpRequestLength on the Linux side (DnsServer.h). +static constexpr size_t c_byteCountTcpRequestLength = 2; + std::optional> DnsResolver::s_dnsQueryRaw; std::optional> DnsResolver::s_dnsCancelQueryRaw; std::optional> DnsResolver::s_dnsQueryRawResultFree; @@ -223,6 +227,26 @@ try auto [it, _] = m_dnsRequests.emplace(requestId, std::move(context)); const auto localContext = it->second.get(); + // Store the DNS transaction ID for constructing SERVFAIL responses. + // For UDP, the DNS header starts at offset 0; for TCP, the first 2 bytes are a length prefix. + size_t transactionIdOffset = 0; + if (dnsClientIdentifier.Protocol == IPPROTO_TCP) + { + transactionIdOffset = c_byteCountTcpRequestLength; + } + + if (dnsBuffer.size() >= transactionIdOffset + sizeof(localContext->m_dnsTransactionId)) + { + memcpy(&localContext->m_dnsTransactionId, dnsBuffer.data() + transactionIdOffset, sizeof(localContext->m_dnsTransactionId)); + } + else + { + WSL_LOG( + "DnsResolver::ProcessDnsRequest - DNS buffer too small to extract transaction ID", + TraceLoggingValue(dnsBuffer.size(), "DNS buffer size"), + TraceLoggingValue(transactionIdOffset, "transactionIdOffset")); + } + auto removeContextOnError = wil::scope_exit([&] { WI_VERIFY(m_dnsRequests.erase(requestId) == 1); }); // Fill DNS request structure @@ -264,6 +288,10 @@ try TraceLoggingValue(requestId, "requestId"), TraceLoggingValue(result, "result"), TraceLoggingValue("DnsQueryRaw", "executionStep")); + + // Send SERVFAIL back to Linux so the DNS client gets an immediate error + // instead of waiting for a timeout. + SendServfailResponse(localContext->m_dnsTransactionId, localContext->m_dnsClientIdentifier); return; } @@ -322,7 +350,9 @@ try if (!m_stopped && queryResults != nullptr && queryResults->queryRawResponse != nullptr) { - // Copy DNS response buffer + // Copy DNS response buffer. + // Note: For TCP, queryRawResponse includes the 2-byte length prefix per the DnsQueryRaw API contract, + // which matches what HandleTcpDnsResponse on the Linux side expects when writing to the TCP socket. std::vector dnsResponse(queryResults->queryRawResponseSize); CopyMemory(dnsResponse.data(), queryResults->queryRawResponse, queryResults->queryRawResponseSize); @@ -337,6 +367,12 @@ try m_dnsChannel.SendDnsMessage(gsl::make_span(dnsResponse), dnsClientIdentifier); }); } + else if (!m_stopped) + { + // The Windows DNS API failed to resolve the request. Send a SERVFAIL response to the Linux DNS client + // so it gets an immediate error instead of waiting for a timeout (which can take 5-10 seconds). + SendServfailResponse(queryContext->m_dnsTransactionId, queryContext->m_dnsClientIdentifier); + } // Stop tracking this DNS request and delete the request context WI_VERIFY(m_dnsRequests.erase(queryContext->m_id) == 1); @@ -349,6 +385,43 @@ try } CATCH_LOG() +void DnsResolver::SendServfailResponse(uint16_t transactionId, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier) +{ + // Build a minimal DNS SERVFAIL response per RFC 1035 section 4.1.1. + // This allows the Linux DNS client to immediately learn the query failed, + // rather than waiting for a retransmit timeout (typically 5-10 seconds). + // + // For TCP, the response must include a 2-byte length prefix (RFC 1035 section 4.2.2) + // to maintain TCP stream framing, matching the format used by the success path. + constexpr size_t c_dnsHeaderSize = 12; + const bool isTcp = (dnsClientIdentifier.Protocol == IPPROTO_TCP); + const size_t prefixSize = isTcp ? c_byteCountTcpRequestLength : 0; + const size_t totalSize = prefixSize + c_dnsHeaderSize; + + std::vector servfail(totalSize, gsl::byte{0}); + + if (isTcp) + { + // Write the 2-byte length prefix in network byte order + uint16_t dnsLength = htons(static_cast(c_dnsHeaderSize)); + memcpy(servfail.data(), &dnsLength, sizeof(dnsLength)); + } + + auto* dnsHeader = servfail.data() + prefixSize; + memcpy(dnsHeader, &transactionId, sizeof(transactionId)); // Transaction ID (network byte order, copied as-is) + dnsHeader[2] = gsl::byte{0x80}; // QR=1 (response), OPCODE=0 (standard query) + dnsHeader[3] = gsl::byte{0x02}; // RA=0, Z=0, RCODE=2 (Server Failure) + + WSL_LOG( + "DnsResolver::SendServfailResponse", + TraceLoggingValue(dnsClientIdentifier.Protocol == IPPROTO_UDP ? "UDP" : "TCP", "Protocol"), + TraceLoggingValue(dnsClientIdentifier.DnsClientId, "DNS client id")); + + m_dnsResponseQueue.submit([this, servfail = std::move(servfail), dnsClientIdentifier]() mutable { + m_dnsChannel.SendDnsMessage(gsl::make_span(servfail), dnsClientIdentifier); + }); +} + void DnsResolver::ResolveExternalInterfaceConstraintIndex() noexcept try { diff --git a/src/windows/common/DnsResolver.h b/src/windows/common/DnsResolver.h index a186737e8..6be6555bb 100644 --- a/src/windows/common/DnsResolver.h +++ b/src/windows/common/DnsResolver.h @@ -43,6 +43,11 @@ class DnsResolver // Unique query id. uint32_t m_id{}; + // Transaction ID field from the original DNS message header (stored in network byte order). + // Note: For DNS-over-TCP the tunneled buffer includes a 2-byte length prefix before the + // DNS header; this value is taken from the DNS header itself, not assumed to be at offset 0. + uint16_t m_dnsTransactionId{}; + // Callback to the parent object to notify about the DNS query completion. std::function m_handleQueryCompletion; @@ -78,6 +83,10 @@ class DnsResolver // queryResults - structure containing result of the DNS request. void HandleDnsQueryCompletion(_Inout_ DnsQueryContext* dnsQueryContext, _Inout_opt_ DNS_QUERY_RAW_RESULT* queryResults) noexcept; + // Build and send a minimal DNS SERVFAIL response (RFC 1035, RCODE=2) back to the Linux DNS client. + // This is used when the Windows DNS API fails, to prevent the Linux client from waiting until timeout. + void SendServfailResponse(uint16_t transactionId, const LX_GNS_DNS_CLIENT_IDENTIFIER& dnsClientIdentifier); + void ResolveExternalInterfaceConstraintIndex() noexcept; // Callback that will be invoked by the DNS API whenever a request finishes. The callback is invoked on success, error or when request is cancelled. @@ -105,7 +114,7 @@ class DnsResolver _Guarded_by_(m_dnsLock) uint32_t m_currentRequestId = 0; // Mapping request id to the request context structure. - _Guarded_by_(m_dnsLock) std::unordered_map> m_dnsRequests {}; + _Guarded_by_(m_dnsLock) std::unordered_map> m_dnsRequests{}; // Event that is set when all tracked DNS requests have completed. wil::unique_event m_allRequestsFinished{wil::EventOptions::ManualReset};