diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java index f099cb64c8..f78ebb0245 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java @@ -38,6 +38,7 @@ import org.apache.arrow.adbc.core.AdbcStatement; import org.apache.arrow.adbc.core.AdbcStatusCode; import org.apache.arrow.adbc.core.BulkIngestMode; +import org.apache.arrow.adbc.driver.flightsql.oauth.FlightSqlOAuthCredentialWriter; import org.apache.arrow.adbc.sql.SqlQuirks; import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.FlightCallHeaders; @@ -58,6 +59,10 @@ import org.checkerframework.checker.nullness.qual.Nullable; public class FlightSqlConnection implements AdbcConnection { + private static final String AUTH_HEADER_CONFLICT_ERROR = + "[Flight SQL] Authentication conflict: Use either Authorization header or OAuth options, " + + "or username/password parameters"; + private final BufferAllocator allocator; private final AtomicInteger counter = new AtomicInteger(0); private final FlightSqlClientWithCallOptions client; @@ -107,9 +112,18 @@ public class FlightSqlConnection implements AdbcConnection { .build( loc -> { FlightClient client = buildClient(loc); - client.handshake(callOptions); - return new FlightSqlClientWithCallOptions( - new FlightSqlClient(client), callOptions); + try { + client.handshake(callOptions); + return new FlightSqlClientWithCallOptions( + new FlightSqlClient(client), callOptions); + } catch (RuntimeException ex) { + try { + client.close(); + } catch (Exception closeEx) { + ex.addSuppressed(closeEx); + } + throw ex; + } }); this.clientCache.put(location, this.client); } @@ -262,54 +276,91 @@ private FlightClient createInitialConnection( } } - // Build the client using the above properties. - final FlightClient client = buildClient(location); - // Add user-specified headers. ArrayList options = new ArrayList<>(); final FlightCallHeaders callHeaders = new FlightCallHeaders(); - for (Map.Entry parameter : parameters.entrySet()) { - if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) { - String userHeaderName = - parameter - .getKey() - .substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length()); - - if (parameter.getValue() instanceof String) { - callHeaders.insert(userHeaderName, (String) parameter.getValue()); - } else if (parameter.getValue() instanceof byte[]) { - callHeaders.insert(userHeaderName, (byte[]) parameter.getValue()); - } else { - throw new AdbcException( - String.format( - "Header values must be String or byte[]. The header failing was %s.", - parameter.getKey()), - null, - AdbcStatusCode.INVALID_ARGUMENT, - null, - 0); + String authorizationHeader = null; + String username = null; + String password = null; + String oauthFlow = null; + if (parameters != null) { + for (Map.Entry parameter : parameters.entrySet()) { + if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) { + String userHeaderName = + parameter + .getKey() + .substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length()); + + if (parameter.getValue() instanceof String) { + callHeaders.insert(userHeaderName, (String) parameter.getValue()); + } else if (parameter.getValue() instanceof byte[]) { + callHeaders.insert(userHeaderName, (byte[]) parameter.getValue()); + } else { + throw new AdbcException( + String.format( + "Header values must be String or byte[]. The header failing was %s.", + parameter.getKey()), + null, + AdbcStatusCode.INVALID_ARGUMENT, + null, + 0); + } } } + + authorizationHeader = FlightSqlConnectionProperties.AUTHORIZATION_HEADER.get(parameters); + username = AdbcDriver.PARAM_USERNAME.get(parameters); + password = AdbcDriver.PARAM_PASSWORD.get(parameters); + oauthFlow = FlightSqlConnectionProperties.OAUTH_FLOW.get(parameters); + } + + if (authorizationHeader != null) { + callHeaders.insert("authorization", authorizationHeader); } options.add(new HeaderCallOption(callHeaders)); - // Test the connection. - String username = AdbcDriver.PARAM_USERNAME.get(parameters); - String password = AdbcDriver.PARAM_PASSWORD.get(parameters); - if (username != null && password != null) { - Optional bearerToken = - client.authenticateBasicToken(username, password); - options.add( - bearerToken.orElse( - new CredentialCallOption(new BasicAuthCredentialWriter(username, password)))); - this.callOptions = options.toArray(new CallOption[0]); - } else { - this.callOptions = options.toArray(new CallOption[0]); - client.handshake(this.callOptions); + final boolean hasAuthorizationHeader = authorizationHeader != null; + final boolean hasUsernamePassword = username != null || password != null; + final boolean hasOauth = oauthFlow != null; + + if ((hasAuthorizationHeader && (hasUsernamePassword || hasOauth)) + || (hasUsernamePassword && hasOauth)) { + throw AdbcException.invalidArgument(AUTH_HEADER_CONFLICT_ERROR); } - return client; + // Build the client using the above properties. + final FlightClient client = buildClient(location); + + try { + // Test the connection. + if (hasOauth) { + final FlightSqlOAuthCredentialWriter oauthCredentialWriter = + FlightSqlOAuthCredentialWriter.create(parameters); + oauthCredentialWriter.prefetchToken(); + options.add(new CredentialCallOption(oauthCredentialWriter)); + this.callOptions = options.toArray(new CallOption[0]); + client.handshake(this.callOptions); + } else if (username != null && password != null) { + Optional bearerToken = + client.authenticateBasicToken(username, password); + options.add( + bearerToken.orElse( + new CredentialCallOption(new BasicAuthCredentialWriter(username, password)))); + this.callOptions = options.toArray(new CallOption[0]); + } else { + this.callOptions = options.toArray(new CallOption[0]); + client.handshake(this.callOptions); + } + return client; + } catch (AdbcException | RuntimeException ex) { + try { + client.close(); + } catch (Exception closeEx) { + ex.addSuppressed(closeEx); + } + throw ex; + } } /** Returns a yet-to-be authenticated FlightClient */ diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java index 4ab1955a1b..a6ed5f2650 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionProperties.java @@ -22,10 +22,33 @@ /** Defines connection options that are used by the FlightSql driver. */ public interface FlightSqlConnectionProperties { + TypedKey AUTHORIZATION_HEADER = + new TypedKey<>("adbc.flight.sql.authorization_header", String.class); TypedKey MTLS_CERT_CHAIN = new TypedKey<>("adbc.flight.sql.client_option.mtls_cert_chain", InputStream.class); TypedKey MTLS_PRIVATE_KEY = new TypedKey<>("adbc.flight.sql.client_option.mtls_private_key", InputStream.class); + TypedKey OAUTH_FLOW = new TypedKey<>("adbc.flight.sql.oauth.flow", String.class); + TypedKey OAUTH_TOKEN_URI = + new TypedKey<>("adbc.flight.sql.oauth.token_uri", String.class); + TypedKey OAUTH_CLIENT_ID = + new TypedKey<>("adbc.flight.sql.oauth.client_id", String.class); + TypedKey OAUTH_CLIENT_SECRET = + new TypedKey<>("adbc.flight.sql.oauth.client_secret", String.class); + TypedKey OAUTH_SCOPE = new TypedKey<>("adbc.flight.sql.oauth.scope", String.class); + TypedKey OAUTH_RESOURCE = new TypedKey<>("adbc.flight.sql.oauth.resource", String.class); + TypedKey OAUTH_EXCHANGE_SUBJECT_TOKEN = + new TypedKey<>("adbc.flight.sql.oauth.exchange.subject_token", String.class); + TypedKey OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE = + new TypedKey<>("adbc.flight.sql.oauth.exchange.subject_token_type", String.class); + TypedKey OAUTH_EXCHANGE_ACTOR_TOKEN = + new TypedKey<>("adbc.flight.sql.oauth.exchange.actor_token", String.class); + TypedKey OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE = + new TypedKey<>("adbc.flight.sql.oauth.exchange.actor_token_type", String.class); + TypedKey OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE = + new TypedKey<>("adbc.flight.sql.oauth.exchange.requested_token_type", String.class); + TypedKey OAUTH_EXCHANGE_AUD = + new TypedKey<>("adbc.flight.sql.oauth.exchange.aud", String.class); TypedKey TLS_OVERRIDE_HOSTNAME = new TypedKey<>("adbc.flight.sql.client_option.tls_override_hostname", String.class); TypedKey TLS_SKIP_VERIFY = diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java index af8221d6b0..706f69ff9f 100644 --- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDatabase.java @@ -64,6 +64,13 @@ public AdbcConnection connect() throws AdbcException { adbcException.addSuppressed(e); } throw adbcException; + } catch (AdbcException ex) { + try { + AutoCloseables.close(connectionAllocator); + } catch (Exception e) { + ex.addSuppressed(e); + } + throw ex; } catch (Exception ex) { AdbcException adbcException = FlightSqlDriverUtil.fromGeneralException(ex); try { diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthCredentialWriter.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthCredentialWriter.java new file mode 100644 index 0000000000..a69cda9758 --- /dev/null +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/FlightSqlOAuthCredentialWriter.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.driver.flightsql.oauth; + +import java.net.URI; +import java.net.URISyntaxException; +import java.sql.SQLException; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.TypedKey; +import org.apache.arrow.adbc.driver.flightsql.FlightSqlConnectionProperties; +import org.apache.arrow.driver.jdbc.client.oauth.OAuthTokenProvider; +import org.apache.arrow.driver.jdbc.client.oauth.OAuthTokenProviders; +import org.apache.arrow.flight.CallHeaders; +import org.checkerframework.checker.nullness.qual.Nullable; + +public final class FlightSqlOAuthCredentialWriter implements Consumer { + private final OAuthTokenProvider tokenProvider; + + private FlightSqlOAuthCredentialWriter(OAuthTokenProvider tokenProvider) { + this.tokenProvider = tokenProvider; + } + + public static FlightSqlOAuthCredentialWriter create(Map parameters) + throws AdbcException { + final String flowValue = requireOption(parameters, FlightSqlConnectionProperties.OAUTH_FLOW); + final OAuthFlowType flowType = OAuthFlowType.fromValue(flowValue); + if (flowType == null) { + throw AdbcException.notImplemented("[Flight SQL] oauth flow not implemented: " + flowValue); + } + + switch (flowType) { + case CLIENT_CREDENTIALS: + return new FlightSqlOAuthCredentialWriter(createClientCredentialsProvider(parameters)); + case TOKEN_EXCHANGE: + return new FlightSqlOAuthCredentialWriter(createTokenExchangeProvider(parameters)); + default: + throw AdbcException.notImplemented("[Flight SQL] oauth flow not implemented: " + flowType); + } + } + + public void prefetchToken() throws AdbcException { + currentAuthorizationValue(); + } + + @Override + public void accept(CallHeaders headers) { + try { + headers.insert("authorization", currentAuthorizationValue()); + } catch (AdbcException e) { + throw new IllegalStateException(e.getMessage(), e); + } + } + + private String currentAuthorizationValue() throws AdbcException { + try { + return "Bearer " + tokenProvider.getValidToken(); + } catch (SQLException e) { + throw AdbcException.io("[Flight SQL] OAuth token request failed: " + e.getMessage()) + .withCause(e); + } + } + + private static OAuthTokenProvider createClientCredentialsProvider(Map parameters) + throws AdbcException { + try { + final OAuthTokenProviders.ClientCredentialsBuilder builder = + OAuthTokenProviders.clientCredentials() + .tokenUri(requireOption(parameters, FlightSqlConnectionProperties.OAUTH_TOKEN_URI)) + .clientId(requireOption(parameters, FlightSqlConnectionProperties.OAUTH_CLIENT_ID)) + .clientSecret( + requireOption(parameters, FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET)); + + final String scope = FlightSqlConnectionProperties.OAUTH_SCOPE.get(parameters); + if (scope != null) { + builder.scope(scope); + } + return builder.build(); + } catch (RuntimeException e) { + throw AdbcException.invalidArgument( + "[Flight SQL] Invalid OAuth client credentials configuration: " + e.getMessage()) + .withCause(e); + } + } + + private static OAuthTokenProvider createTokenExchangeProvider(Map parameters) + throws AdbcException { + try { + final OAuthTokenProviders.TokenExchangeBuilder builder = + OAuthTokenProviders.tokenExchange() + .tokenUri(requireOption(parameters, FlightSqlConnectionProperties.OAUTH_TOKEN_URI)) + .subjectToken( + requireOption( + parameters, FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN)) + .subjectTokenType( + requireOption( + parameters, FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE)); + + final String actorToken = + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.get(parameters); + final String actorTokenType = + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.get(parameters); + if ((actorToken == null) != (actorTokenType == null)) { + throw AdbcException.invalidArgument( + "[Flight SQL] token exchange grant requires " + + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.getKey() + + " when " + + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.getKey() + + " is provided"); + } + if (actorToken != null) { + builder.actorToken(actorToken).actorTokenType(Objects.requireNonNull(actorTokenType)); + } + + final String clientId = FlightSqlConnectionProperties.OAUTH_CLIENT_ID.get(parameters); + final String clientSecret = FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.get(parameters); + if ((clientId == null) != (clientSecret == null)) { + throw AdbcException.invalidArgument( + "[Flight SQL] token exchange grant requires both " + + FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey() + + " and " + + FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey() + + " when client credentials are provided"); + } + if (clientId != null) { + builder.clientCredentials(clientId, Objects.requireNonNull(clientSecret)); + } + + final String audience = FlightSqlConnectionProperties.OAUTH_EXCHANGE_AUD.get(parameters); + if (audience != null) { + builder.audience(audience); + } + + final String requestedTokenType = + FlightSqlConnectionProperties.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE.get(parameters); + if (requestedTokenType != null) { + builder.requestedTokenType(requestedTokenType); + } + + final String resource = FlightSqlConnectionProperties.OAUTH_RESOURCE.get(parameters); + if (resource != null) { + try { + builder.resource(new URI(resource)); + } catch (URISyntaxException e) { + throw AdbcException.invalidArgument( + "[Flight SQL] token exchange grant requires a valid URI for " + + FlightSqlConnectionProperties.OAUTH_RESOURCE.getKey()) + .withCause(e); + } + } + + final @Nullable String scope = FlightSqlConnectionProperties.OAUTH_SCOPE.get(parameters); + if (scope != null) { + builder.scope(scope); + } + + return builder.build(); + } catch (AdbcException e) { + throw e; + } catch (RuntimeException e) { + throw AdbcException.invalidArgument( + "[Flight SQL] Invalid OAuth token exchange configuration: " + e.getMessage()) + .withCause(e); + } + } + + private static String requireOption(Map parameters, TypedKey option) + throws AdbcException { + final String value = option.get(parameters); + if (value == null) { + throw AdbcException.invalidArgument("[Flight SQL] OAuth flow requires " + option.getKey()); + } + return value; + } +} diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/OAuthFlowType.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/OAuthFlowType.java new file mode 100644 index 0000000000..2d54828708 --- /dev/null +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/OAuthFlowType.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.driver.flightsql.oauth; + +import org.checkerframework.checker.nullness.qual.Nullable; + +public enum OAuthFlowType { + CLIENT_CREDENTIALS("client_credentials"), + TOKEN_EXCHANGE("token_exchange"); + + private final String value; + + OAuthFlowType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + static @Nullable OAuthFlowType fromValue(String value) { + for (OAuthFlowType flowType : values()) { + if (flowType.value.equals(value)) { + return flowType; + } + } + return null; + } +} diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/OAuthTokenType.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/OAuthTokenType.java new file mode 100644 index 0000000000..e05d359dc3 --- /dev/null +++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/oauth/OAuthTokenType.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adbc.driver.flightsql.oauth; + +public enum OAuthTokenType { + ACCESS_TOKEN("urn:ietf:params:oauth:token-type:access_token"), + REFRESH_TOKEN("urn:ietf:params:oauth:token-type:refresh_token"), + JWT("urn:ietf:params:oauth:token-type:jwt"), + ID_TOKEN("urn:ietf:params:oauth:token-type:id_token"), + SAML1("urn:ietf:params:oauth:token-type:saml1"), + SAML2("urn:ietf:params:oauth:token-type:saml2"); + + private final String value; + + OAuthTokenType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } +} diff --git a/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/OAuthTest.java b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/OAuthTest.java new file mode 100644 index 0000000000..afd1d4097e --- /dev/null +++ b/java/driver/flight-sql/src/test/java/org/apache/arrow/adbc/driver/flightsql/OAuthTest.java @@ -0,0 +1,510 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.adbc.driver.flightsql; + +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.nimbusds.oauth2.sdk.http.HTTPRequest; +import com.sun.net.httpserver.Headers; +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import com.sun.net.httpserver.HttpsConfigurator; +import com.sun.net.httpserver.HttpsServer; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.KeyFactory; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.SecureRandom; +import java.security.cert.Certificate; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManagerFactory; +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcDriver; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.AdbcInfoCode; +import org.apache.arrow.adbc.core.AdbcStatusCode; +import org.apache.arrow.adbc.driver.flightsql.oauth.OAuthFlowType; +import org.apache.arrow.adbc.driver.flightsql.oauth.OAuthTokenType; +import org.apache.arrow.adbc.drivermanager.AdbcDriverManager; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class OAuthTest { + private static final String CLIENT_ID = "adbc-client"; + private static final String CLIENT_SECRET = "adbc-secret"; + private static final String CLIENT_ACCESS_TOKEN = "client-credentials-token"; + private static final String EXCHANGE_ACCESS_TOKEN = "token-exchange-token"; + private static final char[] TRUST_STORE_PASSWORD = "changeit".toCharArray(); + + private BufferAllocator allocator; + private Map params; + private FlightServer server; + private HttpServer tokenServer; + private AdbcDatabase database; + private AdbcConnection connection; + private HeaderValidator.Factory headerValidatorFactory; + private TokenHandler tokenHandler; + private String tokenServerScheme; + private String previousTrustStore; + private String previousTrustStorePassword; + private String previousTrustStoreType; + private SSLSocketFactory previousDefaultSslSocketFactory; + private HostnameVerifier previousDefaultHostnameVerifier; + private final List tempPaths = new ArrayList<>(); + + @BeforeEach + public void setUp() throws IOException { + allocator = new RootAllocator(Long.MAX_VALUE); + params = new HashMap<>(); + headerValidatorFactory = new HeaderValidator.Factory(); + server = + FlightServer.builder() + .allocator(allocator) + .middleware(HeaderValidator.KEY, headerValidatorFactory) + .location(Location.forGrpcInsecure("localhost", 0)) + .producer(new MockFlightSqlProducer()) + .build(); + server.start(); + + tokenHandler = new TokenHandler(); + previousTrustStore = System.getProperty("javax.net.ssl.trustStore"); + previousTrustStorePassword = System.getProperty("javax.net.ssl.trustStorePassword"); + previousTrustStoreType = System.getProperty("javax.net.ssl.trustStoreType"); + previousDefaultSslSocketFactory = HTTPRequest.getDefaultSSLSocketFactory(); + previousDefaultHostnameVerifier = HTTPRequest.getDefaultHostnameVerifier(); + startHttpTokenServer(); + + params.put( + AdbcDriver.PARAM_URI.getKey(), String.format("grpc+tcp://localhost:%d", server.getPort())); + } + + @AfterEach + public void tearDown() throws Exception { + AutoCloseables.close(connection, database, server, allocator); + if (tokenServer != null) { + tokenServer.stop(0); + } + restoreJvmTrustStoreConfiguration(); + for (Path path : tempPaths) { + Files.deleteIfExists(path); + } + connection = null; + database = null; + server = null; + allocator = null; + tokenServer = null; + } + + @Test + public void testClientCredentialsFlow() throws Exception { + tokenHandler.accessToken = CLIENT_ACCESS_TOKEN; + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), + OAuthFlowType.CLIENT_CREDENTIALS.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET); + params.put(FlightSqlConnectionProperties.OAUTH_SCOPE.getKey(), "scope-a scope-b"); + + connect(); + requestServerMetadata(); + + CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0); + assertEquals("Bearer " + CLIENT_ACCESS_TOKEN, headers.get("authorization")); + + assertEquals(1, tokenHandler.requestBodies.size()); + assertEquals("client_credentials", tokenHandler.formValue(0, "grant_type")); + assertEquals("scope-a scope-b", tokenHandler.formValue(0, "scope")); + assertTrue(tokenHandler.authorizationHeaders.get(0).startsWith("Basic ")); + } + + @Test + public void testTokenExchangeFlow() throws Exception { + tokenHandler.accessToken = EXCHANGE_ACCESS_TOKEN; + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), OAuthFlowType.TOKEN_EXCHANGE.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN.getKey(), "subject-token"); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.getKey(), + OAuthTokenType.ACCESS_TOKEN.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.getKey(), "actor-token"); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE.getKey(), + OAuthTokenType.JWT.getValue()); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE.getKey(), + OAuthTokenType.ACCESS_TOKEN.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_EXCHANGE_AUD.getKey(), "flight-service"); + params.put( + FlightSqlConnectionProperties.OAUTH_RESOURCE.getKey(), "https://resource.example.com"); + params.put(FlightSqlConnectionProperties.OAUTH_SCOPE.getKey(), "profile email"); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET); + + connect(); + requestServerMetadata(); + + CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0); + assertEquals("Bearer " + EXCHANGE_ACCESS_TOKEN, headers.get("authorization")); + + assertEquals(1, tokenHandler.requestBodies.size()); + assertEquals( + "urn:ietf:params:oauth:grant-type:token-exchange", tokenHandler.formValue(0, "grant_type")); + assertEquals("subject-token", tokenHandler.formValue(0, "subject_token")); + assertEquals( + OAuthTokenType.ACCESS_TOKEN.getValue(), tokenHandler.formValue(0, "subject_token_type")); + assertEquals("actor-token", tokenHandler.formValue(0, "actor_token")); + assertEquals(OAuthTokenType.JWT.getValue(), tokenHandler.formValue(0, "actor_token_type")); + assertEquals( + OAuthTokenType.ACCESS_TOKEN.getValue(), tokenHandler.formValue(0, "requested_token_type")); + assertEquals("flight-service", tokenHandler.formValue(0, "audience")); + assertEquals("https://resource.example.com", tokenHandler.formValue(0, "resource")); + assertEquals("profile email", tokenHandler.formValue(0, "scope")); + assertTrue(tokenHandler.authorizationHeaders.get(0).startsWith("Basic ")); + } + + @Test + public void testClientCredentialsFlowWithHttpsTokenEndpointWithoutTrustStore() throws Exception { + tokenHandler.accessToken = CLIENT_ACCESS_TOKEN; + startHttpsTokenServer(); + clearJvmTrustStoreConfiguration(); + + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), + OAuthFlowType.CLIENT_CREDENTIALS.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.IO, adbcException.getStatus()); + } + + @Test + public void testClientCredentialsFlowWithHttpsTokenEndpointUsesJvmTrustStore() throws Exception { + tokenHandler.accessToken = CLIENT_ACCESS_TOKEN; + startHttpsTokenServer(); + configureJvmTrustStore(createTrustStorePath()); + + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), + OAuthFlowType.CLIENT_CREDENTIALS.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET); + + connect(); + requestServerMetadata(); + + CallHeaders headers = headerValidatorFactory.getHeadersReceivedAtRequest(0); + assertEquals("Bearer " + CLIENT_ACCESS_TOKEN, headers.get("authorization")); + assertEquals(1, tokenHandler.requestBodies.size()); + assertEquals("client_credentials", tokenHandler.formValue(0, "grant_type")); + } + + @Test + public void testAuthorizationHeaderConflictsWithOauth() { + params.put( + FlightSqlConnectionProperties.AUTHORIZATION_HEADER.getKey(), "Bearer existing-token"); + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), + OAuthFlowType.CLIENT_CREDENTIALS.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_SECRET.getKey(), CLIENT_SECRET); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus()); + } + + @Test + public void testMissingRequiredParamsClientCredentials() { + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), + OAuthFlowType.CLIENT_CREDENTIALS.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put(FlightSqlConnectionProperties.OAUTH_CLIENT_ID.getKey(), CLIENT_ID); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus()); + } + + @Test + public void testMissingRequiredParamsTokenExchange() { + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), OAuthFlowType.TOKEN_EXCHANGE.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN.getKey(), "subject-token"); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus()); + } + + @Test + public void testActorTokenRequiresActorTokenType() { + params.put( + FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), OAuthFlowType.TOKEN_EXCHANGE.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN.getKey(), "subject-token"); + params.put( + FlightSqlConnectionProperties.OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE.getKey(), + OAuthTokenType.ACCESS_TOKEN.getValue()); + params.put(FlightSqlConnectionProperties.OAUTH_EXCHANGE_ACTOR_TOKEN.getKey(), "actor-token"); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.INVALID_ARGUMENT, adbcException.getStatus()); + } + + @Test + public void testInvalidOauthFlow() { + params.put(FlightSqlConnectionProperties.OAUTH_FLOW.getKey(), "invalid-flow"); + params.put(FlightSqlConnectionProperties.OAUTH_TOKEN_URI.getKey(), tokenUri()); + + AdbcException adbcException = assertThrows(AdbcException.class, this::connect); + assertEquals(AdbcStatusCode.NOT_IMPLEMENTED, adbcException.getStatus()); + } + + private void connect() throws Exception { + database = + AdbcDriverManager.getInstance() + .connect(FlightSqlDriverFactory.class.getCanonicalName(), allocator, params); + connection = database.connect(); + } + + private void requestServerMetadata() throws Exception { + try (ArrowReader reader = connection.getInfo(new int[] {AdbcInfoCode.VENDOR_NAME.getValue()})) { + while (reader.loadNextBatch()) { + // Only interested in triggering an authenticated RPC. + } + } catch (Exception ex) { + // MockFlightSqlProducer does not implement the full SQL metadata surface. + } + } + + private String tokenUri() { + return String.format( + "%s://localhost:%d/token", tokenServerScheme, tokenServer.getAddress().getPort()); + } + + private void startHttpTokenServer() throws IOException { + if (tokenServer != null) { + tokenServer.stop(0); + } + tokenServerScheme = "http"; + tokenServer = HttpServer.create(new InetSocketAddress("localhost", 0), 0); + tokenServer.createContext("/token", tokenHandler); + tokenServer.start(); + } + + private void startHttpsTokenServer() throws Exception { + if (tokenServer != null) { + tokenServer.stop(0); + } + tokenServerScheme = "https"; + final SSLContext sslContext = createServerSslContext(); + final HttpsServer httpsServer = HttpsServer.create(new InetSocketAddress("localhost", 0), 0); + httpsServer.setHttpsConfigurator(new HttpsConfigurator(sslContext)); + httpsServer.createContext("/token", tokenHandler); + httpsServer.start(); + tokenServer = httpsServer; + } + + private void configureJvmTrustStore(Path trustStorePath) throws Exception { + System.setProperty("javax.net.ssl.trustStore", trustStorePath.toString()); + System.setProperty("javax.net.ssl.trustStorePassword", new String(TRUST_STORE_PASSWORD)); + System.setProperty("javax.net.ssl.trustStoreType", "PKCS12"); + refreshOAuthHttpsDefaults(); + } + + private void clearJvmTrustStoreConfiguration() throws Exception { + System.clearProperty("javax.net.ssl.trustStore"); + System.clearProperty("javax.net.ssl.trustStorePassword"); + System.clearProperty("javax.net.ssl.trustStoreType"); + refreshOAuthHttpsDefaults(); + } + + private void restoreJvmTrustStoreConfiguration() { + restoreSystemProperty("javax.net.ssl.trustStore", previousTrustStore); + restoreSystemProperty("javax.net.ssl.trustStorePassword", previousTrustStorePassword); + restoreSystemProperty("javax.net.ssl.trustStoreType", previousTrustStoreType); + HTTPRequest.setDefaultSSLSocketFactory(previousDefaultSslSocketFactory); + HTTPRequest.setDefaultHostnameVerifier(oauthHostnameVerifier()); + } + + private void refreshOAuthHttpsDefaults() throws Exception { + final TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + final SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, trustManagerFactory.getTrustManagers(), new SecureRandom()); + HTTPRequest.setDefaultSSLSocketFactory(sslContext.getSocketFactory()); + HTTPRequest.setDefaultHostnameVerifier(oauthHostnameVerifier()); + } + + private Path createTrustStorePath() throws Exception { + final KeyStore trustStore = KeyStore.getInstance("PKCS12"); + trustStore.load(null, null); + trustStore.setCertificateEntry("root", readCertificate(flightDataPath("root-ca.pem"))); + + final Path trustStorePath = Files.createTempFile("oauth-truststore", ".p12"); + tempPaths.add(trustStorePath); + try (OutputStream output = Files.newOutputStream(trustStorePath)) { + trustStore.store(output, TRUST_STORE_PASSWORD); + } + return trustStorePath; + } + + private SSLContext createServerSslContext() throws Exception { + final Certificate certificate = readCertificate(flightDataPath("cert0.pem")); + final PrivateKey privateKey = readPrivateKey(flightDataPath("cert0.pkcs1")); + final KeyStore keyStore = KeyStore.getInstance("PKCS12"); + keyStore.load(null, null); + keyStore.setKeyEntry( + "token-server", privateKey, TRUST_STORE_PASSWORD, new Certificate[] {certificate}); + + final KeyManagerFactory keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + keyManagerFactory.init(keyStore, TRUST_STORE_PASSWORD); + + final SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(keyManagerFactory.getKeyManagers(), null, new SecureRandom()); + return sslContext; + } + + private static X509Certificate readCertificate(Path path) throws Exception { + try (InputStream input = Files.newInputStream(path)) { + final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); + return (X509Certificate) certificateFactory.generateCertificate(input); + } + } + + private static PrivateKey readPrivateKey(Path path) throws Exception { + final String pem = Files.readString(path, StandardCharsets.US_ASCII); + final String base64 = + pem.replace("-----BEGIN PRIVATE KEY-----", "") + .replace("-----END PRIVATE KEY-----", "") + .replaceAll("\\s+", ""); + final byte[] der = Base64.getDecoder().decode(base64); + return KeyFactory.getInstance("RSA").generatePrivate(new PKCS8EncodedKeySpec(der)); + } + + private static Path flightDataPath(String filename) { + final String dataRoot = System.getProperty("arrow.test.dataRoot"); + if (dataRoot != null) { + return Paths.get(dataRoot).resolve("flight").resolve(filename); + } + return Paths.get("testing", "data", "flight", filename).toAbsolutePath().normalize(); + } + + private static void restoreSystemProperty(String key, String value) { + if (value == null) { + System.clearProperty(key); + } else { + System.setProperty(key, value); + } + } + + private HostnameVerifier oauthHostnameVerifier() { + if (previousDefaultHostnameVerifier != null) { + return previousDefaultHostnameVerifier; + } + return HttpsURLConnection.getDefaultHostnameVerifier(); + } + + private static final class TokenHandler implements HttpHandler { + private final List requestBodies = new ArrayList<>(); + private final List authorizationHeaders = new ArrayList<>(); + private String accessToken = CLIENT_ACCESS_TOKEN; + + @Override + public void handle(HttpExchange exchange) throws IOException { + final String body = + new String(exchange.getRequestBody().readAllBytes(), StandardCharsets.UTF_8); + requestBodies.add(body); + authorizationHeaders.add(exchange.getRequestHeaders().getFirst("Authorization")); + + final byte[] responseBytes = + ("{\"access_token\":\"" + + accessToken + + "\",\"token_type\":\"Bearer\",\"expires_in\":3600}") + .getBytes(StandardCharsets.UTF_8); + final Headers responseHeaders = exchange.getResponseHeaders(); + responseHeaders.add("Content-Type", "application/json"); + exchange.sendResponseHeaders(200, responseBytes.length); + try (OutputStream output = exchange.getResponseBody()) { + output.write(responseBytes); + } + } + + private String formValue(int requestIndex, String key) { + return decodeForm(requestBodies.get(requestIndex)).get(key); + } + + private static Map decodeForm(String body) { + final Map values = new LinkedHashMap<>(); + if (body.isEmpty()) { + return values; + } + for (String pair : body.split("&")) { + final String[] parts = pair.split("=", 2); + final String name = URLDecoder.decode(parts[0], StandardCharsets.UTF_8); + final String value = + parts.length > 1 ? URLDecoder.decode(parts[1], StandardCharsets.UTF_8) : ""; + values.put(name, value); + } + return values; + } + } +}