diff --git a/fe/fe-common/src/main/java/org/apache/doris/common/Config.java b/fe/fe-common/src/main/java/org/apache/doris/common/Config.java index ac6ee280d25333..6497903c093ed7 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/common/Config.java +++ b/fe/fe-common/src/main/java/org/apache/doris/common/Config.java @@ -438,6 +438,13 @@ public class Config extends ConfigBase { "The maximum HTTP POST size of Jetty, in bytes, the default value is 100MB."}) public static int jetty_server_max_http_post_size = 100 * 1024 * 1024; + @ConfField(mutable = true, description = { + "Jetty 在应用未消费完请求体时,额外尝试读取剩余内容的最大次数。" + + "-1 表示不限制,0 表示不额外读取,正数表示最大读取次数。", + "The maximum number of extra reads Jetty performs for unconsumed request content. " + + "-1 means unlimited, 0 means disabled, and a positive value limits the read attempts."}) + public static int jetty_server_max_unconsumed_request_content_reads = -1; + @ConfField(description = {"Jetty 的最大 HTTP header 大小,单位是字节,默认值是 1MB。", "The maximum HTTP header size of Jetty, in bytes, the default value is 1MB."}) public static int jetty_server_max_http_header_size = 1048576; @@ -3305,6 +3312,13 @@ public static int metaServiceRpcRetryTimes() { + "public-private/public/private/direct/random-be and empty string" }) public static String streamload_redirect_policy = ""; + @ConfField(mutable = true, description = { + "Stream Load redirect 场景下,FE 在返回 307 后额外丢弃请求体的最大字节数。" + + "0 表示关闭该兼容逻辑,正数表示最大丢弃字节数。", + "The maximum number of request body bytes FE drains after returning 307 for Stream Load redirects. " + + "0 disables the compatibility logic, and a positive value sets the byte limit."}) + public static long stream_load_redirect_bounded_drain_max_bytes = 0; + @ConfField(mutable = true, description = { "存算分离模式下是否启用group commit的streamload BE转发功能。" + "解决LB随机转发导致group commit攒批失效的问题,通过BE二次转发确保同表请求到达同一BE节点。", diff --git a/fe/fe-core/src/main/java/org/apache/doris/httpv2/config/WebServerFactoryCustomizerConfig.java b/fe/fe-core/src/main/java/org/apache/doris/httpv2/config/WebServerFactoryCustomizerConfig.java index a467a23084481b..71f42ae22f5f35 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/httpv2/config/WebServerFactoryCustomizerConfig.java +++ b/fe/fe-core/src/main/java/org/apache/doris/httpv2/config/WebServerFactoryCustomizerConfig.java @@ -19,6 +19,7 @@ import org.apache.doris.common.Config; +import org.eclipse.jetty.server.Connector; import org.eclipse.jetty.server.HttpConfiguration; import org.eclipse.jetty.server.HttpConnectionFactory; import org.eclipse.jetty.server.ServerConnector; @@ -37,12 +38,16 @@ public void customize(ConfigurableJettyWebServerFactory factory) { ((JettyServletWebServerFactory) factory).setConfigurations( Collections.singletonList(new HttpToHttpsJettyConfig()) ); + } - factory.addServerCustomizers( - server -> { + factory.addServerCustomizers( + server -> { + if (Config.enable_https) { HttpConfiguration httpConfiguration = new HttpConfiguration(); httpConfiguration.setSecurePort(Config.https_port); httpConfiguration.setSecureScheme("https"); + httpConfiguration.setMaxUnconsumedRequestContentReads( + Config.jetty_server_max_unconsumed_request_content_reads); ServerConnector connector = new ServerConnector(server); connector.addConnectionFactory(new HttpConnectionFactory(httpConfiguration)); @@ -50,7 +55,19 @@ public void customize(ConfigurableJettyWebServerFactory factory) { server.addConnector(connector); } - ); - } + + for (Connector connector : server.getConnectors()) { + if (!(connector instanceof ServerConnector)) { + continue; + } + HttpConnectionFactory httpConnectionFactory = + ((ServerConnector) connector).getConnectionFactory(HttpConnectionFactory.class); + if (httpConnectionFactory != null) { + httpConnectionFactory.getHttpConfiguration().setMaxUnconsumedRequestContentReads( + Config.jetty_server_max_unconsumed_request_content_reads); + } + } + } + ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/LoadAction.java b/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/LoadAction.java index 5af0f0e4990f77..a7b8d17843813c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/LoadAction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/LoadAction.java @@ -23,7 +23,6 @@ import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.Table; import org.apache.doris.cloud.qe.ComputeGroupException; -import org.apache.doris.cluster.ClusterNamespace; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.Config; import org.apache.doris.common.DdlException; @@ -38,6 +37,7 @@ import org.apache.doris.httpv2.entity.RestBaseResult; import org.apache.doris.httpv2.exception.UnauthorizedException; import org.apache.doris.httpv2.rest.manager.HttpUtils; +import org.apache.doris.httpv2.util.StreamLoadRedirectDrainUtil; import org.apache.doris.load.FailMsg; import org.apache.doris.load.StreamLoadHandler; import org.apache.doris.load.loadv2.IngestionLoadJob; @@ -77,7 +77,6 @@ import java.io.IOException; import java.net.InetAddress; -import java.net.URI; import java.util.Enumeration; import java.util.HashMap; import java.util.LinkedList; @@ -210,8 +209,7 @@ public Object streamLoadWithSql(HttpServletRequest request, HttpServletResponse LOG.info("redirect load action to destination={}, label: {}", redirectAddr.toString(), label); - RedirectView redirectView = redirectTo(request, redirectAddr); - return redirectView; + return createRedirectResponse(request, response, redirectAddr, true, null, null, label); } catch (Exception e) { return new RestBaseResult(e.getMessage()); } @@ -334,11 +332,10 @@ private Object executeWithoutPassword(HttpServletRequest request, redirectAddr.toString(), isStreamLoad, dbName, tableName, label); } - RedirectView redirectView = redirectTo(request, redirectAddr); - return redirectView; + return createRedirectResponse(request, response, redirectAddr, isStreamLoad, dbName, tableName, label); } catch (StreamLoadForwardException e) { // Special handling for stream load forwarding - return e.getRedirectView(); + return createRedirectResponse(request, response, e.getRedirectView(), isStreamLoad, db, table, label); } catch (Exception e) { LOG.warn("load failed, stream: {}, db: {}, tbl: {}, label: {}, err: {}", isStreamLoad, db, table, label, e.getMessage()); @@ -672,24 +669,7 @@ private Object executeWithClusterToken(HttpServletRequest request, String db, + "stream: {}, db: {}, tbl: {}, label: {}", redirectAddr.toString(), isStreamLoad, dbName, tableName, label); - URI urlObj = null; - URI resultUriObj = null; - String urlStr = request.getRequestURI(); - String userInfo = null; - - try { - urlObj = new URI(urlStr); - resultUriObj = new URI("http", userInfo, redirectAddr.getHostname(), - redirectAddr.getPort(), urlObj.getPath(), "", null); - } catch (Exception e) { - throw new RuntimeException(e); - } - String redirectUrl = resultUriObj.toASCIIString(); - if (!Strings.isNullOrEmpty(request.getQueryString())) { - redirectUrl += request.getQueryString(); - } - LOG.info("Redirect url: {}", "http://" + redirectAddr.getHostname() + ":" - + redirectAddr.getPort() + urlObj.getPath()); + String redirectUrl = buildRedirectUrl(request, redirectAddr); RedirectView redirectView = new RedirectView(redirectUrl); redirectView.setContentType("text/html;charset=utf-8"); redirectView.setStatusCode(org.springframework.http.HttpStatus.TEMPORARY_REDIRECT); @@ -714,6 +694,47 @@ private String getAllHeaders(HttpServletRequest request) { return headers.toString(); } + private Object createRedirectResponse(HttpServletRequest request, HttpServletResponse response, + TNetworkAddress redirectAddr, boolean isStreamLoad, String dbName, String tableName, String label) + throws IOException { + String redirectUrl = buildRedirectUrl(request, redirectAddr); + if (!shouldUseBoundedDrainForStreamLoad(isStreamLoad)) { + return redirectTo(request, redirectAddr); + } + writeTemporaryRedirect(response, redirectUrl); + drainStreamLoadRequestBodyAfterRedirect(request, redirectAddr.toString(), dbName, tableName, label); + return null; + } + + private Object createRedirectResponse(HttpServletRequest request, HttpServletResponse response, + RedirectView redirectView, boolean isStreamLoad, String dbName, String tableName, String label) + throws IOException { + if (!shouldUseBoundedDrainForStreamLoad(isStreamLoad)) { + return redirectView; + } + writeTemporaryRedirect(response, redirectView.getUrl()); + drainStreamLoadRequestBodyAfterRedirect(request, redirectView.getUrl(), dbName, tableName, label); + return null; + } + + private boolean shouldUseBoundedDrainForStreamLoad(boolean isStreamLoad) { + return isStreamLoad && Config.stream_load_redirect_bounded_drain_max_bytes > 0; + } + + private void drainStreamLoadRequestBodyAfterRedirect(HttpServletRequest request, String redirectTarget, + String dbName, String tableName, String label) { + long drainLimit = Config.stream_load_redirect_bounded_drain_max_bytes; + LOG.info("write stream load redirect and start bounded drain, target: {}, db: {}, tbl: {}, label: {}," + + " max_drain_bytes: {}", + redirectTarget, dbName, tableName, label, drainLimit); + StreamLoadRedirectDrainUtil.DrainResult drainResult = + StreamLoadRedirectDrainUtil.drainRequestBodyAfterRedirect(request, drainLimit); + LOG.info("finish bounded drain after stream load redirect, target: {}, db: {}, tbl: {}, label: {}," + + " drained_bytes: {}, elapsed_ms: {}, exit_reason: {}", + redirectTarget, dbName, tableName, label, drainResult.getDrainedBytes(), + drainResult.getElapsedMillis(), drainResult.getExitReason()); + } + private Backend selectBackendForGroupCommit(String clusterName, HttpServletRequest req, long tableId) throws LoadException { ConnectContext ctx = new ConnectContext(); @@ -959,35 +980,13 @@ public Object updateIngestionLoad(HttpServletRequest request, HttpServletRespons */ private RedirectView redirectToStreamLoadForward(HttpServletRequest request, TNetworkAddress addr, String forwardTarget) { - URI urlObj = null; - URI resultUriObj = null; - String urlStr = request.getRequestURI(); - String userInfo = null; - String modifiedPath = null; - - if (!Strings.isNullOrEmpty(request.getHeader("Authorization"))) { - ActionAuthorizationInfo authInfo = getAuthorizationInfo(request); - userInfo = ClusterNamespace.getNameFromFullName(authInfo.fullUserName) - + ":" + authInfo.password; - } - try { - urlObj = new URI(urlStr); - // Replace _stream_load with _stream_load_forward in the path - modifiedPath = urlObj.getPath().replace("/_stream_load", "/_stream_load_forward"); - resultUriObj = new URI("http", userInfo, addr.getHostname(), - addr.getPort(), modifiedPath, "", null); - } catch (Exception e) { - throw new RuntimeException(e); - } - String redirectUrl = resultUriObj.toASCIIString(); - - // Add forward_to parameter (note: toASCIIString() already includes '?' due to empty query) + String modifiedPath = request.getRequestURI().replace("/_stream_load", "/_stream_load_forward"); String queryString = request.getQueryString(); + String redirectQuery = "forward_to=" + forwardTarget; if (!Strings.isNullOrEmpty(queryString)) { - redirectUrl += queryString + "&forward_to=" + forwardTarget; - } else { - redirectUrl += "forward_to=" + forwardTarget; + redirectQuery = queryString + "&" + redirectQuery; } + String redirectUrl = buildRedirectUrl(request, addr, modifiedPath, redirectQuery); LOG.info("Redirect stream load forward url: {}, forward_to: {}", "http://" + addr.getHostname() + ":" + addr.getPort() + modifiedPath, forwardTarget); diff --git a/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/RestBaseController.java b/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/RestBaseController.java index 13d3fbc60a3a01..8c59ad0a1271ae 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/RestBaseController.java +++ b/fe/fe-core/src/main/java/org/apache/doris/httpv2/rest/RestBaseController.java @@ -73,10 +73,13 @@ public ActionAuthorizationInfo executeCheckPassword(HttpServletRequest request, return authInfo; } - public RedirectView redirectTo(HttpServletRequest request, TNetworkAddress addr) { - URI urlObj = null; + protected String buildRedirectUrl(HttpServletRequest request, TNetworkAddress addr) { + return buildRedirectUrl(request, addr, request.getRequestURI(), request.getQueryString()); + } + + protected String buildRedirectUrl(HttpServletRequest request, TNetworkAddress addr, String requestPath, + String queryString) { URI resultUriObj = null; - String urlStr = request.getRequestURI(); String userInfo = null; if (!Strings.isNullOrEmpty(request.getHeader("Authorization"))) { ActionAuthorizationInfo authInfo = getAuthorizationInfo(request); @@ -84,18 +87,29 @@ public RedirectView redirectTo(HttpServletRequest request, TNetworkAddress addr) + ":" + authInfo.password; } try { - urlObj = new URI(urlStr); resultUriObj = new URI("http", userInfo, addr.getHostname(), - addr.getPort(), urlObj.getPath(), "", null); + addr.getPort(), requestPath, null, null); } catch (Exception e) { throw new RuntimeException(e); } String redirectUrl = resultUriObj.toASCIIString(); - if (!Strings.isNullOrEmpty(request.getQueryString())) { - redirectUrl += request.getQueryString(); + if (!Strings.isNullOrEmpty(queryString)) { + redirectUrl += "?" + queryString; } LOG.info("Redirect url: {}", "http://" + addr.getHostname() + ":" - + addr.getPort() + urlObj.getPath()); + + addr.getPort() + requestPath); + return redirectUrl; + } + + protected void writeTemporaryRedirect(HttpServletResponse response, String redirectUrl) throws IOException { + response.setContentType("text/html;charset=utf-8"); + response.setStatus(HttpStatus.TEMPORARY_REDIRECT.value()); + response.setHeader("Location", redirectUrl); + response.flushBuffer(); + } + + public RedirectView redirectTo(HttpServletRequest request, TNetworkAddress addr) { + String redirectUrl = buildRedirectUrl(request, addr); RedirectView redirectView = new RedirectView(redirectUrl); redirectView.setContentType("text/html;charset=utf-8"); redirectView.setStatusCode(org.springframework.http.HttpStatus.TEMPORARY_REDIRECT); diff --git a/fe/fe-core/src/main/java/org/apache/doris/httpv2/util/StreamLoadRedirectDrainUtil.java b/fe/fe-core/src/main/java/org/apache/doris/httpv2/util/StreamLoadRedirectDrainUtil.java new file mode 100644 index 00000000000000..758a92df2f4c9f --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/httpv2/util/StreamLoadRedirectDrainUtil.java @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.httpv2.util; + +import com.google.common.base.Preconditions; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.io.IOException; + +public final class StreamLoadRedirectDrainUtil { + private static final Logger LOG = LogManager.getLogger(StreamLoadRedirectDrainUtil.class); + + private static final int BUFFER_SIZE = 8 * 1024; + private static final int IDLE_SLEEP_MS = 5; + private static final int MAX_IDLE_LOOPS = 3; + + private StreamLoadRedirectDrainUtil() { + } + + public static DrainResult drainRequestBodyAfterRedirect(HttpServletRequest request, long maxBytes) { + try { + return drainRequestBodyAfterRedirect(request.getInputStream(), maxBytes); + } catch (IOException e) { + LOG.warn("failed to get request input stream for stream load redirect drain", e); + return new DrainResult(0, 0, ExitReason.ERROR); + } + } + + static DrainResult drainRequestBodyAfterRedirect(ServletInputStream inputStream, long maxBytes) { + Preconditions.checkArgument(maxBytes > 0, "maxBytes must be positive"); + + long startNanos = System.nanoTime(); + long drainedBytes = 0; + int idleLoops = 0; + byte[] buffer = new byte[(int) Math.min(BUFFER_SIZE, maxBytes)]; + + try { + while (drainedBytes < maxBytes) { + int availableBytes = inputStream.available(); + if (availableBytes <= 0) { + idleLoops++; + if (idleLoops >= MAX_IDLE_LOOPS) { + return new DrainResult(drainedBytes, elapsedMillis(startNanos), ExitReason.IDLE_TIMEOUT); + } + if (!sleepForIdleWindow()) { + return new DrainResult(drainedBytes, elapsedMillis(startNanos), ExitReason.ERROR); + } + continue; + } + + idleLoops = 0; + int readLimit = (int) Math.min(Math.min(maxBytes - drainedBytes, buffer.length), availableBytes); + int readBytes = inputStream.read(buffer, 0, readLimit); + if (readBytes < 0) { + return new DrainResult(drainedBytes, elapsedMillis(startNanos), ExitReason.EOF); + } + if (readBytes == 0) { + continue; + } + drainedBytes += readBytes; + } + return new DrainResult(drainedBytes, elapsedMillis(startNanos), ExitReason.MAX_BYTES); + } catch (IOException e) { + LOG.warn("failed while draining request body after stream load redirect", e); + return new DrainResult(drainedBytes, elapsedMillis(startNanos), ExitReason.ERROR); + } + } + + private static boolean sleepForIdleWindow() { + try { + Thread.sleep(IDLE_SLEEP_MS); + return true; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return false; + } + } + + private static long elapsedMillis(long startNanos) { + return (System.nanoTime() - startNanos) / 1_000_000; + } + + public enum ExitReason { + EOF, + MAX_BYTES, + IDLE_TIMEOUT, + ERROR + } + + public static final class DrainResult { + private final long drainedBytes; + private final long elapsedMillis; + private final ExitReason exitReason; + + public DrainResult(long drainedBytes, long elapsedMillis, ExitReason exitReason) { + this.drainedBytes = drainedBytes; + this.elapsedMillis = elapsedMillis; + this.exitReason = exitReason; + } + + public long getDrainedBytes() { + return drainedBytes; + } + + public long getElapsedMillis() { + return elapsedMillis; + } + + public ExitReason getExitReason() { + return exitReason; + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/httpv2/rest/LoadActionTest.java b/fe/fe-core/src/test/java/org/apache/doris/httpv2/rest/LoadActionTest.java new file mode 100644 index 00000000000000..ca024628d1c570 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/httpv2/rest/LoadActionTest.java @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.httpv2.rest; + +import org.apache.doris.common.Config; +import org.apache.doris.thrift.TNetworkAddress; + +import jakarta.servlet.ReadListener; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.http.HttpStatus; +import org.springframework.web.servlet.view.RedirectView; + +import java.lang.reflect.Method; + +public class LoadActionTest { + + @AfterEach + public void tearDown() { + Config.stream_load_redirect_bounded_drain_max_bytes = 0; + } + + @Test + public void testCreateRedirectResponseReturnsRedirectViewWhenBoundedDrainDisabled() throws Exception { + Config.stream_load_redirect_bounded_drain_max_bytes = 0; + LoadAction loadAction = new LoadAction(); + HttpServletRequest request = mockStreamLoadRequest(); + HttpServletResponse response = Mockito.mock(HttpServletResponse.class); + + Object result = invokeCreateRedirectResponse(loadAction, request, response, + new TNetworkAddress("be-host", 8040), true, "db1", "tbl1", "label1"); + + Assertions.assertTrue(result instanceof RedirectView); + RedirectView redirectView = (RedirectView) result; + Assertions.assertEquals("http://be-host:8040/api/db1/tbl1/_stream_load?foo=bar", redirectView.getUrl()); + Assertions.assertEquals(HttpStatus.TEMPORARY_REDIRECT, redirectView.getStatusCode()); + Mockito.verifyNoInteractions(response); + } + + @Test + public void testCreateRedirectResponseWrites307WhenBoundedDrainEnabled() throws Exception { + Config.stream_load_redirect_bounded_drain_max_bytes = 8; + LoadAction loadAction = new LoadAction(); + HttpServletRequest request = mockStreamLoadRequest(); + HttpServletResponse response = Mockito.mock(HttpServletResponse.class); + + Object result = invokeCreateRedirectResponse(loadAction, request, response, + new TNetworkAddress("be-host", 8040), true, "db1", "tbl1", "label1"); + + Assertions.assertNull(result); + Mockito.verify(response).setContentType("text/html;charset=utf-8"); + Mockito.verify(response).setStatus(HttpStatus.TEMPORARY_REDIRECT.value()); + Mockito.verify(response).setHeader("Location", "http://be-host:8040/api/db1/tbl1/_stream_load?foo=bar"); + Mockito.verify(response).flushBuffer(); + } + + @Test + public void testCreateRedirectResponseKeepsNonStreamLoadBehavior() throws Exception { + Config.stream_load_redirect_bounded_drain_max_bytes = 8; + LoadAction loadAction = new LoadAction(); + HttpServletRequest request = mockStreamLoadRequest(); + HttpServletResponse response = Mockito.mock(HttpServletResponse.class); + + Object result = invokeCreateRedirectResponse(loadAction, request, response, + new TNetworkAddress("be-host", 8040), false, "db1", "tbl1", "label1"); + + Assertions.assertTrue(result instanceof RedirectView); + Mockito.verifyNoInteractions(response); + } + + private Object invokeCreateRedirectResponse(LoadAction loadAction, HttpServletRequest request, + HttpServletResponse response, TNetworkAddress redirectAddr, boolean isStreamLoad, String dbName, + String tableName, String label) throws Exception { + Method method = LoadAction.class.getDeclaredMethod("createRedirectResponse", + HttpServletRequest.class, HttpServletResponse.class, TNetworkAddress.class, + boolean.class, String.class, String.class, String.class); + method.setAccessible(true); + return method.invoke(loadAction, request, response, redirectAddr, isStreamLoad, dbName, tableName, label); + } + + private HttpServletRequest mockStreamLoadRequest() throws Exception { + HttpServletRequest request = Mockito.mock(HttpServletRequest.class); + Mockito.when(request.getRequestURI()).thenReturn("/api/db1/tbl1/_stream_load"); + Mockito.when(request.getQueryString()).thenReturn("foo=bar"); + Mockito.when(request.getHeader("Authorization")).thenReturn(null); + Mockito.when(request.getInputStream()).thenReturn(new IdleServletInputStream()); + return request; + } + + private static class IdleServletInputStream extends ServletInputStream { + @Override + public int read() { + return -1; + } + + @Override + public int available() { + return 0; + } + + @Override + public boolean isFinished() { + return false; + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) { + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/httpv2/util/StreamLoadRedirectDrainUtilTest.java b/fe/fe-core/src/test/java/org/apache/doris/httpv2/util/StreamLoadRedirectDrainUtilTest.java new file mode 100644 index 00000000000000..aed0df28afa517 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/httpv2/util/StreamLoadRedirectDrainUtilTest.java @@ -0,0 +1,194 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.httpv2.util; + +import jakarta.servlet.ReadListener; +import jakarta.servlet.ServletInputStream; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Queue; + +public class StreamLoadRedirectDrainUtilTest { + + @Test + public void testDrainRequestBodyWithinMaxBytes() { + StreamLoadRedirectDrainUtil.DrainResult drainResult = + StreamLoadRedirectDrainUtil.drainRequestBodyAfterRedirect( + new QueueAvailableServletInputStream("hello".getBytes(), 5, 0, 0, 0), 16); + + Assertions.assertEquals(5, drainResult.getDrainedBytes()); + Assertions.assertEquals(StreamLoadRedirectDrainUtil.ExitReason.IDLE_TIMEOUT, drainResult.getExitReason()); + } + + @Test + public void testDrainRequestBodyStopsAtMaxBytes() { + StreamLoadRedirectDrainUtil.DrainResult drainResult = + StreamLoadRedirectDrainUtil.drainRequestBodyAfterRedirect( + new QueueAvailableServletInputStream("hello world".getBytes(), 11), 5); + + Assertions.assertEquals(5, drainResult.getDrainedBytes()); + Assertions.assertEquals(StreamLoadRedirectDrainUtil.ExitReason.MAX_BYTES, drainResult.getExitReason()); + } + + @Test + public void testDrainRequestBodyIdleTimeout() { + StreamLoadRedirectDrainUtil.DrainResult drainResult = + StreamLoadRedirectDrainUtil.drainRequestBodyAfterRedirect( + new QueueAvailableServletInputStream(new byte[0], 0, 0, 0, 0), 8); + + Assertions.assertEquals(0, drainResult.getDrainedBytes()); + Assertions.assertEquals(StreamLoadRedirectDrainUtil.ExitReason.IDLE_TIMEOUT, drainResult.getExitReason()); + } + + @Test + public void testDrainRequestBodyReadError() { + StreamLoadRedirectDrainUtil.DrainResult drainResult = + StreamLoadRedirectDrainUtil.drainRequestBodyAfterRedirect(new ErrorServletInputStream(), 8); + + Assertions.assertEquals(0, drainResult.getDrainedBytes()); + Assertions.assertEquals(StreamLoadRedirectDrainUtil.ExitReason.ERROR, drainResult.getExitReason()); + } + + @Test + public void testDrainRequestBodyEof() { + StreamLoadRedirectDrainUtil.DrainResult drainResult = + StreamLoadRedirectDrainUtil.drainRequestBodyAfterRedirect(new EofServletInputStream(), 8); + + Assertions.assertEquals(0, drainResult.getDrainedBytes()); + Assertions.assertEquals(StreamLoadRedirectDrainUtil.ExitReason.EOF, drainResult.getExitReason()); + } + + private static class QueueAvailableServletInputStream extends ServletInputStream { + private final byte[] data; + private final Queue availableValues = new ArrayDeque<>(); + private int offset = 0; + + QueueAvailableServletInputStream(byte[] data, int... availableValues) { + this.data = data; + for (int availableValue : availableValues) { + this.availableValues.add(availableValue); + } + } + + @Override + public int read() { + if (offset >= data.length) { + return -1; + } + return data[offset++] & 0xFF; + } + + @Override + public int read(byte[] b, int off, int len) { + if (offset >= data.length) { + return -1; + } + int readBytes = Math.min(len, data.length - offset); + System.arraycopy(data, offset, b, off, readBytes); + offset += readBytes; + return readBytes; + } + + @Override + public int available() { + if (!availableValues.isEmpty()) { + return availableValues.poll(); + } + return Math.max(0, data.length - offset); + } + + @Override + public boolean isFinished() { + return offset >= data.length; + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) { + } + } + + private static class ErrorServletInputStream extends ServletInputStream { + @Override + public int read() throws IOException { + throw new IOException("read error"); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + throw new IOException("read error"); + } + + @Override + public int available() { + return 1; + } + + @Override + public boolean isFinished() { + return false; + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) { + } + } + + private static class EofServletInputStream extends ServletInputStream { + @Override + public int read() { + return -1; + } + + @Override + public int read(byte[] b, int off, int len) { + return -1; + } + + @Override + public int available() { + return 1; + } + + @Override + public boolean isFinished() { + return true; + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public void setReadListener(ReadListener readListener) { + } + } +}