diff --git a/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java b/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java index a47e13fb1..fa34b64ea 100644 --- a/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/PostgresArtifactService.java @@ -100,22 +100,7 @@ public PostgresArtifactService(String dbUrl, String dbUser, String dbPassword) { @Override public Single saveArtifact( String appName, String userId, String sessionId, String filename, Part artifact) { - return Single.fromCallable( - () -> { - try { - // Extract data from Part - byte[] data = extractBytesFromPart(artifact); - String mimeType = extractMimeTypeFromPart(artifact); - - // Save to database without metadata (metadata = null) - // Applications should use saveArtifact(..., metadata) if they need custom metadata - return dbHelper.saveArtifact( - appName, userId, sessionId, filename, data, mimeType, null); - } catch (SQLException e) { - throw new RuntimeException("Failed to save artifact: " + e.getMessage(), e); - } - }) - .subscribeOn(Schedulers.io()); + return saveArtifact(appName, userId, sessionId, filename, artifact, null, null); } /** @@ -147,6 +132,41 @@ public Single saveArtifact( String filename, Part artifact, String metadata) { + return saveArtifact(appName, userId, sessionId, filename, artifact, metadata, null); + } + + /** + * Save an artifact with custom metadata and invocation ID. + * + *

This overloaded method allows the caller to provide both metadata and the invocation ID that + * produced this artifact. The invocation ID links the artifact to the specific agent invocation + * for traceability, debugging, cost attribution, and rollback cleanup. + * + *

Example usage: + * + *

{@code
+   * String metadata = "{\"projectId\":\"ABC\",\"cost\":0.005}";
+   * String invocationId = invocationContext.invocationId();
+   * artifactService.saveArtifact(appName, userId, sessionId, filename, part, metadata, invocationId);
+   * }
+ * + * @param appName the application name + * @param userId the user ID + * @param sessionId the session ID + * @param filename the artifact filename + * @param artifact the artifact as a Part object + * @param metadata custom metadata JSON string (can be null) + * @param invocationId the invocation ID that produced this artifact (can be null) + * @return a Single emitting the version number of the saved artifact + */ + public Single saveArtifact( + String appName, + String userId, + String sessionId, + String filename, + Part artifact, + String metadata, + String invocationId) { return Single.fromCallable( () -> { try { @@ -154,9 +174,9 @@ public Single saveArtifact( byte[] data = extractBytesFromPart(artifact); String mimeType = extractMimeTypeFromPart(artifact); - // Save to database with caller-provided metadata + // Save to database with caller-provided metadata and invocation ID return dbHelper.saveArtifact( - appName, userId, sessionId, filename, data, mimeType, metadata); + appName, userId, sessionId, filename, data, mimeType, metadata, invocationId); } catch (SQLException e) { throw new RuntimeException("Failed to save artifact: " + e.getMessage(), e); } @@ -167,13 +187,34 @@ public Single saveArtifact( @Override public Maybe loadArtifact( String appName, String userId, String sessionId, String filename, Optional version) { + return loadArtifact(appName, userId, sessionId, filename, version, null); + } + + /** + * Load an artifact by version or latest, optionally filtered by invocation ID. + * + * @param appName the application name + * @param userId the user ID + * @param sessionId the session ID + * @param filename the artifact filename + * @param version the version number, or empty for latest + * @param invocationId the invocation ID to filter by, or null for no filter + * @return a Maybe emitting the Part if found + */ + public Maybe loadArtifact( + String appName, + String userId, + String sessionId, + String filename, + Optional version, + String invocationId) { return Maybe.fromCallable( () -> { try { - // Load from database + // Load from database with optional invocation ID filter ArtifactData artifactData = dbHelper.loadArtifact( - appName, userId, sessionId, filename, version.orElse(null)); + appName, userId, sessionId, filename, version.orElse(null), invocationId); if (artifactData == null) { return null; diff --git a/core/src/main/java/com/google/adk/store/PostgresArtifactStore.java b/core/src/main/java/com/google/adk/store/PostgresArtifactStore.java index d31be8676..629aa5d82 100644 --- a/core/src/main/java/com/google/adk/store/PostgresArtifactStore.java +++ b/core/src/main/java/com/google/adk/store/PostgresArtifactStore.java @@ -222,7 +222,7 @@ public int saveArtifact( byte[] data, String mimeType) throws SQLException { - return saveArtifact(appName, userId, sessionId, filename, data, mimeType, null); + return saveArtifact(appName, userId, sessionId, filename, data, mimeType, null, null); } /** @@ -247,14 +247,46 @@ public int saveArtifact( String mimeType, String metadata) throws SQLException { + return saveArtifact(appName, userId, sessionId, filename, data, mimeType, metadata, null); + } + + /** + * Save artifact to database with metadata and invocation ID. Returns the assigned version number. + * + *

The invocation ID links the artifact to the specific agent invocation that produced it, + * enabling traceability, debugging, cost attribution, and cleanup of artifacts from + * failed/rolled-back invocations. + * + * @param appName the application name + * @param userId the user ID + * @param sessionId the session ID + * @param filename the artifact filename + * @param data the artifact binary data + * @param mimeType the MIME type + * @param metadata the metadata JSON string (can be null) + * @param invocationId the invocation ID that produced this artifact (can be null) + * @return the version number assigned to this artifact + * @throws SQLException if save operation fails + */ + public int saveArtifact( + String appName, + String userId, + String sessionId, + String filename, + byte[] data, + String mimeType, + String metadata, + String invocationId) + throws SQLException { logger.debug( - "Saving artifact: app={}, user={}, session={}, file={}, size={}KB, mime={}", + "Saving artifact: app={}, user={}, session={}, file={}, size={}KB, mime={}, invocationId={}", appName, userId, sessionId, filename, data.length / 1024, - mimeType); + mimeType, + invocationId); Connection conn = null; try { @@ -267,8 +299,8 @@ public int saveArtifact( String sql = String.format( - "INSERT INTO %s (app_name, user_id, session_id, filename, version, mime_type, data, metadata) " - + "VALUES (?, ?, ?, ?, ?, ?, ?, ?::jsonb)", + "INSERT INTO %s (app_name, user_id, session_id, filename, version, mime_type, data, metadata, invocation_id) " + + "VALUES (?, ?, ?, ?, ?, ?, ?, ?::jsonb, ?)", tableName); try (PreparedStatement pstmt = conn.prepareStatement(sql)) { @@ -280,6 +312,7 @@ public int saveArtifact( pstmt.setString(6, mimeType); pstmt.setBytes(7, data); pstmt.setString(8, metadata); + pstmt.setString(9, invocationId); int rowsAffected = pstmt.executeUpdate(); @@ -288,13 +321,14 @@ public int saveArtifact( conn.commit(); logger.info( - "✅ Artifact saved: app={}, user={}, session={}, file={}, version={}, size={}KB", + "✅ Artifact saved: app={}, user={}, session={}, file={}, version={}, size={}KB, invocationId={}", appName, userId, sessionId, filename, nextVersion, - data.length / 1024); + data.length / 1024, + invocationId); return nextVersion; } else { conn.rollback(); @@ -408,40 +442,68 @@ private int getNextVersion( public ArtifactData loadArtifact( String appName, String userId, String sessionId, String filename, Integer version) throws SQLException { + return loadArtifact(appName, userId, sessionId, filename, version, null); + } + + /** + * Load artifact by version or latest, optionally filtered by invocation ID. Returns ArtifactData + * object or null if not found. + * + * @param appName the application name + * @param userId the user ID + * @param sessionId the session ID + * @param filename the artifact filename + * @param version the version number, or null for latest + * @param invocationId the invocation ID to filter by, or null for no filter + * @return ArtifactData object or null if not found + * @throws SQLException if load operation fails + */ + public ArtifactData loadArtifact( + String appName, + String userId, + String sessionId, + String filename, + Integer version, + String invocationId) + throws SQLException { logger.debug( - "Loading artifact: app={}, user={}, session={}, file={}, version={}", + "Loading artifact: app={}, user={}, session={}, file={}, version={}, invocationId={}", appName, userId, sessionId, filename, - version != null ? version : "latest"); + version != null ? version : "latest", + invocationId != null ? invocationId : "any"); + + StringBuilder sql = new StringBuilder(); + sql.append("SELECT data, mime_type, version, created_at, metadata, invocation_id FROM ") + .append(tableName) + .append(" WHERE app_name = ? AND user_id = ? AND session_id = ? AND filename = ?"); - String sql; if (version != null) { - // Load specific version - sql = - String.format( - "SELECT data, mime_type, version, created_at, metadata FROM %s " - + "WHERE app_name = ? AND user_id = ? AND session_id = ? AND filename = ? AND version = ?", - tableName); - } else { - // Load latest version - sql = - String.format( - "SELECT data, mime_type, version, created_at, metadata FROM %s " - + "WHERE app_name = ? AND user_id = ? AND session_id = ? AND filename = ? " - + "ORDER BY version DESC LIMIT 1", - tableName); + sql.append(" AND version = ?"); + } + if (invocationId != null) { + sql.append(" AND invocation_id = ?"); + } + + if (version == null) { + sql.append(" ORDER BY version DESC LIMIT 1"); } try (Connection conn = getConnection(); - PreparedStatement pstmt = conn.prepareStatement(sql)) { - pstmt.setString(1, appName); - pstmt.setString(2, userId); - pstmt.setString(3, sessionId); - pstmt.setString(4, filename); + PreparedStatement pstmt = conn.prepareStatement(sql.toString())) { + int paramIdx = 1; + pstmt.setString(paramIdx++, appName); + pstmt.setString(paramIdx++, userId); + pstmt.setString(paramIdx++, sessionId); + pstmt.setString(paramIdx++, filename); + if (version != null) { - pstmt.setInt(5, version); + pstmt.setInt(paramIdx++, version); + } + if (invocationId != null) { + pstmt.setString(paramIdx++, invocationId); } try (ResultSet rs = pstmt.executeQuery()) { @@ -451,17 +513,20 @@ public ArtifactData loadArtifact( int loadedVersion = rs.getInt("version"); Timestamp createdAt = rs.getTimestamp("created_at"); String metadata = rs.getString("metadata"); + String resultInvocationId = rs.getString("invocation_id"); logger.info( - "✅ Artifact loaded: app={}, user={}, session={}, file={}, version={}, size={}KB", + "✅ Artifact loaded: app={}, user={}, session={}, file={}, version={}, size={}KB, invocationId={}", appName, userId, sessionId, filename, loadedVersion, - data.length / 1024); + data.length / 1024, + resultInvocationId); - return new ArtifactData(data, mimeType, loadedVersion, createdAt, metadata); + return new ArtifactData( + data, mimeType, loadedVersion, createdAt, metadata, resultInvocationId); } else { logger.warn( "⚠️ Artifact not found: app={}, user={}, session={}, file={}, version={}", @@ -668,14 +733,21 @@ public static class ArtifactData { public final int version; public final Timestamp createdAt; public final String metadata; + public final String invocationId; public ArtifactData( - byte[] data, String mimeType, int version, Timestamp createdAt, String metadata) { + byte[] data, + String mimeType, + int version, + Timestamp createdAt, + String metadata, + String invocationId) { this.data = data; this.mimeType = mimeType; this.version = version; this.createdAt = createdAt; this.metadata = metadata; + this.invocationId = invocationId; } } } diff --git a/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java b/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java index 509f72d9f..870f984bd 100644 --- a/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java +++ b/core/src/test/java/com/google/adk/artifacts/PostgresArtifactServiceTest.java @@ -169,7 +169,7 @@ public void testLoadArtifact_LatestVersion() throws Exception { ArtifactData artifactData = new ArtifactData( - contentBytes, "text/plain", 1, new Timestamp(System.currentTimeMillis()), null); + contentBytes, "text/plain", 1, new Timestamp(System.currentTimeMillis()), null, null); when(mockStore.loadArtifact(eq(appName), eq(userId), eq(sessionId), eq(filename), isNull())) .thenReturn(artifactData); @@ -200,7 +200,12 @@ public void testLoadArtifact_SpecificVersion() throws Exception { ArtifactData artifactData = new ArtifactData( - contentBytes, "text/plain", version, new Timestamp(System.currentTimeMillis()), null); + contentBytes, + "text/plain", + version, + new Timestamp(System.currentTimeMillis()), + null, + null); when(mockStore.loadArtifact(eq(appName), eq(userId), eq(sessionId), eq(filename), eq(version))) .thenReturn(artifactData);