Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# AI Agents Style Guide

When modifying Java code in this repository, please adhere to the following guidelines:

## Code Style

- **Avoid Fully Qualified Names (FQNs)**: Do not use fully qualified class names directly in code (e.g., `java.util.Base64`). Always use standard `import` statements at the top of the file to import the necessary classes and refer to them by their simple names.
- **Dependency Injection**: Favor Dependency Injection (Dagger) over static singletons or utility classes. Create injectable services with `@Singleton` and `@Inject` constructors where applicable.
- **Constants**: Extract reused string literals and configuration keys into `public static final String` constants rather than hardcoding them multiple times.
- **Validation**: Ensure settings and inputs are appropriately validated, providing standard exception types (`IllegalArgumentException`) for malformed inputs, such as GCP resource names.
- **Documentation**: Provide examples for complex configuration values or properties (like GCP KMS keys or URIs) inside JavaDocs.
48 changes: 48 additions & 0 deletions docs/cmek-encryption.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Customer-Managed Encryption Keys (CMEK) Support

## Overview
This design document describes the support for encrypting potentially sensitive job parameters using Customer-Managed Encryption Keys (CMEK) via Google Cloud KMS within the pipeline framework.

The primary use case is to ensure that sensitive job parameters are encrypted with customer-specific keys, even when the pipelines are running in shared task queues or when arbitrary content for the task is stored in the datastore records for the job/pipeline.

## Design

### Job Settings

A new job setting called `EncryptionKey` has been added. It is a `StringValuedSetting` that holds the full Google Cloud KMS Key Name.
When creating a pipeline or a job, developers can pass the `EncryptionKey` setting:

```java
JobSetting[] settings = new JobSetting[] {
new JobSetting.EncryptionKey("projects/my-project/locations/global/keyRings/my-keyring/cryptoKeys/my-key")
};
```

This setting is stored within the JobRecord's `QueueSettings` and passed along to any spawned pipeline tasks.

### Encryption at Enqueue Time

When a job creates tasks that need to be enqueued (via `PipelineTask.toTaskSpec()`), it checks if an `EncryptionKey` is present in the `QueueSettings`.

If present, the framework:
1. Translates the task's properties into a JSON object.
2. Encrypts the JSON representation using the configured GCP KMS key using a utility class (`CmekUtils`).
3. Base64 encodes the ciphertext.
4. Uses a single parameter `_encrypted_payload` to hold the encrypted Base64 string for the `POST` request payload.
5. Adds an HTTP Header `X-Pipeline-EncryptionKey` carrying the KMS key name so the receiver knows which key to use for decryption.

### Decryption at Task Execution Time

When the task is received by `TaskHandler.java`:
1. It checks for the `X-Pipeline-EncryptionKey` header.
2. If the header and the `_encrypted_payload` parameter exist, it decrypts the Base64 decoded payload using the specified key via `CmekUtils`.
3. The resulting decrypted JSON is then parsed back into standard task parameters and the pipeline framework execution proceeds normally without any need to alter internal logic.

## Datastore Consideration

The current CMEK support directly encrypts the task payload (POST parameters) going into Cloud Tasks or App Engine Task Queues.
For encrypting parameters saved to the Datastore, developers can implement an `EncryptionSerializationStrategy` or explicitly encrypt their `Value`s before feeding them into the pipeline, to ensure data stored at rest in the Datastore uses CMEK. The current explicit encryption guarantees that any data placed in the Task Queues is securely encrypted under the provided CMEK.

## Dependencies

The implementation binds to the `google-cloud-kms` library to perform KMS operations.
4 changes: 4 additions & 0 deletions java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-storage</artifactId>
</dependency>
<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-kms</artifactId>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,22 @@ public DatastoreNamespace(String datastoreNameSpace) {
}
}

/**
* A setting specifying a CMEK for encrypting task parameters.
* Example: "projects/my-project/locations/global/keyRings/my-keyring/cryptoKeys/my-key"
*/
final class EncryptionKey extends StringValuedSetting {
@Serial
private static final long serialVersionUID = -2L;

public EncryptionKey(String encryptionKey) {
super(encryptionKey);
if (encryptionKey != null && !encryptionKey.matches("projects/[^/]+/locations/[^/]+/keyRings/[^/]+/cryptoKeys/[^/]+")) {
throw new IllegalArgumentException("EncryptionKey must match the format: projects/{project}/locations/{location}/keyRings/{keyRing}/cryptoKeys/{cryptoKey}");
}
}
}


static <E extends StringValuedSetting> Optional<String> getSettingValue(Class<E> clazz, JobSetting[] settings) {
return Arrays.stream(settings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ public final class QueueSettings implements Cloneable {
*/
private Long delayInSeconds;

/**
* KMS Key Name for CMEK encryption of task payload
* Example: "projects/my-project/locations/global/keyRings/my-keyring/cryptoKeys/my-key"
*/
private String encryptionKey;

/**
* Merge will override any {@code null} setting with a matching setting from {@code other}.
* Note, delay value is not being merged.
Expand All @@ -49,6 +55,9 @@ public QueueSettings merge(QueueSettings other) {
if (onQueue == null) {
onQueue = other.getOnQueue();
}
if (encryptionKey == null) {
encryptionKey = other.getEncryptionKey();
}
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.appengine.tools.pipeline.impl.QueueSettings;
import com.google.appengine.tools.pipeline.impl.servlets.TaskHandler;
import com.google.appengine.tools.pipeline.impl.tasks.PipelineTask;
import com.google.appengine.tools.pipeline.impl.util.KmsService;
import com.google.apphosting.api.ApiProxy;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
Expand Down Expand Up @@ -59,19 +60,22 @@ public class AppEngineTaskQueue implements PipelineTaskQueue {

final AppEngineEnvironment environment;
final AppEngineServicesService servicesService;
final KmsService kmsService;

final String taskHandlerUrl;

public AppEngineTaskQueue(AppEngineServicesService appEngineServicesService) {
this.environment = new AppEngineStandardGen2();
this.servicesService = appEngineServicesService;
this.kmsService = null;
this.taskHandlerUrl = TaskHandler.handleTaskUrl();
}

@Inject
public AppEngineTaskQueue(AppEngineEnvironment environment, AppEngineServicesService servicesService) {
public AppEngineTaskQueue(AppEngineEnvironment environment, AppEngineServicesService servicesService, KmsService kmsService) {
this.environment = environment;
this.servicesService = servicesService;
this.kmsService = kmsService;
this.taskHandlerUrl = TaskHandler.handleTaskUrl();
}

Expand Down Expand Up @@ -159,7 +163,7 @@ public Collection<TaskReference> enqueue(final Collection<PipelineTask> pipeline
public Multimap<String, TaskSpec> asTaskSpecs(Collection<PipelineTask> pipelineTasks) {
Multimap<String, TaskSpec> taskSpecs = HashMultimap.create();
pipelineTasks.forEach( pipelineTask -> {
taskSpecs.put(getQueueForTask(pipelineTask), pipelineTask.toTaskSpec(servicesService, taskHandlerUrl));
taskSpecs.put(getQueueForTask(pipelineTask), pipelineTask.toTaskSpec(servicesService, taskHandlerUrl, kmsService));
});
return taskSpecs;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.protobuf.ByteString;
import com.google.protobuf.Timestamp;
import com.google.appengine.tools.pipeline.impl.util.KmsService;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
Expand Down Expand Up @@ -57,6 +58,9 @@ public String getPropertyName() {
@NonNull
AppEngineServicesService appEngineServicesService;

@NonNull
KmsService kmsService;

// GAE location -> Cloud Tasks location name
Cache<String, String> locationCache =
CacheBuilder.newBuilder().initialCapacity(1).build();
Expand All @@ -76,7 +80,7 @@ public Collection<TaskReference> enqueue(Collection<PipelineTask> pipelineTasks)
.map(tasksForQueue -> {
Stream<TaskSpec> specs = tasksForQueue.getValue().stream()
.map(pipelineTask -> {
return pipelineTask.toTaskSpec(appEngineServicesService, TaskHandler.handleTaskUrl());
return pipelineTask.toTaskSpec(appEngineServicesService, TaskHandler.handleTaskUrl(), kmsService);
});
return enqueue(tasksForQueue.getKey(), specs.collect(Collectors.toList()));
})
Expand All @@ -89,7 +93,7 @@ public Multimap<String, TaskSpec> asTaskSpecs(Collection<PipelineTask> pipelineT
Multimap<String, TaskSpec> taskSpecs = HashMultimap.create();
pipelineTasks
.forEach(pipelineTask -> {
taskSpecs.put(getQueueForTask(pipelineTask), pipelineTask.toTaskSpec(appEngineServicesService, TaskHandler.handleTaskUrl()));
taskSpecs.put(getQueueForTask(pipelineTask), pipelineTask.toTaskSpec(appEngineServicesService, TaskHandler.handleTaskUrl(), kmsService));
});
return taskSpecs;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ public enum InflationType {
private static final String CHILD_GRAPH_GUID_PROPERTY = "childGraphGuid";
private static final String STATUS_CONSOLE_URL = "statusConsoleUrl";
public static final String ROOT_JOB_DISPLAY_NAME = "rootJobDisplayName";

public static final String IS_ROOT_JOB_PROPERTY = "isRootJob";
private static final String ENCRYPTION_KEY_PROPERTY = "encryptionKey";

/**
* projectId for job; must be set
Expand Down Expand Up @@ -287,6 +287,7 @@ public JobRecord(Entity entity) {
queueSettings.setOnService(EntityUtils.getString(entity, ON_SERVICE_PROPERTY));
queueSettings.setOnServiceVersion(EntityUtils.getString(entity, ON_SERVICE_VERSION_PROPERTY));
queueSettings.setOnQueue(EntityUtils.getString(entity, ON_QUEUE_PROPERTY));
queueSettings.setEncryptionKey(EntityUtils.getString(entity, ENCRYPTION_KEY_PROPERTY));

statusConsoleUrl = EntityUtils.getString(entity, STATUS_CONSOLE_URL);
rootJobDisplayName = EntityUtils.getString(entity, ROOT_JOB_DISPLAY_NAME);
Expand Down Expand Up @@ -356,6 +357,9 @@ public Entity toEntity() {
if (queueSettings.getOnQueue() != null) {
builder.set(ON_QUEUE_PROPERTY, StringValue.newBuilder(queueSettings.getOnQueue()).setExcludeFromIndexes(true).build());
}
if (queueSettings.getEncryptionKey() != null) {
builder.set(ENCRYPTION_KEY_PROPERTY, StringValue.newBuilder(queueSettings.getEncryptionKey()).setExcludeFromIndexes(true).build());
}

if (statusConsoleUrl != null) {
builder.set(STATUS_CONSOLE_URL, StringValue.newBuilder(statusConsoleUrl).setExcludeFromIndexes(true).build());
Expand Down Expand Up @@ -518,6 +522,8 @@ private void applySetting(JobSetting setting) {
statusConsoleUrl = ((StatusConsoleUrl) setting).getValue();
} else if (setting instanceof JobSetting.DatastoreNamespace) {
//ignore; applied in constructor, bc it's final
} else if (setting instanceof JobSetting.EncryptionKey) {
queueSettings.setEncryptionKey(((JobSetting.EncryptionKey) setting).getValue());
} else {
throw new RuntimeException("Unrecognized JobSetting class " + setting.getClass().getName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
import java.util.logging.Logger;
import java.util.stream.Stream;

import org.json.JSONObject;
import com.google.appengine.tools.pipeline.impl.util.KmsService;
import java.nio.charset.StandardCharsets;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import lombok.extern.java.Log;
Expand All @@ -48,6 +52,7 @@
public class TaskHandler {

final JobRunServiceComponent component;
final KmsService kmsService;

public static final String PATH_COMPONENT = "handleTask";

Expand Down Expand Up @@ -110,8 +115,19 @@ Integer parseTaskRetryCount(HttpServletRequest req) {

private PipelineTask reconstructTask(HttpServletRequest request) {
Properties properties = new Properties();
Streams.stream(request.getParameterNames().asIterator())
.forEach(name -> properties.setProperty(name, request.getParameter(name)));
String encryptionKey = request.getHeader(PipelineTask.ENCRYPTION_KEY_HEADER);
if (encryptionKey != null && request.getParameter("_encrypted_payload") != null) {
String base64Encrypted = request.getParameter("_encrypted_payload");
byte[] encrypted = Base64.getDecoder().decode(base64Encrypted);
byte[] decrypted = kmsService.decrypt(encryptionKey, encrypted);
JSONObject jsonParams = new JSONObject(new String(decrypted, StandardCharsets.UTF_8));
for (String key : jsonParams.keySet()) {
properties.setProperty(key, jsonParams.getString(key));
}
} else {
Streams.stream(request.getParameterNames().asIterator())
.forEach(name -> properties.setProperty(name, request.getParameter(name)));
}

String taskName = parseTaskName(request);
PipelineTask pipelineTask = PipelineTask.fromProperties(taskName, properties);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,17 @@

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Base64;
import java.util.EnumSet;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;

import org.json.JSONObject;
import com.google.appengine.tools.pipeline.impl.util.KmsService;

/**
* A Pipeline Framework task to be executed asynchronously This is the abstract base class for all
* Pipeline task types.
Expand Down Expand Up @@ -55,6 +60,7 @@
public abstract class PipelineTask {

protected static final String TASK_TYPE_PARAMETER = "taskType";
public static final String ENCRYPTION_KEY_HEADER = "X-Pipeline-EncryptionKey";

@Getter @NonNull
private final Type type;
Expand Down Expand Up @@ -200,14 +206,25 @@ public final Properties toProperties() {
}


public PipelineTaskQueue.TaskSpec toTaskSpec(AppEngineServicesService appEngineServicesService, String callback) {
public PipelineTaskQueue.TaskSpec toTaskSpec(AppEngineServicesService appEngineServicesService, String callback, KmsService kmsService) {
PipelineTaskQueue.TaskSpec.TaskSpecBuilder spec = PipelineTaskQueue.TaskSpec.builder()
.name(this.getTaskName())
.callbackPath(callback)
.method(PipelineTaskQueue.TaskSpec.Method.POST);

this.toProperties().entrySet()
.forEach(p -> spec.param((String) p.getKey(), (String) p.getValue()));
if (this.getQueueSettings().getEncryptionKey() != null) {
spec.header(ENCRYPTION_KEY_HEADER, this.getQueueSettings().getEncryptionKey());
JSONObject jsonParams = new JSONObject();
this.toProperties().forEach((k, v) -> jsonParams.put((String) k, (String) v));
byte[] encrypted = kmsService.encrypt(
this.getQueueSettings().getEncryptionKey(),
jsonParams.toString().getBytes(StandardCharsets.UTF_8));
String base64Encrypted = Base64.getEncoder().encodeToString(encrypted);
spec.param("_encrypted_payload", base64Encrypted);
} else {
this.toProperties().entrySet()
.forEach(p -> spec.param((String) p.getKey(), (String) p.getValue()));
}

if (this.getQueueSettings().getDelayInSeconds() != null) {
spec.scheduledExecutionTime(Instant.now().plusSeconds(this.getQueueSettings().getDelayInSeconds()));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.google.appengine.tools.pipeline.impl.util;

import com.google.cloud.kms.v1.KeyManagementServiceClient;
import com.google.cloud.kms.v1.EncryptResponse;
import com.google.cloud.kms.v1.DecryptResponse;
import com.google.protobuf.ByteString;

import javax.inject.Inject;
import javax.inject.Singleton;
import java.io.IOException;
import java.io.UncheckedIOException;

@Singleton
public class KmsService {
private final KeyManagementServiceClient client;

@Inject
public KmsService() {
try {
this.client = KeyManagementServiceClient.create();
} catch (IOException e) {
throw new UncheckedIOException("Failed to create KeyManagementServiceClient", e);
}
}

public byte[] encrypt(String keyName, byte[] plaintext) {
EncryptResponse response = client.encrypt(keyName, ByteString.copyFrom(plaintext));
return response.getCiphertext().toByteArray();
}

public byte[] decrypt(String keyName, byte[] ciphertext) {
DecryptResponse response = client.decrypt(keyName, ByteString.copyFrom(ciphertext));
return response.getPlaintext().toByteArray();
}
}
Loading