Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.arrow.adbc.core.AdbcStatement;
import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.adbc.core.BulkIngestMode;
import org.apache.arrow.adbc.driver.flightsql.oauth.FlightSqlOAuthCredentialWriter;
import org.apache.arrow.adbc.sql.SqlQuirks;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightCallHeaders;
Expand All @@ -58,6 +59,10 @@
import org.checkerframework.checker.nullness.qual.Nullable;

public class FlightSqlConnection implements AdbcConnection {
private static final String AUTH_HEADER_CONFLICT_ERROR =
"[Flight SQL] Authentication conflict: Use either Authorization header or OAuth options, "
+ "or username/password parameters";

private final BufferAllocator allocator;
private final AtomicInteger counter = new AtomicInteger(0);
private final FlightSqlClientWithCallOptions client;
Expand Down Expand Up @@ -107,9 +112,18 @@ public class FlightSqlConnection implements AdbcConnection {
.build(
loc -> {
FlightClient client = buildClient(loc);
client.handshake(callOptions);
return new FlightSqlClientWithCallOptions(
new FlightSqlClient(client), callOptions);
try {
client.handshake(callOptions);
return new FlightSqlClientWithCallOptions(
new FlightSqlClient(client), callOptions);
} catch (RuntimeException ex) {
try {
client.close();
} catch (Exception closeEx) {
ex.addSuppressed(closeEx);
}
throw ex;
}
});
this.clientCache.put(location, this.client);
}
Expand Down Expand Up @@ -262,54 +276,91 @@ private FlightClient createInitialConnection(
}
}

// Build the client using the above properties.
final FlightClient client = buildClient(location);

// Add user-specified headers.
ArrayList<CallOption> options = new ArrayList<>();
final FlightCallHeaders callHeaders = new FlightCallHeaders();
for (Map.Entry<String, Object> parameter : parameters.entrySet()) {
if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) {
String userHeaderName =
parameter
.getKey()
.substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length());

if (parameter.getValue() instanceof String) {
callHeaders.insert(userHeaderName, (String) parameter.getValue());
} else if (parameter.getValue() instanceof byte[]) {
callHeaders.insert(userHeaderName, (byte[]) parameter.getValue());
} else {
throw new AdbcException(
String.format(
"Header values must be String or byte[]. The header failing was %s.",
parameter.getKey()),
null,
AdbcStatusCode.INVALID_ARGUMENT,
null,
0);
String authorizationHeader = null;
String username = null;
String password = null;
String oauthFlow = null;
if (parameters != null) {
for (Map.Entry<String, Object> parameter : parameters.entrySet()) {
if (parameter.getKey().startsWith(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX)) {
String userHeaderName =
parameter
.getKey()
.substring(FlightSqlConnectionProperties.RPC_CALL_HEADER_PREFIX.length());

if (parameter.getValue() instanceof String) {
callHeaders.insert(userHeaderName, (String) parameter.getValue());
} else if (parameter.getValue() instanceof byte[]) {
callHeaders.insert(userHeaderName, (byte[]) parameter.getValue());
} else {
throw new AdbcException(
String.format(
"Header values must be String or byte[]. The header failing was %s.",
parameter.getKey()),
null,
AdbcStatusCode.INVALID_ARGUMENT,
null,
0);
}
Comment on lines +282 to +307
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use ArrowFlightSqlClientHandler.Builder to create the flight client with all the configurations?
I think it would simplify the configuration, creating and managing the flight client

}
}

authorizationHeader = FlightSqlConnectionProperties.AUTHORIZATION_HEADER.get(parameters);
username = AdbcDriver.PARAM_USERNAME.get(parameters);
password = AdbcDriver.PARAM_PASSWORD.get(parameters);
oauthFlow = FlightSqlConnectionProperties.OAUTH_FLOW.get(parameters);
}

if (authorizationHeader != null) {
callHeaders.insert("authorization", authorizationHeader);
}

options.add(new HeaderCallOption(callHeaders));

// Test the connection.
String username = AdbcDriver.PARAM_USERNAME.get(parameters);
String password = AdbcDriver.PARAM_PASSWORD.get(parameters);
if (username != null && password != null) {
Optional<CredentialCallOption> bearerToken =
client.authenticateBasicToken(username, password);
options.add(
bearerToken.orElse(
new CredentialCallOption(new BasicAuthCredentialWriter(username, password))));
this.callOptions = options.toArray(new CallOption[0]);
} else {
this.callOptions = options.toArray(new CallOption[0]);
client.handshake(this.callOptions);
final boolean hasAuthorizationHeader = authorizationHeader != null;
final boolean hasUsernamePassword = username != null || password != null;
final boolean hasOauth = oauthFlow != null;

if ((hasAuthorizationHeader && (hasUsernamePassword || hasOauth))
|| (hasUsernamePassword && hasOauth)) {
throw AdbcException.invalidArgument(AUTH_HEADER_CONFLICT_ERROR);
}

return client;
// Build the client using the above properties.
final FlightClient client = buildClient(location);

try {
// Test the connection.
if (hasOauth) {
final FlightSqlOAuthCredentialWriter oauthCredentialWriter =
FlightSqlOAuthCredentialWriter.create(parameters);
oauthCredentialWriter.prefetchToken();
options.add(new CredentialCallOption(oauthCredentialWriter));
this.callOptions = options.toArray(new CallOption[0]);
client.handshake(this.callOptions);
} else if (username != null && password != null) {
Optional<CredentialCallOption> bearerToken =
client.authenticateBasicToken(username, password);
options.add(
bearerToken.orElse(
new CredentialCallOption(new BasicAuthCredentialWriter(username, password))));
this.callOptions = options.toArray(new CallOption[0]);
} else {
this.callOptions = options.toArray(new CallOption[0]);
client.handshake(this.callOptions);
}
return client;
} catch (AdbcException | RuntimeException ex) {
try {
client.close();
} catch (Exception closeEx) {
ex.addSuppressed(closeEx);
}
throw ex;
}
}

/** Returns a yet-to-be authenticated FlightClient */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,33 @@

/** Defines connection options that are used by the FlightSql driver. */
public interface FlightSqlConnectionProperties {
TypedKey<String> AUTHORIZATION_HEADER =
new TypedKey<>("adbc.flight.sql.authorization_header", String.class);
TypedKey<InputStream> MTLS_CERT_CHAIN =
new TypedKey<>("adbc.flight.sql.client_option.mtls_cert_chain", InputStream.class);
TypedKey<InputStream> MTLS_PRIVATE_KEY =
new TypedKey<>("adbc.flight.sql.client_option.mtls_private_key", InputStream.class);
TypedKey<String> OAUTH_FLOW = new TypedKey<>("adbc.flight.sql.oauth.flow", String.class);
TypedKey<String> OAUTH_TOKEN_URI =
new TypedKey<>("adbc.flight.sql.oauth.token_uri", String.class);
TypedKey<String> OAUTH_CLIENT_ID =
new TypedKey<>("adbc.flight.sql.oauth.client_id", String.class);
TypedKey<String> OAUTH_CLIENT_SECRET =
new TypedKey<>("adbc.flight.sql.oauth.client_secret", String.class);
TypedKey<String> OAUTH_SCOPE = new TypedKey<>("adbc.flight.sql.oauth.scope", String.class);
TypedKey<String> OAUTH_RESOURCE = new TypedKey<>("adbc.flight.sql.oauth.resource", String.class);
TypedKey<String> OAUTH_EXCHANGE_SUBJECT_TOKEN =
new TypedKey<>("adbc.flight.sql.oauth.exchange.subject_token", String.class);
TypedKey<String> OAUTH_EXCHANGE_SUBJECT_TOKEN_TYPE =
new TypedKey<>("adbc.flight.sql.oauth.exchange.subject_token_type", String.class);
TypedKey<String> OAUTH_EXCHANGE_ACTOR_TOKEN =
new TypedKey<>("adbc.flight.sql.oauth.exchange.actor_token", String.class);
TypedKey<String> OAUTH_EXCHANGE_ACTOR_TOKEN_TYPE =
new TypedKey<>("adbc.flight.sql.oauth.exchange.actor_token_type", String.class);
TypedKey<String> OAUTH_EXCHANGE_REQUESTED_TOKEN_TYPE =
new TypedKey<>("adbc.flight.sql.oauth.exchange.requested_token_type", String.class);
TypedKey<String> OAUTH_EXCHANGE_AUD =
new TypedKey<>("adbc.flight.sql.oauth.exchange.aud", String.class);
TypedKey<String> TLS_OVERRIDE_HOSTNAME =
new TypedKey<>("adbc.flight.sql.client_option.tls_override_hostname", String.class);
TypedKey<Boolean> TLS_SKIP_VERIFY =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ public AdbcConnection connect() throws AdbcException {
adbcException.addSuppressed(e);
}
throw adbcException;
} catch (AdbcException ex) {
try {
AutoCloseables.close(connectionAllocator);
} catch (Exception e) {
ex.addSuppressed(e);
}
throw ex;
} catch (Exception ex) {
AdbcException adbcException = FlightSqlDriverUtil.fromGeneralException(ex);
try {
Expand Down
Loading
Loading