diff --git a/hack/ark/test-e2e.sh b/hack/ark/test-e2e.sh index e24487d3..ae2152c6 100755 --- a/hack/ark/test-e2e.sh +++ b/hack/ark/test-e2e.sh @@ -82,6 +82,7 @@ kubectl apply -f "${root_dir}/hack/ark/conjur-connect-configmap.yaml" # We use a non-existent tag and omit the `--version` flag, to work around a Helm # v4 bug. See: https://github.com/helm/helm/issues/31600 +# TODO: shouldn't need to set config.sendSecretValues because it will default to true in future helm upgrade agent "oci://${ARK_CHART}:NON_EXISTENT_TAG@${ARK_CHART_DIGEST}" \ --install \ --wait \ @@ -94,6 +95,7 @@ helm upgrade agent "oci://${ARK_CHART}:NON_EXISTENT_TAG@${ARK_CHART_DIGEST}" \ --set config.clusterName="e2e-test-cluster" \ --set config.clusterDescription="A temporary cluster for E2E testing. Contact @wallrj-cyberark." \ --set config.period=60s \ + --set config.sendSecretValues=true \ --set-json "podLabels={\"disco-agent.cyberark.cloud/test-id\": \"${RANDOM}\"}" kubectl rollout status deployments/disco-agent --namespace "${NAMESPACE}" diff --git a/internal/cyberark/client_test.go b/internal/cyberark/client_test.go index 1c220d2d..e3a86b22 100644 --- a/internal/cyberark/client_test.go +++ b/internal/cyberark/client_test.go @@ -32,9 +32,9 @@ func TestCyberArkClient_PutSnapshot_MockAPI(t *testing.T) { Secret: "somepassword", } - discoveryClient := servicediscovery.New(httpClient) + discoveryClient := servicediscovery.New(httpClient, cfg.Subdomain) - serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain) + serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context()) if err != nil { t.Fatalf("failed to discover mock services: %v", err) } @@ -44,7 +44,7 @@ func TestCyberArkClient_PutSnapshot_MockAPI(t *testing.T) { err = cl.PutSnapshot(ctx, dataupload.Snapshot{ ClusterID: "ffffffff-ffff-ffff-ffff-ffffffffffff", - AgentVersion: version.PreflightVersion, + AgentVersion: version.CYBRVersion, }) require.NoError(t, err) @@ -76,9 +76,9 @@ func TestCyberArkClient_PutSnapshot_RealAPI(t *testing.T) { cfg, err := cyberark.LoadClientConfigFromEnvironment() require.NoError(t, err) - discoveryClient := servicediscovery.New(httpClient) + discoveryClient := servicediscovery.New(httpClient, cfg.Subdomain) - serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain) + serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context()) if err != nil { t.Fatalf("failed to discover services: %v", err) } @@ -88,7 +88,7 @@ func TestCyberArkClient_PutSnapshot_RealAPI(t *testing.T) { err = cl.PutSnapshot(ctx, dataupload.Snapshot{ ClusterID: "ffffffff-ffff-ffff-ffff-ffffffffffff", - AgentVersion: version.PreflightVersion, + AgentVersion: version.CYBRVersion, }) require.NoError(t, err) diff --git a/internal/cyberark/dataupload/dataupload.go b/internal/cyberark/dataupload/dataupload.go index 18fba38e..b4db0fdf 100644 --- a/internal/cyberark/dataupload/dataupload.go +++ b/internal/cyberark/dataupload/dataupload.go @@ -149,7 +149,7 @@ func (c *CyberArkClient) PutSnapshot(ctx context.Context, snapshot Snapshot) err req.Header.Set("X-Amz-Tagging", q.Encode()) - version.SetUserAgent(req) + version.SetUserAgentCYBR(req) res, err := c.httpClient.Do(req) if err != nil { @@ -190,7 +190,7 @@ func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksu request := RetrievePresignedUploadURLRequest{ ClusterID: clusterID, Checksum: checksum, - AgentVersion: version.PreflightVersion, + AgentVersion: version.CYBRVersion, FileSize: fileSize, } @@ -211,7 +211,7 @@ func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksu return "", "", fmt.Errorf("failed to authenticate request: %s", err) } - version.SetUserAgent(req) + version.SetUserAgentCYBR(req) // Add telemetry headers arkapi.SetTelemetryRequestHeader(req) diff --git a/internal/cyberark/dataupload/dataupload_test.go b/internal/cyberark/dataupload/dataupload_test.go index d78c4bf3..e6631b9e 100644 --- a/internal/cyberark/dataupload/dataupload_test.go +++ b/internal/cyberark/dataupload/dataupload_test.go @@ -37,7 +37,7 @@ func TestCyberArkClient_PutSnapshot_MockAPI(t *testing.T) { name: "successful upload", snapshot: dataupload.Snapshot{ ClusterID: "ffffffff-ffff-ffff-ffff-ffffffffffff", - AgentVersion: version.PreflightVersion, + AgentVersion: version.CYBRVersion, }, authenticate: setToken("success-token"), requireFn: func(t *testing.T, err error) { diff --git a/internal/cyberark/dataupload/mock.go b/internal/cyberark/dataupload/mock.go index b2e416e3..e6a183bd 100644 --- a/internal/cyberark/dataupload/mock.go +++ b/internal/cyberark/dataupload/mock.go @@ -101,7 +101,7 @@ func randHex() string { } func (mds *mockDataUploadServer) handleSnapshotLinks(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("User-Agent") != version.UserAgent() { + if r.Header.Get("User-Agent") != version.UserAgentCYBR() { http.Error(w, "should set user agent on all requests", http.StatusInternalServerError) return } @@ -132,7 +132,7 @@ func (mds *mockDataUploadServer) handleSnapshotLinks(w http.ResponseWriter, r *h return } - if req.AgentVersion != version.PreflightVersion { + if req.AgentVersion != version.CYBRVersion { http.Error(w, fmt.Sprintf("post body contains unexpected agent version: %s", req.AgentVersion), http.StatusInternalServerError) return } @@ -214,7 +214,7 @@ func (mds *mockDataUploadServer) handlePresignedUpload(w http.ResponseWriter, r return } - if r.Header.Get("User-Agent") != version.UserAgent() { + if r.Header.Get("User-Agent") != version.UserAgentCYBR() { http.Error(w, "should set user agent on all requests", http.StatusInternalServerError) return } @@ -249,8 +249,8 @@ func (mds *mockDataUploadServer) handlePresignedUpload(w http.ResponseWriter, r return } - if tags.Get("agent_version") != version.PreflightVersion { - http.Error(w, fmt.Sprintf("x-amz-tagging should contain an agent_version tag with value %s", version.PreflightVersion), http.StatusInternalServerError) + if tags.Get("agent_version") != version.CYBRVersion { + http.Error(w, fmt.Sprintf("x-amz-tagging should contain an agent_version tag with value %s but got %s", version.CYBRVersion, tags.Get("agent_version")), http.StatusInternalServerError) return } @@ -308,7 +308,7 @@ func (mds *mockDataUploadServer) handlePresignedUpload(w http.ResponseWriter, r err = d.Decode(&snapshot) require.NoError(mds.t, err) assert.Equal(mds.t, successClusterID, snapshot.ClusterID) - assert.Equal(mds.t, version.PreflightVersion, snapshot.AgentVersion) + assert.Equal(mds.t, version.CYBRVersion, snapshot.AgentVersion) // AWS S3 responds with an empty body if the PUT succeeds w.WriteHeader(http.StatusOK) diff --git a/internal/cyberark/identity/cmd/testidentity/main.go b/internal/cyberark/identity/cmd/testidentity/main.go index 916c81ea..0a8df80b 100644 --- a/internal/cyberark/identity/cmd/testidentity/main.go +++ b/internal/cyberark/identity/cmd/testidentity/main.go @@ -50,8 +50,8 @@ func run(ctx context.Context) error { var rootCAs *x509.CertPool httpClient := http_client.NewDefaultClient(version.UserAgent(), rootCAs) - sdClient := servicediscovery.New(httpClient) - services, _, err := sdClient.DiscoverServices(ctx, subdomain) + sdClient := servicediscovery.New(httpClient, subdomain) + services, _, err := sdClient.DiscoverServices(ctx) if err != nil { return fmt.Errorf("while performing service discovery: %s", err) } diff --git a/internal/cyberark/identity/identity_test.go b/internal/cyberark/identity/identity_test.go index 917ba15d..0915f46c 100644 --- a/internal/cyberark/identity/identity_test.go +++ b/internal/cyberark/identity/identity_test.go @@ -53,7 +53,7 @@ func TestLoginUsernamePassword_RealAPI(t *testing.T) { arktesting.SkipIfNoEnv(t) subdomain := os.Getenv("ARK_SUBDOMAIN") httpClient := http.DefaultClient - services, _, err := servicediscovery.New(httpClient).DiscoverServices(t.Context(), subdomain) + services, _, err := servicediscovery.New(httpClient, subdomain).DiscoverServices(t.Context()) require.NoError(t, err) loginUsernamePasswordTests(t, func(t testing.TB) inputs { diff --git a/internal/cyberark/servicediscovery/discovery.go b/internal/cyberark/servicediscovery/discovery.go index 82394ab3..93598d5c 100644 --- a/internal/cyberark/servicediscovery/discovery.go +++ b/internal/cyberark/servicediscovery/discovery.go @@ -9,6 +9,8 @@ import ( "net/url" "os" "path" + "sync" + "time" arkapi "github.com/jetstack/preflight/internal/cyberark/api" "github.com/jetstack/preflight/pkg/version" @@ -35,21 +37,34 @@ const ( // users to fetch URLs for various APIs available in CyberArk. This client is specialised to // fetch only API endpoints, since only API endpoints are required by the Venafi Kubernetes Agent currently. type Client struct { - client *http.Client - baseURL string + client *http.Client + baseURL string + subdomain string + + cachedResponse *Services + cachedTenantID string + cachedResponseTime time.Time + cachedResponseMutex sync.Mutex } // New creates a new CyberArk Service Discovery client. If the ARK_DISCOVERY_API // environment variable is set, it is used as the base URL for the service // discovery API. Otherwise, the production URL is used. -func New(httpClient *http.Client) *Client { +func New(httpClient *http.Client, subdomain string) *Client { baseURL := os.Getenv("ARK_DISCOVERY_API") if baseURL == "" { baseURL = ProdDiscoveryAPIBaseURL } + client := &Client{ - client: httpClient, - baseURL: baseURL, + client: httpClient, + baseURL: baseURL, + subdomain: subdomain, + + cachedResponse: nil, + cachedTenantID: "", + cachedResponseTime: time.Time{}, + cachedResponseMutex: sync.Mutex{}, } return client @@ -93,17 +108,24 @@ type Services struct { DiscoveryContext ServiceEndpoint } -// DiscoverServices fetches from the service discovery service for a given subdomain +// DiscoverServices fetches from the service discovery service for the configured subdomain // and parses the CyberArk Identity API URL and Inventory API URL. // It also returns the Tenant ID UUID corresponding to the subdomain. -func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Services, string, error) { +func (c *Client) DiscoverServices(ctx context.Context) (*Services, string, error) { + c.cachedResponseMutex.Lock() + defer c.cachedResponseMutex.Unlock() + + if c.cachedResponse != nil && time.Since(c.cachedResponseTime) < 1*time.Hour { + return c.cachedResponse, c.cachedTenantID, nil + } + u, err := url.Parse(c.baseURL) if err != nil { return nil, "", fmt.Errorf("invalid base URL for service discovery: %w", err) } u.Path = path.Join(u.Path, "api/public/tenant-discovery") - u.RawQuery = url.Values{"bySubdomain": []string{subdomain}}.Encode() + u.RawQuery = url.Values{"bySubdomain": []string{c.subdomain}}.Encode() endpoint := u.String() @@ -127,7 +149,7 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi // a 404 error is returned with an empty JSON body "{}" if the subdomain is unknown; at the time of writing, we haven't observed // any other errors and so we can't special case them if resp.StatusCode == http.StatusNotFound { - return nil, "", fmt.Errorf("got an HTTP 404 response from service discovery; maybe the subdomain %q is incorrect or does not exist?", subdomain) + return nil, "", fmt.Errorf("got an HTTP 404 response from service discovery; maybe the subdomain %q is incorrect or does not exist?", c.subdomain) } return nil, "", fmt.Errorf("got unexpected status code %s from request to service discovery API", resp.Status) @@ -167,8 +189,14 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi } //TODO: Should add a check for discoveryContextAPI too? - return &Services{ + services := &Services{ Identity: ServiceEndpoint{API: identityAPI}, DiscoveryContext: ServiceEndpoint{API: discoveryContextAPI}, - }, discoveryResp.TenantID, nil + } + + c.cachedResponse = services + c.cachedTenantID = discoveryResp.TenantID + c.cachedResponseTime = time.Now() + + return services, discoveryResp.TenantID, nil } diff --git a/internal/cyberark/servicediscovery/discovery_test.go b/internal/cyberark/servicediscovery/discovery_test.go index 00d0fd58..618e63f9 100644 --- a/internal/cyberark/servicediscovery/discovery_test.go +++ b/internal/cyberark/servicediscovery/discovery_test.go @@ -64,9 +64,9 @@ func Test_DiscoverIdentityAPIURL(t *testing.T) { }, }) - client := New(httpClient) + client := New(httpClient, testSpec.subdomain) - services, _, err := client.DiscoverServices(ctx, testSpec.subdomain) + services, _, err := client.DiscoverServices(ctx) if testSpec.expectedError != nil { assert.EqualError(t, err, testSpec.expectedError.Error()) assert.Nil(t, services) diff --git a/internal/envelope/keyfetch/client.go b/internal/envelope/keyfetch/client.go new file mode 100644 index 00000000..132ee743 --- /dev/null +++ b/internal/envelope/keyfetch/client.go @@ -0,0 +1,156 @@ +package keyfetch + +import ( + "context" + "crypto/rsa" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/lestrrat-go/jwx/v3/jwk" + + "github.com/jetstack/preflight/internal/cyberark/servicediscovery" +) + +const ( + // minRSAKeySize is the minimum RSA key size in bits; we'd expect that keys will be larger but 2048 is a sane floor + // to enforce to ensure that a weak key can't accidentally be used + minRSAKeySize = 2048 +) + +// KeyFetcher is an interface for fetching public keys. +type KeyFetcher interface { + // FetchKey retrieves a public key from the key source. + FetchKey(ctx context.Context) (PublicKey, error) +} + +// Compile-time check that Client implements KeyFetcher +var _ KeyFetcher = (*Client)(nil) + +// PublicKey represents an RSA public key retrieved from the key server. +type PublicKey struct { + // KeyID is the unique identifier for this key + KeyID string + + // Key is the actual RSA public key + Key *rsa.PublicKey +} + +// Client fetches public keys from a CyberArk HTTP endpoint that provides keys in JWKS format. +// It can be expanded in future to support other key types and formats, but for now it only supports RSA keys +// and ignored other types. +type Client struct { + discoveryClient *servicediscovery.Client + + // httpClient is the HTTP client used for requests + httpClient *http.Client +} + +// NewClient creates a new key fetching client. +// Uses CyberArk service discovery to derive the JWKS endpoint +func NewClient(discoveryClient *servicediscovery.Client) *Client { + return &Client{ + discoveryClient: discoveryClient, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// FetchKey retrieves the public keys from the configured endpoint. +// It returns a slice of PublicKey structs containing the key material and metadata. +func (c *Client) FetchKey(ctx context.Context) (PublicKey, error) { + services, _, err := c.discoveryClient.DiscoverServices(ctx) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to get services from discovery client: %w", err) + } + + endpoint, err := url.JoinPath(services.DiscoveryContext.API, "discovery-context/jwks") + if err != nil { + return PublicKey{}, fmt.Errorf("failed to construct endpoint URL: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to fetch keys from %s: %w", endpoint, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return PublicKey{}, fmt.Errorf("unexpected status code %d from %s: %s", resp.StatusCode, endpoint, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to read response body: %w", err) + } + + keySet, err := jwk.Parse(body) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to parse JWKs response: %w", err) + } + + for i := range keySet.Len() { + key, ok := keySet.Key(i) + if !ok { + continue + } + + // Only process RSA keys + if key.KeyType().String() != "RSA" { + continue + } + + var rawKey any + if err := jwk.Export(key, &rawKey); err != nil { + // skip unparseable keys + continue + } + + rsaKey, ok := rawKey.(*rsa.PublicKey) + if !ok { + // only process RSA keys (for now) + continue + } + + if rsaKey.N.BitLen() < minRSAKeySize { + // skip keys that are too small to be secure + continue + } + + kid, ok := key.KeyID() + if !ok { + // skip any keys which don't have an ID + continue + } + + alg, ok := key.Algorithm() + if !ok { + // skip any keys which don't have an algorithm specified + continue + } + + if alg.String() != "RSA-OAEP-256" { + // we only use RSA keys for RSA-OAEP-256 + continue + } + + // return the first valid key we find + return PublicKey{ + KeyID: kid, + Key: rsaKey, + }, nil + } + + return PublicKey{}, fmt.Errorf("no valid RSA keys found at %s", endpoint) +} diff --git a/internal/envelope/keyfetch/client_test.go b/internal/envelope/keyfetch/client_test.go new file mode 100644 index 00000000..4e7584d7 --- /dev/null +++ b/internal/envelope/keyfetch/client_test.go @@ -0,0 +1,247 @@ +package keyfetch + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/jetstack/preflight/internal/cyberark/servicediscovery" +) + +// mockDiscoveryClient creates a discovery client that returns the given URL as the API endpoint +func mockDiscoveryClient(t *testing.T, apiURL string) *servicediscovery.Client { + t.Helper() + + // Create a mock discovery server that returns the test server URL + discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response := servicediscovery.DiscoveryResponse{ + TenantID: "test-tenant", + Services: []servicediscovery.Service{ + { + ServiceName: servicediscovery.DiscoveryContextServiceName, + Endpoints: []servicediscovery.ServiceEndpoint{ + { + IsActive: true, + Type: "main", + API: apiURL, + }, + }, + }, + { + ServiceName: servicediscovery.IdentityServiceName, + Endpoints: []servicediscovery.ServiceEndpoint{ + { + IsActive: true, + Type: "main", + API: "https://identity.example.com", + }, + }, + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(response) + })) + t.Cleanup(discoveryServer.Close) + + // Override the discovery API URL with our mock server + t.Setenv("ARK_DISCOVERY_API", discoveryServer.URL) + + return servicediscovery.New(&http.Client{}, "test-subdomain") +} + +func TestClient_FetchKey(t *testing.T) { + // Sample JWKs response with a valid RSA key + // This is a minimal example with the required fields + jwksResponse := `{ + "keys": [ + { + "kty": "RSA", + "use": "enc", + "kid": "test-key-1", + "alg": "RSA-OAEP-256", + "n": "vDdioGpDuAEQDd4WRXyWa4sZ5EeS9OPsRrU_jU3PbZdDcANxfh_WSeSvSBKGfGXGC3fIzu0Ernk9VjXcs3LeFdRq2N4nNRZvCzsd_MjBtn7CWgjM_Sk9DXEGn3cHHilcJUJQ4i2YgX9bHu0odNgE6cSVIUEMIC2EGuGk_I7lwroinAAwXpNLLQkV_25kv_QQof2i5f7AocY6QTd0SAo8ZUqFBzanupkeFpl3-Bsz6_zdt_N0x9k5XHQn42Q2oTupTwvXFbE1x8XtCpiaP3_fsQ9dN7t4z6HtwlNUJB2tFfF6PgdKZ9LuJpYjFPYzJQ6Rv28fuc8YHcF7Jittjyzmew", + "e": "AQAB" + } + ] + }` + + t.Run("successful fetch", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method) + assert.Equal(t, "application/json", r.Header.Get("Accept")) + + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(jwksResponse)) + require.NoError(t, err) + })) + defer server.Close() + + discoveryClient := mockDiscoveryClient(t, server.URL) + client := NewClient(discoveryClient) + key, err := client.FetchKey(context.Background()) + + require.NoError(t, err) + + assert.Equal(t, "test-key-1", key.KeyID) + assert.NotNil(t, key.Key) + assert.NotNil(t, key.Key.N) + assert.Greater(t, key.Key.E, 0) + }) + + t.Run("multiple keys", func(t *testing.T) { + multiKeyResponse := `{ + "keys": [ + { + "kty": "RSA", + "kid": "key-1", + "alg": "RSA-OAEP-256", + "n": "vDdioGpDuAEQDd4WRXyWa4sZ5EeS9OPsRrU_jU3PbZdDcANxfh_WSeSvSBKGfGXGC3fIzu0Ernk9VjXcs3LeFdRq2N4nNRZvCzsd_MjBtn7CWgjM_Sk9DXEGn3cHHilcJUJQ4i2YgX9bHu0odNgE6cSVIUEMIC2EGuGk_I7lwroinAAwXpNLLQkV_25kv_QQof2i5f7AocY6QTd0SAo8ZUqFBzanupkeFpl3-Bsz6_zdt_N0x9k5XHQn42Q2oTupTwvXFbE1x8XtCpiaP3_fsQ9dN7t4z6HtwlNUJB2tFfF6PgdKZ9LuJpYjFPYzJQ6Rv28fuc8YHcF7Jittjyzmew", + "e": "AQAB" + }, + { + "kty": "RSA", + "kid": "key-2", + "alg": "RSA-OAEP-256", + "n": "4J0VE8FK1rSQUBGiLpk4MkPyFApCyCugOfkuH0hiHclxZay96JgyZylH97eqs-ZmWXtv42ynYctIj2ZleaoqVDfMOqZ1GsbccyNAYReDtUYgeUtJEajpfUo1vitoh6OEB6nB0Hau07ELLqcUoxH_zkH5Kwoi_BgxByJDQ1HOut6nyEPTXLTMrAYK_pqL_kzsU0OtrCgSBh6j-11ToqUfxsLupbadRC0t5zrq4-3mZKqxBUz4XB2g3b9d2lH7mOTl5J_E8jcD4tK9DePzjdbkRWonBEJetWl9f2mh_VD1sxJbie1kzM5cdQylXzV_AvhSr58w00qy6XR_QXI10UU16Q", + "e": "AQAB" + } + ] + }` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(multiKeyResponse)) + require.NoError(t, err) + })) + defer server.Close() + + discoveryClient := mockDiscoveryClient(t, server.URL) + client := NewClient(discoveryClient) + key, err := client.FetchKey(context.Background()) + + require.NoError(t, err) + + assert.Equal(t, "key-1", key.KeyID) + }) + + t.Run("filters non-RSA keys", func(t *testing.T) { + mixedKeyResponse := `{ + "keys": [ + { + "kty": "EC", + "kid": "ec-key-1", + "alg": "ES256", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE" + }, + { + "kty": "RSA", + "kid": "rsa-key-1", + "alg": "RSA-OAEP-256", + "n": "vDdioGpDuAEQDd4WRXyWa4sZ5EeS9OPsRrU_jU3PbZdDcANxfh_WSeSvSBKGfGXGC3fIzu0Ernk9VjXcs3LeFdRq2N4nNRZvCzsd_MjBtn7CWgjM_Sk9DXEGn3cHHilcJUJQ4i2YgX9bHu0odNgE6cSVIUEMIC2EGuGk_I7lwroinAAwXpNLLQkV_25kv_QQof2i5f7AocY6QTd0SAo8ZUqFBzanupkeFpl3-Bsz6_zdt_N0x9k5XHQn42Q2oTupTwvXFbE1x8XtCpiaP3_fsQ9dN7t4z6HtwlNUJB2tFfF6PgdKZ9LuJpYjFPYzJQ6Rv28fuc8YHcF7Jittjyzmew", + "e": "AQAB" + } + ] + }` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(mixedKeyResponse)) + require.NoError(t, err) + })) + defer server.Close() + + discoveryClient := mockDiscoveryClient(t, server.URL) + client := NewClient(discoveryClient) + key, err := client.FetchKey(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "rsa-key-1", key.KeyID) + }) + + t.Run("error on non-200 status", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal server error")) + })) + defer server.Close() + + discoveryClient := mockDiscoveryClient(t, server.URL) + client := NewClient(discoveryClient) + _, err := client.FetchKey(context.Background()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unexpected status code 500") + }) + + t.Run("error on invalid JSON", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("invalid json")) + })) + defer server.Close() + + discoveryClient := mockDiscoveryClient(t, server.URL) + client := NewClient(discoveryClient) + _, err := client.FetchKey(context.Background()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse JWKs response") + }) + + t.Run("error on no RSA keys", func(t *testing.T) { + emptyResponse := `{ + "keys": [ + { + "kty": "EC", + "kid": "ec-key-1", + "alg": "ES256", + "crv": "P-256", + "x": "WKn-ZIGevcwGIyyrzFoZNBdaq9_TsqzGl96oc0CWuis", + "y": "y77t-RvAHRKTsSGdIYUfweuOvwrvDD-Q3Hv5J0fSKbE" + } + ] + }` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(emptyResponse)) + require.NoError(t, err) + })) + defer server.Close() + + discoveryClient := mockDiscoveryClient(t, server.URL) + client := NewClient(discoveryClient) + _, err := client.FetchKey(context.Background()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "no valid RSA keys found") + }) + + t.Run("context cancellation", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // This handler will never respond + <-r.Context().Done() + })) + defer server.Close() + + discoveryClient := mockDiscoveryClient(t, server.URL) + client := NewClient(discoveryClient) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + _, err := client.FetchKey(ctx) + + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + }) +} diff --git a/internal/envelope/keyfetch/doc.go b/internal/envelope/keyfetch/doc.go new file mode 100644 index 00000000..2da5b587 --- /dev/null +++ b/internal/envelope/keyfetch/doc.go @@ -0,0 +1,22 @@ +// Package keyfetch provides a client for fetching encryption keys from an HTTP endpoint. +// +// The client retrieves public keys in JSON Web Key Set (JWKs) format from a remote +// server and converts them into usable cryptographic keys for envelope encryption. +// +// Example usage: +// +// client := keyfetch.NewClient("https://keys.example.com/jwks") +// keys, err := client.FetchKeys(context.Background()) +// if err != nil { +// // handle error +// } +// +// // Use the keys for envelope encryption +// for _, key := range keys { +// fmt.Printf("Key ID: %s, Algorithm: %s\n", key.KeyID, key.Algorithm) +// } +// +// This package uses github.com/lestrrat-go/jwx/v3/jwk for JWK parsing and handling. +// +// Currently, keyfetch only supports RSA keys for envelope encryption. +package keyfetch diff --git a/internal/envelope/keyfetch/fake.go b/internal/envelope/keyfetch/fake.go new file mode 100644 index 00000000..d7226b2b --- /dev/null +++ b/internal/envelope/keyfetch/fake.go @@ -0,0 +1,85 @@ +package keyfetch + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "fmt" +) + +// Compile-time check that FakeClient implements KeyFetcher +var _ KeyFetcher = (*FakeClient)(nil) + +// FakeClient is a fake implementation of the key fetcher for testing. +// It can be configured to return specific keys or errors for testing different scenarios. +type FakeClient struct { + // Key is the public key that will be returned by FetchKey. + // If nil, a random key will be generated on the first call. + Key *PublicKey + + // Err is the error that will be returned by FetchKey. + // If both Key and Err are set, Err takes precedence. + Err error + + // FetchKeyCalls tracks how many times FetchKey was called + FetchKeyCalls int +} + +// NewFakeClient creates a new fake client for testing. +func NewFakeClient() *FakeClient { + return &FakeClient{} +} + +// NewFakeClientWithKey creates a new fake client that returns the specified key. +func NewFakeClientWithKey(keyID string, key *rsa.PublicKey) *FakeClient { + return &FakeClient{ + Key: &PublicKey{ + KeyID: keyID, + Key: key, + }, + } +} + +// NewFakeClientWithError creates a new fake client that returns the specified error. +func NewFakeClientWithError(err error) *FakeClient { + return &FakeClient{ + Err: err, + } +} + +// FetchKey implements the key fetching interface for testing. +// It returns the configured key or error, or generates a random key if none is configured. +func (f *FakeClient) FetchKey(ctx context.Context) (PublicKey, error) { + f.FetchKeyCalls++ + + // Check if context is canceled + if ctx.Err() != nil { + return PublicKey{}, ctx.Err() + } + + // If an error is configured, return it + if f.Err != nil { + return PublicKey{}, f.Err + } + + // If a key is configured, return it + if f.Key != nil { + return *f.Key, nil + } + + // Generate a random key for testing + privateKey, err := rsa.GenerateKey(rand.Reader, minRSAKeySize) + if err != nil { + return PublicKey{}, fmt.Errorf("failed to generate test key: %w", err) + } + + generatedKey := PublicKey{ + KeyID: "test-key", + Key: &privateKey.PublicKey, + } + + // Cache the generated key for subsequent calls + f.Key = &generatedKey + + return generatedKey, nil +} diff --git a/internal/envelope/keyfetch/fake_test.go b/internal/envelope/keyfetch/fake_test.go new file mode 100644 index 00000000..9036dc31 --- /dev/null +++ b/internal/envelope/keyfetch/fake_test.go @@ -0,0 +1,89 @@ +package keyfetch + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFakeClient(t *testing.T) { + t.Run("returns generated key by default", func(t *testing.T) { + fake := NewFakeClient() + + key, err := fake.FetchKey(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "test-key", key.KeyID) + assert.NotNil(t, key.Key) + assert.Equal(t, 1, fake.FetchKeyCalls) + + // Subsequent calls return the same key + key2, err := fake.FetchKey(context.Background()) + require.NoError(t, err) + assert.Equal(t, key.KeyID, key2.KeyID) + assert.Equal(t, key.Key, key2.Key) + assert.Equal(t, 2, fake.FetchKeyCalls) + }) + + t.Run("returns configured key", func(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, minRSAKeySize) + require.NoError(t, err) + + fake := NewFakeClientWithKey("custom-key", &privateKey.PublicKey) + + key, err := fake.FetchKey(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "custom-key", key.KeyID) + assert.Equal(t, &privateKey.PublicKey, key.Key) + assert.Equal(t, 1, fake.FetchKeyCalls) + }) + + t.Run("returns configured error", func(t *testing.T) { + expectedErr := errors.New("test error") + fake := NewFakeClientWithError(expectedErr) + + _, err := fake.FetchKey(context.Background()) + require.Error(t, err) + + assert.Equal(t, expectedErr, err) + assert.Equal(t, 1, fake.FetchKeyCalls) + }) + + t.Run("respects context cancellation", func(t *testing.T) { + fake := NewFakeClient() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := fake.FetchKey(ctx) + require.Error(t, err) + + assert.Equal(t, context.Canceled, err) + assert.Equal(t, 1, fake.FetchKeyCalls) + }) + + t.Run("error takes precedence over key", func(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, minRSAKeySize) + require.NoError(t, err) + + expectedErr := errors.New("test error") + fake := &FakeClient{ + Key: &PublicKey{ + KeyID: "custom-key", + Key: &privateKey.PublicKey, + }, + Err: expectedErr, + } + + _, err = fake.FetchKey(context.Background()) + require.Error(t, err) + + assert.Equal(t, expectedErr, err) + }) +} diff --git a/internal/envelope/rsa/encryptor.go b/internal/envelope/rsa/encryptor.go index 8cc0e17a..80d87de3 100644 --- a/internal/envelope/rsa/encryptor.go +++ b/internal/envelope/rsa/encryptor.go @@ -1,20 +1,17 @@ package rsa import ( - "crypto/rsa" + "context" "fmt" "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwe" "github.com/jetstack/preflight/internal/envelope" + "github.com/jetstack/preflight/internal/envelope/keyfetch" ) const ( - // minRSAKeySize is the minimum RSA key size in bits; we'd expect that keys will be larger but 2048 is a sane floor - // to enforce to ensure that a weak key can't accidentally be used - minRSAKeySize = 2048 - // EncryptionType is the type identifier for RSA JWE encryption EncryptionType = "JWE-RSA" ) @@ -25,45 +22,33 @@ var _ envelope.Encryptor = (*Encryptor)(nil) // Encryptor provides envelope encryption using RSA-OAEP-256 for key wrapping // and AES-256-GCM for data encryption, outputting JWE Compact Serialization format. type Encryptor struct { - keyID string - publicKey *rsa.PublicKey + fetcher keyfetch.KeyFetcher } -// NewEncryptor creates a new Encryptor with the provided RSA public key. -// The RSA key must be at least minRSAKeySize bits. +// NewEncryptor creates a new Encryptor with the provided key fetcher. // The encryptor will use RSA-OAEP-256 for key encryption and A256GCM for content encryption. -func NewEncryptor(keyID string, publicKey *rsa.PublicKey) (*Encryptor, error) { - if publicKey == nil { - return nil, fmt.Errorf("RSA public key cannot be nil") - } - - // Validate key size - keySize := publicKey.N.BitLen() - if keySize < minRSAKeySize { - return nil, fmt.Errorf("RSA key size must be at least %d bits, got %d bits", minRSAKeySize, keySize) - } - - if len(keyID) == 0 { - return nil, fmt.Errorf("keyID cannot be empty") - } - +func NewEncryptor(fetcher keyfetch.KeyFetcher) (*Encryptor, error) { return &Encryptor{ - keyID: keyID, - publicKey: publicKey, + fetcher: fetcher, }, nil } // Encrypt performs envelope encryption on the provided data. // It returns an EncryptedData struct containing JWE Compact Serialization format and type metadata. // The JWE uses RSA-OAEP-256 for key encryption and A256GCM for content encryption. -func (e *Encryptor) Encrypt(data []byte) (*envelope.EncryptedData, error) { +func (e *Encryptor) Encrypt(ctx context.Context, data []byte) (*envelope.EncryptedData, error) { if len(data) == 0 { return nil, fmt.Errorf("data to encrypt cannot be empty") } + key, err := e.fetcher.FetchKey(ctx) + if err != nil { + return nil, fmt.Errorf("failed to fetch encryption key: %w", err) + } + // Create headers with the key ID headers := jwe.NewHeaders() - if err := headers.Set("kid", e.keyID); err != nil { + if err := headers.Set("kid", key.KeyID); err != nil { return nil, fmt.Errorf("failed to set key ID header: %w", err) } @@ -71,7 +56,7 @@ func (e *Encryptor) Encrypt(data []byte) (*envelope.EncryptedData, error) { // TODO: in go1.26+, consider using secret.Do to wrap this call, since it will generate an AES key encrypted, err := jwe.Encrypt( data, - jwe.WithKey(jwa.RSA_OAEP_256(), e.publicKey, jwe.WithPerRecipientHeaders(headers)), + jwe.WithKey(jwa.RSA_OAEP_256(), key.Key, jwe.WithPerRecipientHeaders(headers)), jwe.WithContentEncryption(jwa.A256GCM()), jwe.WithCompact(), ) diff --git a/internal/envelope/rsa/encryptor_test.go b/internal/envelope/rsa/encryptor_test.go index 6763c1e5..50534fed 100644 --- a/internal/envelope/rsa/encryptor_test.go +++ b/internal/envelope/rsa/encryptor_test.go @@ -3,9 +3,7 @@ package rsa import ( "crypto/rand" "crypto/rsa" - "crypto/x509" "encoding/base64" - "encoding/pem" "strings" "sync" "testing" @@ -13,21 +11,15 @@ import ( "github.com/lestrrat-go/jwx/v3/jwa" "github.com/lestrrat-go/jwx/v3/jwe" "github.com/stretchr/testify/require" -) -const testKeyID = "test-key-id" + "github.com/jetstack/preflight/internal/envelope/keyfetch" +) -// smallRSAKey1024 is a hardcoded 1024-bit RSA public key in PEM format (PKIX) -// used for testing key size validation. This key is intentionally weak and should -// only be used for testing purposes. -// This is hardcoded rather than generated in order to save compute, and also on the -// assumption that future Go releases might restrict the ability to generate such small keys. -const smallRSAKey1024 = `-----BEGIN PUBLIC KEY----- -MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDCNDoCM0OBt4HFxFxyU50FYsuZ -gK+lgel/Jlzb+ghkWpCL1Vk3Au7aet4KxNxQh5dFRxtMU7pe6fC5eZtdL3+0TCUu -XAUVgMhTRn3ZXlEmJXosuiFQ2y4+3nbWL51OxXRf3jsieSVqr4fbceakuOKXp4vX -wgiguV3/XqaysHs1uwIDAQAB ------END PUBLIC KEY-----` +const ( + testKeyID = "test-key-id" + // minRSAKeySize is the minimum RSA key size used for test key generation + minRSAKeySize = 2048 +) var ( testKeyOnce sync.Once @@ -49,67 +41,10 @@ func testKey() *rsa.PrivateKey { return internalTestKey } -func TestNewEncryptor_ValidKeys(t *testing.T) { - tests := []struct { - name string - keySize int - }{ - {"2048 bits", 2048}, - {"3072 bits", 3072}, - {"4096 bits", 4096}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - key, err := rsa.GenerateKey(rand.Reader, tt.keySize) - require.NoError(t, err) - - enc, err := NewEncryptor(testKeyID, &key.PublicKey) - require.NoError(t, err) - require.NotNil(t, enc) - }) - } -} - -func TestNewEncryptor_RejectsSmallKeys(t *testing.T) { - // Parse the hardcoded 1024-bit RSA public key from PEM format - block, _ := pem.Decode([]byte(smallRSAKey1024)) - require.NotNil(t, block, "failed to decode PEM block") - - // NB: a future Go update might restrict the ability to parse small keys; - // if that happens, this test will need to be removed or changed. - pubKey, err := x509.ParsePKIXPublicKey(block.Bytes) - require.NoError(t, err, "failed to parse RSA public key") - - rsaPubKey, ok := pubKey.(*rsa.PublicKey) - require.True(t, ok, "key should be an RSA public key") - - enc, err := NewEncryptor(testKeyID, rsaPubKey) - require.Error(t, err) - require.Nil(t, enc) - require.Contains(t, err.Error(), "must be at least 2048 bits") -} - -func TestNewEncryptor_NilKey(t *testing.T) { - enc, err := NewEncryptor(testKeyID, nil) - require.Error(t, err) - require.Nil(t, enc) - require.Contains(t, err.Error(), "cannot be nil") -} - -func TestNewEncryptor_EmptyKeyID(t *testing.T) { - key := testKey() - - enc, err := NewEncryptor("", &key.PublicKey) - require.Error(t, err) - require.Nil(t, enc) - require.Contains(t, err.Error(), "keyID cannot be empty") -} - func TestEncrypt_VariousDataSizes(t *testing.T) { - key := testKey() + fetcher := keyfetch.NewFakeClient() - enc, err := NewEncryptor(testKeyID, &key.PublicKey) + enc, err := NewEncryptor(fetcher) require.NoError(t, err) tests := []struct { @@ -127,7 +62,7 @@ func TestEncrypt_VariousDataSizes(t *testing.T) { _, err := rand.Read(data) require.NoError(t, err) - result, err := enc.Encrypt(data) + result, err := enc.Encrypt(t.Context(), data) require.NoError(t, err) require.NotNil(t, result) require.Equal(t, EncryptionType, result.Type, "Type should be JWE-RSA") @@ -152,31 +87,31 @@ func TestEncrypt_VariousDataSizes(t *testing.T) { } func TestEncrypt_EmptyData(t *testing.T) { - key := testKey() + fetcher := keyfetch.NewFakeClient() - enc, err := NewEncryptor(testKeyID, &key.PublicKey) + enc, err := NewEncryptor(fetcher) require.NoError(t, err) - result, err := enc.Encrypt([]byte{}) + result, err := enc.Encrypt(t.Context(), []byte{}) require.Error(t, err) require.Nil(t, result) require.Contains(t, err.Error(), "cannot be empty") } func TestEncrypt_NonDeterministic(t *testing.T) { - key := testKey() + fetcher := keyfetch.NewFakeClient() - enc, err := NewEncryptor(testKeyID, &key.PublicKey) + enc, err := NewEncryptor(fetcher) require.NoError(t, err) data := []byte("test data for encryption") // Encrypt the same data twice - result1, err := enc.Encrypt(data) + result1, err := enc.Encrypt(t.Context(), data) require.NoError(t, err) require.Equal(t, EncryptionType, result1.Type, "Type should be JWE-RSA") - result2, err := enc.Encrypt(data) + result2, err := enc.Encrypt(t.Context(), data) require.NoError(t, err) require.Equal(t, EncryptionType, result2.Type, "Type should be JWE-RSA") @@ -186,12 +121,13 @@ func TestEncrypt_NonDeterministic(t *testing.T) { func TestEncrypt_JWEFormat(t *testing.T) { key := testKey() + fetcher := keyfetch.NewFakeClientWithKey(testKeyID, &key.PublicKey) - enc, err := NewEncryptor(testKeyID, &key.PublicKey) + enc, err := NewEncryptor(fetcher) require.NoError(t, err) data := []byte("test data") - result, err := enc.Encrypt(data) + result, err := enc.Encrypt(t.Context(), data) require.NoError(t, err) require.Equal(t, EncryptionType, result.Type, "Type should be JWE-RSA") @@ -203,14 +139,15 @@ func TestEncrypt_JWEFormat(t *testing.T) { func TestEncrypt_DecryptRoundtrip(t *testing.T) { key := testKey() + fetcher := keyfetch.NewFakeClientWithKey(testKeyID, &key.PublicKey) - enc, err := NewEncryptor(testKeyID, &key.PublicKey) + enc, err := NewEncryptor(fetcher) require.NoError(t, err) originalData := []byte("test data for roundtrip encryption and decryption") // Encrypt the data - encrypted, err := enc.Encrypt(originalData) + encrypted, err := enc.Encrypt(t.Context(), originalData) require.NoError(t, err) require.Equal(t, EncryptionType, encrypted.Type, "Type should be JWE-RSA") diff --git a/internal/envelope/rsa/keys_test.go b/internal/envelope/rsa/keys_test.go index 83f86e19..1a138a35 100644 --- a/internal/envelope/rsa/keys_test.go +++ b/internal/envelope/rsa/keys_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/jetstack/preflight/internal/envelope/keyfetch" internalrsa "github.com/jetstack/preflight/internal/envelope/rsa" ) @@ -151,13 +152,14 @@ func TestLoadHardcodedPublicKey_CanBeUsedWithEncryptor(t *testing.T) { require.NotNil(t, key) require.NotEmpty(t, uid) - encryptor, err := internalrsa.NewEncryptor(uid, key) + fetcher := keyfetch.NewFakeClientWithKey(uid, key) + encryptor, err := internalrsa.NewEncryptor(fetcher) require.NoError(t, err) require.NotNil(t, encryptor) // Test that the encryptor can encrypt data testData := []byte("test data for encryption") - encryptedData, err := encryptor.Encrypt(testData) + encryptedData, err := encryptor.Encrypt(t.Context(), testData) require.NoError(t, err) require.NotNil(t, encryptedData) require.NotEmpty(t, encryptedData.Data) diff --git a/internal/envelope/types.go b/internal/envelope/types.go index b458f35d..6618ce6c 100644 --- a/internal/envelope/types.go +++ b/internal/envelope/types.go @@ -1,6 +1,9 @@ package envelope -import "encoding/json" +import ( + "context" + "encoding/json" +) // EncryptedData represents encrypted data along with metadata about the encryption type. type EncryptedData struct { @@ -34,5 +37,5 @@ func (ed *EncryptedData) ToMap() map[string]any { type Encryptor interface { // Encrypt encrypts data using envelope encryption, returning an EncryptedData struct // containing the encrypted payload and encryption type metadata. - Encrypt(data []byte) (*EncryptedData, error) + Encrypt(ctx context.Context, data []byte) (*EncryptedData, error) } diff --git a/pkg/agent/run.go b/pkg/agent/run.go index cba3c0a0..78fa64b3 100644 --- a/pkg/agent/run.go +++ b/pkg/agent/run.go @@ -32,6 +32,7 @@ import ( "github.com/jetstack/preflight/api" "github.com/jetstack/preflight/internal/envelope" + "github.com/jetstack/preflight/internal/envelope/keyfetch" "github.com/jetstack/preflight/internal/envelope/rsa" "github.com/jetstack/preflight/pkg/client" "github.com/jetstack/preflight/pkg/datagatherer" @@ -164,6 +165,10 @@ func Run(cmd *cobra.Command, args []string) (returnErr error) { return fmt.Errorf("failed to create event recorder: %v", err) } + // Check if secret encryption is enabled via environment variable + // When enabled, secret data will be kept for encryption instead of being redacted + encryptSecrets := strings.ToLower(os.Getenv("ARK_SEND_SECRET_VALUES")) == "true" + dataGatherers := map[string]datagatherer.DataGatherer{} // load datagatherer config and boot each one @@ -184,14 +189,10 @@ func Run(cmd *cobra.Command, args []string) (returnErr error) { dynDg.ExcludeAnnotKeys = config.ExcludeAnnotationKeysRegex dynDg.ExcludeLabelKeys = config.ExcludeLabelKeysRegex - // Check if secret encryption is enabled via environment variable - // When enabled, secret data will be kept for encryption instead of being redacted - encryptSecrets := strings.ToLower(os.Getenv("ARK_SEND_SECRET_VALUES")) - - if encryptSecrets == "true" { + if encryptSecrets { var err error - dynDg.Encryptor, err = loadEncryptor() + dynDg.Encryptor, err = loadEncryptor(preflightClient) if err != nil { log.Error(err, "Failed to set up encryptor for secrets, secret data will not be sent") } @@ -273,14 +274,15 @@ func Run(cmd *cobra.Command, args []string) (returnErr error) { } // loadEncryptor sets up an encryptor for encrypting secrets. For now, it just loads a hardcoded public key -func loadEncryptor() (envelope.Encryptor, error) { - // TODO(@SgtCoDFish): this will eventually fetch a key from JWKS endpoint when that endpoint is available - key, keyID, err := rsa.LoadHardcodedPublicKey() - if err != nil { - return nil, fmt.Errorf("failed to load public key for secret encryption: %w", err) +func loadEncryptor(preflightClient client.Client) (envelope.Encryptor, error) { + cyberarkClient, ok := preflightClient.(*client.CyberArkClient) + if !ok { + return nil, fmt.Errorf("secret encryption is only supported for CyberArk clients") } - encryptor, err := rsa.NewEncryptor(keyID, key) + fetcher := keyfetch.NewClient(cyberarkClient.DiscoveryClient()) + + encryptor, err := rsa.NewEncryptor(fetcher) if err != nil { return nil, fmt.Errorf("failed to create encryptor for secret encryption: %w", err) } diff --git a/pkg/client/client_cyberark.go b/pkg/client/client_cyberark.go index 394ae144..00b05d28 100644 --- a/pkg/client/client_cyberark.go +++ b/pkg/client/client_cyberark.go @@ -28,6 +28,8 @@ import ( type CyberArkClient struct { configLoader cyberark.ClientConfigLoader httpClient *http.Client + + discoveryClient *servicediscovery.Client } var _ Client = &CyberArkClient{} @@ -41,14 +43,15 @@ var _ Client = &CyberArkClient{} func NewCyberArk(httpClient *http.Client) (*CyberArkClient, error) { configLoader := cyberark.LoadClientConfigFromEnvironment - _, err := configLoader() + cfg, err := configLoader() if err != nil { return nil, err } return &CyberArkClient{ - configLoader: configLoader, - httpClient: httpClient, + configLoader: configLoader, + httpClient: httpClient, + discoveryClient: servicediscovery.New(httpClient, cfg.Subdomain), }, nil } @@ -67,9 +70,7 @@ func (o *CyberArkClient) PostDataReadingsWithOptions(ctx context.Context, readin return fmt.Errorf("failed to load config: %w", err) } - discoveryClient := servicediscovery.New(o.httpClient) - - serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(ctx, cfg.Subdomain) + serviceMap, tenantUUID, err := o.discoveryClient.DiscoverServices(ctx) if err != nil { return err } @@ -95,6 +96,10 @@ func (o *CyberArkClient) PostDataReadingsWithOptions(ctx context.Context, readin return nil } +func (o *CyberArkClient) DiscoveryClient() *servicediscovery.Client { + return o.discoveryClient +} + // baseSnapshotFromOptions creates a base snapshot with common fields from the provided options. // This includes the cluster name, description, and agent version. // Other fields like ClusterID and K8SVersion need to be populated separately. @@ -102,7 +107,7 @@ func baseSnapshotFromOptions(opts Options) dataupload.Snapshot { return dataupload.Snapshot{ ClusterName: opts.ClusterName, ClusterDescription: opts.ClusterDescription, - AgentVersion: version.PreflightVersion, + AgentVersion: version.CYBRVersion, } } diff --git a/pkg/client/client_cyberark_convertdatareadings_test.go b/pkg/client/client_cyberark_convertdatareadings_test.go index a0fc2c27..d44c650a 100644 --- a/pkg/client/client_cyberark_convertdatareadings_test.go +++ b/pkg/client/client_cyberark_convertdatareadings_test.go @@ -45,7 +45,7 @@ func TestBaseSnapshotFromOptions(t *testing.T) { want: dataupload.Snapshot{ ClusterName: "some-cluster-name", ClusterDescription: "some-cluster-description", - AgentVersion: preflightversion.PreflightVersion, + AgentVersion: preflightversion.CYBRVersion, }, }, } diff --git a/pkg/client/client_cyberark_test.go b/pkg/client/client_cyberark_test.go index 9a963300..a80da178 100644 --- a/pkg/client/client_cyberark_test.go +++ b/pkg/client/client_cyberark_test.go @@ -57,7 +57,7 @@ func TestCyberArkClient_PostDataReadingsWithOptions_RealAPI(t *testing.T) { ctx := klog.NewContext(t.Context(), logger) var rootCAs *x509.CertPool - httpClient := http_client.NewDefaultClient(version.UserAgent(), rootCAs) + httpClient := http_client.NewDefaultClient(version.UserAgentCYBR(), rootCAs) c, err := client.NewCyberArk(httpClient) if err != nil { diff --git a/pkg/datagatherer/k8sdynamic/dynamic.go b/pkg/datagatherer/k8sdynamic/dynamic.go index df490db4..5c82d939 100644 --- a/pkg/datagatherer/k8sdynamic/dynamic.go +++ b/pkg/datagatherer/k8sdynamic/dynamic.go @@ -469,7 +469,7 @@ func (g *DataGathererDynamic) redactList(ctx context.Context, list []*api.Gather // If encryption is enabled, we encrypt the data and preserve it, but we still need to redact later. // If encryption is enabled and _fails_, we MUST still redact the data field to avoid leaking sensitive information. if g.Encryptor != nil { - err := g.encryptDataField(resource) + err := g.encryptDataField(ctx, resource) if err != nil { // WARNING: We CAN NOT return an error here, as that would leak the secret data log := klog.FromContext(ctx).WithName("encryptDataField") @@ -544,7 +544,7 @@ var encryptedDataField = FieldPath{encryptedDataFieldName} // in a new field with the name of [encryptedDataFieldName]. The original `data` field is left unchanged, on the // assumption that it will be redacted after the encryption step. // This function does not check that the given resource is actually a Secret; that is the caller's responsibility. -func (g *DataGathererDynamic) encryptDataField(secret *unstructured.Unstructured) error { +func (g *DataGathererDynamic) encryptDataField(ctx context.Context, secret *unstructured.Unstructured) error { if g.Encryptor == nil { return nil } @@ -569,7 +569,7 @@ func (g *DataGathererDynamic) encryptDataField(secret *unstructured.Unstructured return fmt.Errorf("failed to marshal secret data field for encryption: %w", err) } - encryptedData, err := g.Encryptor.Encrypt(plaintextData) + encryptedData, err := g.Encryptor.Encrypt(ctx, plaintextData) if err != nil { return fmt.Errorf("failed to encrypt secret data during redaction: %w", err) } diff --git a/pkg/datagatherer/k8sdynamic/dynamic_test.go b/pkg/datagatherer/k8sdynamic/dynamic_test.go index 335d6571..9b4651eb 100644 --- a/pkg/datagatherer/k8sdynamic/dynamic_test.go +++ b/pkg/datagatherer/k8sdynamic/dynamic_test.go @@ -1,6 +1,7 @@ package k8sdynamic import ( + "context" "crypto/rand" stdrsa "crypto/rsa" "encoding/base64" @@ -32,6 +33,7 @@ import ( "github.com/jetstack/preflight/api" "github.com/jetstack/preflight/internal/envelope" + "github.com/jetstack/preflight/internal/envelope/keyfetch" "github.com/jetstack/preflight/internal/envelope/rsa" ) @@ -405,7 +407,7 @@ func init() { type failEncryptor struct{} -func (fe *failEncryptor) Encrypt(plaintext []byte) (*envelope.EncryptedData, error) { +func (fe *failEncryptor) Encrypt(_ context.Context, plaintext []byte) (*envelope.EncryptedData, error) { return nil, fmt.Errorf("encryption failed") } @@ -415,7 +417,8 @@ func TestDynamicGatherer_Fetch(t *testing.T) { keyID := "test-key-id" - encryptor, err := rsa.NewEncryptor(keyID, privKey.Public().(*stdrsa.PublicKey)) + fetcher := keyfetch.NewFakeClientWithKey(keyID, privKey.Public().(*stdrsa.PublicKey)) + encryptor, err := rsa.NewEncryptor(fetcher) if err != nil { t.Fatalf("failed to create encryptor: %v", err) } diff --git a/pkg/version/version.go b/pkg/version/version.go index 219c5c5b..b1fdc0cb 100644 --- a/pkg/version/version.go +++ b/pkg/version/version.go @@ -31,3 +31,16 @@ func UserAgent() string { func SetUserAgent(req *http.Request) { req.Header.Set("User-Agent", UserAgent()) } + +// CYBRVersion DO NOT MERGE +var CYBRVersion = "v999.0.0" + +// UserAgentCYBR DO NOT MERGE +func UserAgentCYBR() string { + return fmt.Sprintf("Mozilla/5.0 venafi-kubernetes-agent/%s", CYBRVersion) +} + +// SetUserAgentCYBR DO NOT MERGE +func SetUserAgentCYBR(req *http.Request) { + req.Header.Set("User-Agent", UserAgentCYBR()) +}