Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,7 @@ public PostgresArtifactService(String dbUrl, String dbUser, String dbPassword) {
@Override
public Single<Integer> saveArtifact(
String appName, String userId, String sessionId, String filename, Part artifact) {
return Single.fromCallable(
() -> {
try {
// Extract data from Part
byte[] data = extractBytesFromPart(artifact);
String mimeType = extractMimeTypeFromPart(artifact);

// Save to database without metadata (metadata = null)
// Applications should use saveArtifact(..., metadata) if they need custom metadata
return dbHelper.saveArtifact(
appName, userId, sessionId, filename, data, mimeType, null);
} catch (SQLException e) {
throw new RuntimeException("Failed to save artifact: " + e.getMessage(), e);
}
})
.subscribeOn(Schedulers.io());
return saveArtifact(appName, userId, sessionId, filename, artifact, null, null);
}

/**
Expand Down Expand Up @@ -147,16 +132,51 @@ public Single<Integer> saveArtifact(
String filename,
Part artifact,
String metadata) {
return saveArtifact(appName, userId, sessionId, filename, artifact, metadata, null);
}

/**
* Save an artifact with custom metadata and invocation ID.
*
* <p>This overloaded method allows the caller to provide both metadata and the invocation ID that
* produced this artifact. The invocation ID links the artifact to the specific agent invocation
* for traceability, debugging, cost attribution, and rollback cleanup.
*
* <p>Example usage:
*
* <pre>{@code
* String metadata = "{\"projectId\":\"ABC\",\"cost\":0.005}";
* String invocationId = invocationContext.invocationId();
* artifactService.saveArtifact(appName, userId, sessionId, filename, part, metadata, invocationId);
* }</pre>
*
* @param appName the application name
* @param userId the user ID
* @param sessionId the session ID
* @param filename the artifact filename
* @param artifact the artifact as a Part object
* @param metadata custom metadata JSON string (can be null)
* @param invocationId the invocation ID that produced this artifact (can be null)
* @return a Single emitting the version number of the saved artifact
*/
public Single<Integer> saveArtifact(
String appName,
String userId,
String sessionId,
String filename,
Part artifact,
String metadata,
String invocationId) {
return Single.fromCallable(
() -> {
try {
// Extract data from Part
byte[] data = extractBytesFromPart(artifact);
String mimeType = extractMimeTypeFromPart(artifact);

// Save to database with caller-provided metadata
// Save to database with caller-provided metadata and invocation ID
return dbHelper.saveArtifact(
appName, userId, sessionId, filename, data, mimeType, metadata);
appName, userId, sessionId, filename, data, mimeType, metadata, invocationId);
} catch (SQLException e) {
throw new RuntimeException("Failed to save artifact: " + e.getMessage(), e);
}
Expand All @@ -167,13 +187,34 @@ public Single<Integer> saveArtifact(
@Override
public Maybe<Part> loadArtifact(
String appName, String userId, String sessionId, String filename, Optional<Integer> version) {
return loadArtifact(appName, userId, sessionId, filename, version, null);
}

/**
* Load an artifact by version or latest, optionally filtered by invocation ID.
*
* @param appName the application name
* @param userId the user ID
* @param sessionId the session ID
* @param filename the artifact filename
* @param version the version number, or empty for latest
* @param invocationId the invocation ID to filter by, or null for no filter
* @return a Maybe emitting the Part if found
*/
public Maybe<Part> loadArtifact(
String appName,
String userId,
String sessionId,
String filename,
Optional<Integer> version,
String invocationId) {
return Maybe.fromCallable(
() -> {
try {
// Load from database
// Load from database with optional invocation ID filter
ArtifactData artifactData =
dbHelper.loadArtifact(
appName, userId, sessionId, filename, version.orElse(null));
appName, userId, sessionId, filename, version.orElse(null), invocationId);

if (artifactData == null) {
return null;
Expand Down
140 changes: 106 additions & 34 deletions core/src/main/java/com/google/adk/store/PostgresArtifactStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public int saveArtifact(
byte[] data,
String mimeType)
throws SQLException {
return saveArtifact(appName, userId, sessionId, filename, data, mimeType, null);
return saveArtifact(appName, userId, sessionId, filename, data, mimeType, null, null);
}

/**
Expand All @@ -247,14 +247,46 @@ public int saveArtifact(
String mimeType,
String metadata)
throws SQLException {
return saveArtifact(appName, userId, sessionId, filename, data, mimeType, metadata, null);
}

/**
* Save artifact to database with metadata and invocation ID. Returns the assigned version number.
*
* <p>The invocation ID links the artifact to the specific agent invocation that produced it,
* enabling traceability, debugging, cost attribution, and cleanup of artifacts from
* failed/rolled-back invocations.
*
* @param appName the application name
* @param userId the user ID
* @param sessionId the session ID
* @param filename the artifact filename
* @param data the artifact binary data
* @param mimeType the MIME type
* @param metadata the metadata JSON string (can be null)
* @param invocationId the invocation ID that produced this artifact (can be null)
* @return the version number assigned to this artifact
* @throws SQLException if save operation fails
*/
public int saveArtifact(
String appName,
String userId,
String sessionId,
String filename,
byte[] data,
String mimeType,
String metadata,
String invocationId)
throws SQLException {
logger.debug(
"Saving artifact: app={}, user={}, session={}, file={}, size={}KB, mime={}",
"Saving artifact: app={}, user={}, session={}, file={}, size={}KB, mime={}, invocationId={}",
appName,
userId,
sessionId,
filename,
data.length / 1024,
mimeType);
mimeType,
invocationId);

Connection conn = null;
try {
Expand All @@ -267,8 +299,8 @@ public int saveArtifact(

String sql =
String.format(
"INSERT INTO %s (app_name, user_id, session_id, filename, version, mime_type, data, metadata) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?::jsonb)",
"INSERT INTO %s (app_name, user_id, session_id, filename, version, mime_type, data, metadata, invocation_id) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?::jsonb, ?)",
tableName);

try (PreparedStatement pstmt = conn.prepareStatement(sql)) {
Expand All @@ -280,6 +312,7 @@ public int saveArtifact(
pstmt.setString(6, mimeType);
pstmt.setBytes(7, data);
pstmt.setString(8, metadata);
pstmt.setString(9, invocationId);

int rowsAffected = pstmt.executeUpdate();

Expand All @@ -288,13 +321,14 @@ public int saveArtifact(
conn.commit();

logger.info(
"✅ Artifact saved: app={}, user={}, session={}, file={}, version={}, size={}KB",
"✅ Artifact saved: app={}, user={}, session={}, file={}, version={}, size={}KB, invocationId={}",
appName,
userId,
sessionId,
filename,
nextVersion,
data.length / 1024);
data.length / 1024,
invocationId);
return nextVersion;
} else {
conn.rollback();
Expand Down Expand Up @@ -408,40 +442,68 @@ private int getNextVersion(
public ArtifactData loadArtifact(
String appName, String userId, String sessionId, String filename, Integer version)
throws SQLException {
return loadArtifact(appName, userId, sessionId, filename, version, null);
}

/**
* Load artifact by version or latest, optionally filtered by invocation ID. Returns ArtifactData
* object or null if not found.
*
* @param appName the application name
* @param userId the user ID
* @param sessionId the session ID
* @param filename the artifact filename
* @param version the version number, or null for latest
* @param invocationId the invocation ID to filter by, or null for no filter
* @return ArtifactData object or null if not found
* @throws SQLException if load operation fails
*/
public ArtifactData loadArtifact(
String appName,
String userId,
String sessionId,
String filename,
Integer version,
String invocationId)
throws SQLException {
logger.debug(
"Loading artifact: app={}, user={}, session={}, file={}, version={}",
"Loading artifact: app={}, user={}, session={}, file={}, version={}, invocationId={}",
appName,
userId,
sessionId,
filename,
version != null ? version : "latest");
version != null ? version : "latest",
invocationId != null ? invocationId : "any");

StringBuilder sql = new StringBuilder();
sql.append("SELECT data, mime_type, version, created_at, metadata, invocation_id FROM ")
.append(tableName)
.append(" WHERE app_name = ? AND user_id = ? AND session_id = ? AND filename = ?");

String sql;
if (version != null) {
// Load specific version
sql =
String.format(
"SELECT data, mime_type, version, created_at, metadata FROM %s "
+ "WHERE app_name = ? AND user_id = ? AND session_id = ? AND filename = ? AND version = ?",
tableName);
} else {
// Load latest version
sql =
String.format(
"SELECT data, mime_type, version, created_at, metadata FROM %s "
+ "WHERE app_name = ? AND user_id = ? AND session_id = ? AND filename = ? "
+ "ORDER BY version DESC LIMIT 1",
tableName);
sql.append(" AND version = ?");
}
if (invocationId != null) {
sql.append(" AND invocation_id = ?");
}

if (version == null) {
sql.append(" ORDER BY version DESC LIMIT 1");
}

try (Connection conn = getConnection();
PreparedStatement pstmt = conn.prepareStatement(sql)) {
pstmt.setString(1, appName);
pstmt.setString(2, userId);
pstmt.setString(3, sessionId);
pstmt.setString(4, filename);
PreparedStatement pstmt = conn.prepareStatement(sql.toString())) {
int paramIdx = 1;
pstmt.setString(paramIdx++, appName);
pstmt.setString(paramIdx++, userId);
pstmt.setString(paramIdx++, sessionId);
pstmt.setString(paramIdx++, filename);

if (version != null) {
pstmt.setInt(5, version);
pstmt.setInt(paramIdx++, version);
}
if (invocationId != null) {
pstmt.setString(paramIdx++, invocationId);
}

try (ResultSet rs = pstmt.executeQuery()) {
Expand All @@ -451,17 +513,20 @@ public ArtifactData loadArtifact(
int loadedVersion = rs.getInt("version");
Timestamp createdAt = rs.getTimestamp("created_at");
String metadata = rs.getString("metadata");
String resultInvocationId = rs.getString("invocation_id");

logger.info(
"✅ Artifact loaded: app={}, user={}, session={}, file={}, version={}, size={}KB",
"✅ Artifact loaded: app={}, user={}, session={}, file={}, version={}, size={}KB, invocationId={}",
appName,
userId,
sessionId,
filename,
loadedVersion,
data.length / 1024);
data.length / 1024,
resultInvocationId);

return new ArtifactData(data, mimeType, loadedVersion, createdAt, metadata);
return new ArtifactData(
data, mimeType, loadedVersion, createdAt, metadata, resultInvocationId);
} else {
logger.warn(
"⚠️ Artifact not found: app={}, user={}, session={}, file={}, version={}",
Expand Down Expand Up @@ -668,14 +733,21 @@ public static class ArtifactData {
public final int version;
public final Timestamp createdAt;
public final String metadata;
public final String invocationId;

public ArtifactData(
byte[] data, String mimeType, int version, Timestamp createdAt, String metadata) {
byte[] data,
String mimeType,
int version,
Timestamp createdAt,
String metadata,
String invocationId) {
this.data = data;
this.mimeType = mimeType;
this.version = version;
this.createdAt = createdAt;
this.metadata = metadata;
this.invocationId = invocationId;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ public void testLoadArtifact_LatestVersion() throws Exception {

ArtifactData artifactData =
new ArtifactData(
contentBytes, "text/plain", 1, new Timestamp(System.currentTimeMillis()), null);
contentBytes, "text/plain", 1, new Timestamp(System.currentTimeMillis()), null, null);

when(mockStore.loadArtifact(eq(appName), eq(userId), eq(sessionId), eq(filename), isNull()))
.thenReturn(artifactData);
Expand Down Expand Up @@ -200,7 +200,12 @@ public void testLoadArtifact_SpecificVersion() throws Exception {

ArtifactData artifactData =
new ArtifactData(
contentBytes, "text/plain", version, new Timestamp(System.currentTimeMillis()), null);
contentBytes,
"text/plain",
version,
new Timestamp(System.currentTimeMillis()),
null,
null);

when(mockStore.loadArtifact(eq(appName), eq(userId), eq(sessionId), eq(filename), eq(version)))
.thenReturn(artifactData);
Expand Down
Loading