diff --git a/core/src/main/java/com/google/adk/models/BedrockBaseLM.java b/core/src/main/java/com/google/adk/models/BedrockBaseLM.java index 9b90ccc45..f7f79aff2 100644 --- a/core/src/main/java/com/google/adk/models/BedrockBaseLM.java +++ b/core/src/main/java/com/google/adk/models/BedrockBaseLM.java @@ -35,6 +35,7 @@ import java.net.URL; import java.util.ArrayList; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -55,10 +56,51 @@ */ public class BedrockBaseLM extends BaseLlm { - // Use a constant for the environment variable name public static final String BEDROCK_ENV_VAR = "BEDROCK_URL"; public String D_URL = null; + /** + * Bearer token env vars (checked in order). BEDROCK_API_KEY (FMIS) preferred - works for both + * Converse + foundation-models in ap-south-1/ap-southeast-1. AWS_BEARER_TOKEN_BEDROCK is legacy. + */ + private static final String[] BEARER_TOKEN_ENV_VARS = { + "BEDROCK_API_KEY", // FMIS key - works for Converse + foundation-models + "AWS_BEARER_TOKEN_BEDROCK", // Legacy runtime token + "BEDROCK_BEARER_TOKEN", + "BEDROCK_TOKEN" + }; + + /** Returns Bearer token from env. Prefers BEDROCK_API_KEY (FMIS) when available. */ + public static String getBearerToken() { + for (String name : BEARER_TOKEN_ENV_VARS) { + String v = System.getenv(name); + if (v != null && !v.isBlank()) return v; + } + return null; + } + + /** Base URL for Converse API. Uses BEDROCK_URL, or BEDROCK_REGION, or default ap-south-1. */ + private static String getBedrockBaseUrl(String overrideUrl) { + if (overrideUrl != null && !overrideUrl.isBlank()) return overrideUrl; + String url = System.getenv(BEDROCK_ENV_VAR); + if (url != null && !url.isBlank()) return url; + String region = System.getenv("BEDROCK_REGION"); + if (region == null || region.isBlank()) region = "ap-south-1"; + return "https://bedrock-runtime." + region + ".amazonaws.com"; + } + + /** + * Returns the full Converse API URL. Handles BEDROCK_URL with or without trailing /model to avoid + * double /model/ in path. + */ + private static String buildConverseUrl(String baseUrl, String model) { + String base = baseUrl == null ? "" : baseUrl.trim().replaceAll("/+$", ""); + if (base.endsWith("/model")) { + return base + "/" + model + "/converse"; + } + return base + "/model/" + model + "/converse"; + } + // Corrected the logger name to use OllamaBaseLM.class private static final Logger logger = LoggerFactory.getLogger(BedrockBaseLM.class); @@ -304,7 +346,30 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre logger.debug("Usage metadata parsing failed (non-critical)", e); } - JSONObject responseQuantum = agentresponse.getJSONObject("output").getJSONObject("message"); + // Bedrock Converse API: output.message, or message/Message at top level. + // Message (capital M) can be a String in error responses (e.g. "Authentication failed"). + JSONObject responseQuantum = extractMessageObject(agentresponse); + if (responseQuantum == null) { + String detail = "Response keys: " + agentresponse.keySet(); + if (agentresponse.has("Output")) { + try { + JSONObject out = agentresponse.getJSONObject("Output"); + detail += ", Output keys: " + out.keySet(); + } catch (Exception e) { + detail += ", Output: (not JSONObject)"; + } + } else if (agentresponse.has("output")) { + try { + JSONObject out = agentresponse.getJSONObject("output"); + detail += ", output keys: " + out.keySet(); + } catch (Exception e) { + detail += ", output: (not JSONObject)"; + } + } + throw new IllegalStateException( + "Unexpected Bedrock response: missing output/Output.message, message, or Message. " + + detail); + } // Check if tool call is required // Tools call @@ -335,6 +400,59 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre return Flowable.just(responseBuilder.build()); } + /** Gets a value from JSONObject using case-insensitive key match. */ + private static Object getKeyIgnoreCase(JSONObject obj, String... keys) { + Iterator it = obj.keys(); + while (it.hasNext()) { + String k = it.next(); + for (String target : keys) { + if (k.equalsIgnoreCase(target)) return obj.get(k); + } + } + return null; + } + + /** + * Extracts the message content object from a Bedrock response. Handles both success (message + * object) and error (Message as String) responses. Supports output/Output (AWS PascalCase) and + * case-insensitive keys for FMIS/global endpoints. + */ + private static JSONObject extractMessageObject(JSONObject response) { + Object msg = null; + String errMsg = null; + JSONObject outputObj = + response.has("output") + ? response.getJSONObject("output") + : response.has("Output") ? response.optJSONObject("Output") : null; + if (outputObj != null) { + msg = getKeyIgnoreCase(outputObj, "message", "Message"); + if (msg == null + && outputObj.has("choices") + && outputObj.getJSONArray("choices").length() > 0) { + JSONObject first = outputObj.getJSONArray("choices").getJSONObject(0); + msg = getKeyIgnoreCase(first, "message", "Message"); + } + if (msg == null && outputObj.has("content")) { + msg = outputObj; + } + } + if (msg == null) { + msg = getKeyIgnoreCase(response, "message", "Message"); + } + if (msg instanceof String) { + errMsg = (String) msg; + } else if (msg instanceof JSONObject) { + return (JSONObject) msg; + } + if (errMsg != null) { + throw new IllegalStateException("Bedrock API error: " + errMsg); + } + if (outputObj != null) { + logger.debug("Bedrock Output keys (extraction failed): {}", outputObj.keySet()); + } + return null; + } + public Flowable generateContentStream(LlmRequest llmRequest) { List contents = llmRequest.contents(); // Last content must be from the user, otherwise the model won't respond. @@ -578,13 +696,31 @@ private Flowable createRobustStreamingResponse( } JSONObject message = null; - if (responseJson.has("output")) { - JSONObject output = responseJson.getJSONObject("output"); - if (output.has("message")) { - message = output.getJSONObject("message"); + Object msgVal = null; + JSONObject outputObj = + responseJson.has("output") + ? responseJson.optJSONObject("output") + : responseJson.has("Output") ? responseJson.optJSONObject("Output") : null; + if (outputObj != null) { + msgVal = getKeyIgnoreCase(outputObj, "message", "Message"); + if (msgVal == null + && outputObj.has("choices") + && outputObj.getJSONArray("choices").length() > 0) { + JSONObject first = outputObj.getJSONArray("choices").getJSONObject(0); + msgVal = getKeyIgnoreCase(first, "message", "Message"); + } + if (msgVal == null && outputObj.has("content")) { + msgVal = outputObj; } - } else if (responseJson.has("message")) { - message = responseJson.getJSONObject("message"); + } + if (msgVal == null) { + msgVal = getKeyIgnoreCase(responseJson, "message", "Message"); + } + if (msgVal instanceof String) { + emitter.onError(new IllegalStateException("Bedrock API error: " + msgVal)); + return reader; + } else if (msgVal instanceof JSONObject) { + message = (JSONObject) msgVal; } // Accumulate all text from this response chunk @@ -757,9 +893,14 @@ private LlmResponse createTextResponse(String text, boolean partial) { public BufferedReader callLLMChatStream(String model, JSONArray messages, JSONArray tools) { try { - String apiUrl = - (D_URL != null ? D_URL : System.getenv(BEDROCK_ENV_VAR)) + "/" + model + "/converse"; - String AWS_BEARER_TOKEN_BEDROCK = System.getenv("AWS_BEARER_TOKEN_BEDROCK"); + String bearerToken = getBearerToken(); + if (bearerToken == null || bearerToken.isBlank()) { + throw new IllegalStateException( + "Bedrock Bearer token not found. Set one of: AWS_BEARER_TOKEN_BEDROCK, " + + "BEDROCK_BEARER_TOKEN, BEDROCK_API_KEY, BEDROCK_TOKEN (e.g. in .bashrc)"); + } + String baseUrl = getBedrockBaseUrl(D_URL); + String apiUrl = buildConverseUrl(baseUrl, model); System.out.println("Using Bedrock URL: " + apiUrl); JSONObject payload = new JSONObject(); // Model already encoded in path; omit 'model' field to avoid Unexpected field type errors @@ -775,7 +916,7 @@ public BufferedReader callLLMChatStream(String model, JSONArray messages, JSONAr HttpURLConnection connection = (HttpURLConnection) url.openConnection(); connection.setRequestMethod("POST"); connection.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); - connection.setRequestProperty("Authorization", "Bearer " + AWS_BEARER_TOKEN_BEDROCK); + connection.setRequestProperty("Authorization", "Bearer " + bearerToken); connection.setDoOutput(true); connection.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length); @@ -885,11 +1026,17 @@ public static Part ollamaContentBlockToPart(JSONObject blockJson) { */ public JSONObject callLLMChat(String model, JSONArray messages, JSONArray tools) { try { + String bearerToken = getBearerToken(); + if (bearerToken == null || bearerToken.isBlank()) { + throw new IllegalStateException( + "Bedrock Bearer token not found. Set one of: AWS_BEARER_TOKEN_BEDROCK, " + + "BEDROCK_BEARER_TOKEN, BEDROCK_API_KEY, BEDROCK_TOKEN (e.g. in .bashrc)"); + } + String baseUrl = getBedrockBaseUrl(D_URL); JSONObject responseJ = new JSONObject(); - String apiUrl = D_URL != null ? D_URL : System.getenv(BEDROCK_ENV_VAR); - String AWS_BEARER_TOKEN_BEDROCK = System.getenv("AWS_BEARER_TOKEN_BEDROCK"); + String apiUrl = buildConverseUrl(baseUrl, model); JSONObject payload = new JSONObject(); - payload.put("model", model); + // Model already in path; omit from payload to avoid "Unexpected field type" errors payload.put("stream", false); payload.put("messages", messages); if (tools != null) { @@ -901,7 +1048,7 @@ public JSONObject callLLMChat(String model, JSONArray messages, JSONArray tools) HttpURLConnection connection = (HttpURLConnection) url.openConnection(); connection.setRequestMethod("POST"); connection.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); - connection.setRequestProperty("Authorization", "Bearer " + AWS_BEARER_TOKEN_BEDROCK); + connection.setRequestProperty("Authorization", "Bearer " + bearerToken); connection.setDoOutput(true); connection.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length); @@ -958,9 +1105,12 @@ public static JSONObject callLLMChat( boolean stream, String prompt, String model, JSONArray messages, JSONArray tools) { JSONObject responseJ = new JSONObject(); try { - String apiUrl = System.getenv(BEDROCK_ENV_VAR); - String AWS_BEARER_TOKEN_BEDROCK = System.getenv(BEDROCK_ENV_VAR); - apiUrl = apiUrl + "/api/chat"; + String bearerToken = getBearerToken(); + if (bearerToken == null || bearerToken.isBlank()) { + throw new IllegalStateException( + "Bedrock Bearer token not found. Set AWS_BEARER_TOKEN_BEDROCK, BEDROCK_BEARER_TOKEN, etc."); + } + String apiUrl = getBedrockBaseUrl(null) + "/api/chat"; JSONObject payload = new JSONObject(); payload.put("model", model); payload.put("stream", false); @@ -978,7 +1128,7 @@ public static JSONObject callLLMChat( // System.out.print("HTTP Connection to Ollama API: " + apiUrl.toString()); connection.setRequestMethod("POST"); connection.setRequestProperty("Content-Type", "application/json"); - connection.setRequestProperty("Authorization", "Bearer " + AWS_BEARER_TOKEN_BEDROCK); + connection.setRequestProperty("Authorization", "Bearer " + bearerToken); connection.setDoOutput(true); connection.setFixedLengthStreamingMode(jsonString.getBytes().length); try (DataOutputStream outputStream = new DataOutputStream(connection.getOutputStream())) { @@ -1070,8 +1220,14 @@ public Flowable generateContent( return Flowable.create( emitter -> { try { - String apiUrl = D_URL != null ? D_URL : System.getenv(BEDROCK_ENV_VAR); - String AWS_BEARER_TOKEN_BEDROCK = System.getenv("AWS_BEARER_TOKEN_BEDROCK"); + String bearerToken = getBearerToken(); + if (bearerToken == null || bearerToken.isBlank()) { + emitter.onError( + new IllegalStateException( + "Bedrock Bearer token not found. Set AWS_BEARER_TOKEN_BEDROCK, BEDROCK_BEARER_TOKEN, etc.")); + return; + } + String apiUrl = getBedrockBaseUrl(D_URL); JSONObject payload = new JSONObject(); payload.put("messages", messages); if (tools != null) { @@ -1082,7 +1238,7 @@ public Flowable generateContent( HttpURLConnection connection = (HttpURLConnection) url.openConnection(); connection.setRequestMethod("POST"); connection.setRequestProperty("Content-Type", "application/json; charset=UTF-8"); - connection.setRequestProperty("Authorization", "Bearer " + AWS_BEARER_TOKEN_BEDROCK); + connection.setRequestProperty("Authorization", "Bearer " + bearerToken); connection.setDoOutput(true); connection.setFixedLengthStreamingMode(jsonString.getBytes("UTF-8").length); try (OutputStream outputStream = connection.getOutputStream();