diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFutures.java b/core/src/main/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFutures.java index 275b2ddfeef..7051eb338d3 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFutures.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFutures.java @@ -21,6 +21,7 @@ import com.datastax.oss.driver.api.core.DriverExecutionException; import com.datastax.oss.driver.shaded.guava.common.base.Preconditions; import com.datastax.oss.driver.shaded.guava.common.base.Throwables; +import java.time.Duration; import java.util.List; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; @@ -28,6 +29,8 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; @@ -164,6 +167,47 @@ public static T getUninterruptibly(CompletionStage stage) { } } + /** + * Get the result of a future uninterruptibly, with a timeout. + * + * @param stage the completion stage to wait for + * @param timeout the maximum time to wait + * @return the result value + * @throws DriverExecutionException if the future completed exceptionally + * @throws DriverExecutionException wrapping TimeoutException if the wait timed out + */ + public static T getUninterruptibly(CompletionStage stage, Duration timeout) { + boolean interrupted = false; + try { + long remainingNanos = timeout.toNanos(); + long deadline = System.nanoTime() + remainingNanos; + while (true) { + try { + return stage.toCompletableFuture().get(remainingNanos, TimeUnit.NANOSECONDS); + } catch (InterruptedException e) { + interrupted = true; + remainingNanos = deadline - System.nanoTime(); + if (remainingNanos <= 0) { + throw new DriverExecutionException(new TimeoutException("Timed out after interrupt")); + } + } catch (TimeoutException e) { + throw new DriverExecutionException(e); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof DriverException) { + throw ((DriverException) cause).copy(); + } + Throwables.throwIfUnchecked(cause); + throw new DriverExecutionException(cause); + } + } + } finally { + if (interrupted) { + Thread.currentThread().interrupt(); + } + } + } + /** * Executes a function on the calling thread and returns result in a {@link CompletableFuture}. * diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFuturesTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFuturesTest.java index 04f96f185fd..2e21441b565 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFuturesTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/util/concurrent/CompletableFuturesTest.java @@ -18,12 +18,16 @@ package com.datastax.oss.driver.internal.core.util.concurrent; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.Assert.fail; +import com.datastax.oss.driver.api.core.DriverExecutionException; +import java.time.Duration; import java.util.Arrays; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import org.junit.Test; public class CompletableFuturesTest { @@ -45,4 +49,28 @@ public void should_not_suppress_identical_exceptions() throws Exception { assertThat(e.getCause()).isEqualTo(error); } } + + @Test + public void should_get_uninterruptibly_with_timeout_on_completed_future() { + CompletableFuture future = CompletableFuture.completedFuture("result"); + String result = CompletableFutures.getUninterruptibly(future, Duration.ofSeconds(1)); + assertThat(result).isEqualTo("result"); + } + + @Test + public void should_timeout_on_incomplete_future() { + CompletableFuture future = new CompletableFuture<>(); + assertThatThrownBy(() -> CompletableFutures.getUninterruptibly(future, Duration.ofMillis(100))) + .isInstanceOf(DriverExecutionException.class) + .hasCauseInstanceOf(TimeoutException.class); + } + + @Test + public void should_propagate_exception_with_timeout() { + CompletableFuture future = new CompletableFuture<>(); + RuntimeException error = new RuntimeException("test error"); + future.completeExceptionally(error); + assertThatThrownBy(() -> CompletableFutures.getUninterruptibly(future, Duration.ofSeconds(1))) + .isEqualTo(error); + } } diff --git a/pom.xml b/pom.xml index 321f3fec2d5..ae3b2358d00 100644 --- a/pom.xml +++ b/pom.xml @@ -100,7 +100,7 @@ ${skipTests} false false - +