Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions internal/toolsets/vulnerability/cluster_resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package vulnerability

import (
"context"
"fmt"

v1 "github.com/stackrox/rox/generated/api/v1"
"google.golang.org/grpc"
)

// resolveClusterID resolves a cluster name to its ID.
// Returns error if cluster name is not found or if API call fails.
func resolveClusterID(ctx context.Context, conn *grpc.ClientConn,
clusterID string, clusterName string) (string, error) {
// Cluster ID has priority.
if clusterID != "" {
return clusterID, nil
}

if clusterName == "" {
return "", nil
}

client := v1.NewClustersServiceClient(conn)

// Use query to filter by cluster name server-side
query := fmt.Sprintf("Cluster:%q", clusterName)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it full or partial match? Should we do a check before returning the [0] element if name really matches?


resp, err := client.GetClusters(ctx, &v1.GetClustersRequest{
Query: query,
})
if err != nil {
return "", fmt.Errorf("failed to fetch clusters: %w", err)
}

clusters := resp.GetClusters()
if len(clusters) == 0 {
return "", fmt.Errorf("cluster with name %q not found", clusterName)
}

// Return the first matching cluster's ID
return clusters[0].GetId(), nil
}
143 changes: 143 additions & 0 deletions internal/toolsets/vulnerability/cluster_resolver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package vulnerability

import (
"context"
"errors"
"net"
"testing"

"github.com/stackrox/rox/generated/storage"
"github.com/stackrox/stackrox-mcp/internal/toolsets/mock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
)

func getBufferConnection(t *testing.T, listener *bufconn.Listener) *grpc.ClientConn {
t.Helper()

// Create a gRPC client connection to the mock server
conn, err := grpc.NewClient(
"passthrough://buffer",
grpc.WithLocalDNSResolution(),
grpc.WithContextDialer(func(_ context.Context, _ string) (net.Conn, error) {
return listener.Dial()
}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
require.NoError(t, err)

return conn
}

func TestResolveClusterID_Success(t *testing.T) {
tests := map[string]struct {
clusterID string
clusterName string
mockClusters []*storage.Cluster
mockError error
expectedID string
expectedQuery string
}{
"only cluster ID": {
clusterID: "only-cluster-id",
clusterName: "",
mockClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}},
expectedID: "only-cluster-id",
},
"cluster ID has priority": {
clusterID: "cluster-with-priority",
clusterName: "production",
mockClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}},
expectedID: "cluster-with-priority",
},
"empty cluster name returns empty ID": {
clusterID: "",
clusterName: "",
mockClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}},
expectedID: "",
},
"cluster name found returns correct ID": {
clusterID: "",
clusterName: "production",
mockClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}},
expectedID: "cluster-1",
expectedQuery: `Cluster:"production"`,
},
}

for testName, testCase := range tests {
t.Run(testName, func(t *testing.T) {
mockService := mock.NewClustersServiceMock(testCase.mockClusters, testCase.mockError)

grpcServer, listener := mock.SetupClusterServer(mockService)
defer grpcServer.Stop()

conn := getBufferConnection(t, listener)

defer func() { _ = conn.Close() }()

clusterID, err := resolveClusterID(
context.Background(),
conn,
testCase.clusterID,
testCase.clusterName,
)

require.NoError(t, err)
assert.Equal(t, testCase.expectedID, clusterID)
assert.Equal(t, testCase.expectedQuery, mockService.GetLastCallQuery())
})
}
}

func TestResolveClusterID_Failure(t *testing.T) {
tests := map[string]struct {
clusterName string
mockClusters []*storage.Cluster
mockError error
expectedErrText string
expectedQuery string
}{
"cluster name not found returns error": {
clusterName: "nonexistent",
mockClusters: []*storage.Cluster{},
expectedErrText: `cluster with name "nonexistent" not found`,
expectedQuery: `Cluster:"nonexistent"`,
},
"API error propagation": {
clusterName: "production",
mockError: errors.New("API connection failed"),
expectedErrText: "failed to fetch clusters:",
expectedQuery: `Cluster:"production"`,
},
}

for testName, testCase := range tests {
t.Run(testName, func(t *testing.T) {
mockService := mock.NewClustersServiceMock(testCase.mockClusters, testCase.mockError)

grpcServer, listener := mock.SetupClusterServer(mockService)
defer grpcServer.Stop()

conn := getBufferConnection(t, listener)

defer func() { _ = conn.Close() }()

clusterID, err := resolveClusterID(
context.Background(),
conn,
"",
testCase.clusterName,
)

require.Error(t, err)
assert.Empty(t, clusterID)
assert.Contains(t, err.Error(), testCase.expectedErrText)

assert.Equal(t, testCase.expectedQuery, mockService.GetLastCallQuery())
})
}
}
38 changes: 27 additions & 11 deletions internal/toolsets/vulnerability/clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@ import (

// getClustersForCVEInput defines the input parameters for get_clusters_for_cve tool.
type getClustersForCVEInput struct {
CVEName string `json:"cveName"`
FilterClusterID string `json:"filterClusterId,omitempty"`
CVEName string `json:"cveName"`
FilterClusterID string `json:"filterClusterId,omitempty"`
FilterClusterName string `json:"filterClusterName,omitempty"`
}

func (input *getClustersForCVEInput) validate() error {
if input.CVEName == "" {
return errors.New("CVE name is required")
}

if input.FilterClusterID != "" && input.FilterClusterName != "" {
return errors.New("cannot specify both filterClusterId and filterClusterName")
}

return nil
}

Expand Down Expand Up @@ -76,9 +81,7 @@ func (t *getClustersForCVETool) GetTool() *mcp.Tool {
" Call ALL THREE CVE tools (get_clusters_with_orchestrator_cve, get_deployments_for_cve, get_nodes_for_cve)" +
" for comprehensive coverage." +
" 2) When user asks specifically about 'orchestrator', 'Kubernetes components'," +
" or 'control plane': Use ONLY this tool." +
" 3) For single cluster queries (e.g., 'in cluster X'): First call list_clusters to get cluster ID," +
" then call ONLY this tool with filterClusterId.",
" or 'control plane': Use ONLY this tool.",
InputSchema: getClustersForCVEInputSchema(),
}
}
Expand All @@ -97,11 +100,13 @@ func getClustersForCVEInputSchema() *jsonschema.Schema {

schema.Properties["cveName"].Description = "CVE name to filter clusters (e.g., CVE-2021-44228)"
schema.Properties["filterClusterId"].Description =
"Optional cluster ID (cluster ID only, not cluster name) to verify if CVE is detected in a specific cluster." +
" Only use this parameter when the user's query explicitly mentions a specific cluster name." +
" When checking if a CVE exists at all, call without this parameter to check all clusters at once." +
" To resolve cluster names to IDs, use list_clusters tool first." +
" If the cluster doesn't exist, respond that the CVE is not detected in that cluster (since it doesn't exist)."
"Optional cluster ID to verify if CVE is detected in a specific cluster." +
" Cannot be used together with filterClusterName." +
" When checking if a CVE exists at all, call without this parameter to check all clusters at once."
schema.Properties["filterClusterName"].Description =
"Optional cluster name to verify if CVE is detected in a specific cluster." +
" Cannot be used together with filterClusterId." +
" When checking if a CVE exists at all, call without this parameter to check all clusters at once."

return schema
}
Expand Down Expand Up @@ -143,7 +148,18 @@ func (t *getClustersForCVETool) handle(

clustersClient := v1.NewClustersServiceClient(conn)

query := buildClusterQuery(input)
// Resolve cluster name to ID if provided
resolvedClusterID, err := resolveClusterID(callCtx, conn, input.FilterClusterID, input.FilterClusterName)
if err != nil {
return nil, nil, err
}

// Build query using the resolved cluster ID
queryInput := getClustersForCVEInput{
CVEName: input.CVEName,
FilterClusterID: resolvedClusterID,
}
query := buildClusterQuery(queryInput)

resp, err := clustersClient.GetClusters(callCtx, &v1.GetClustersRequest{
Query: query,
Expand Down
90 changes: 90 additions & 0 deletions internal/toolsets/vulnerability/clusters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ func TestGetClustersForCVETool_RegisterWith(t *testing.T) {
}

// Unit tests for input validate method.
//
//nolint:dupl // Duplication to `TestNodeInputValidate` is detected. They use different input types.
func TestClusterInputValidate(t *testing.T) {
tests := map[string]struct {
input getClustersForCVEInput
Expand All @@ -78,6 +80,29 @@ func TestClusterInputValidate(t *testing.T) {
expectError: true,
errorMsg: "CVE name is required",
},
"both cluster ID and name provided": {
input: getClustersForCVEInput{
CVEName: "CVE-2021-44228",
FilterClusterID: "cluster-123",
FilterClusterName: "production",
},
expectError: true,
errorMsg: "cannot specify both filterClusterId and filterClusterName",
},
"only cluster ID provided": {
input: getClustersForCVEInput{
CVEName: "CVE-2021-44228",
FilterClusterID: "cluster-123",
},
expectError: false,
},
"only cluster name provided": {
input: getClustersForCVEInput{
CVEName: "CVE-2021-44228",
FilterClusterName: "production",
},
expectError: false,
},
}

for testName, testCase := range tests {
Expand Down Expand Up @@ -297,3 +322,68 @@ func TestClusterHandle_WithFilters(t *testing.T) {
})
}
}

func TestClusterHandle_WithValidClusterNameFilter(t *testing.T) {
tests := map[string]struct {
clusterName string
returnedClusters []*storage.Cluster
expectedQuery string
}{
"cluster name found": {
clusterName: "production",
returnedClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}},
expectedQuery: `CVE:"CVE-2021-44228"+Cluster ID:"cluster-1"`,
},
"empty cluster name": {
clusterName: "",
returnedClusters: []*storage.Cluster{},
expectedQuery: `CVE:"CVE-2021-44228"`,
},
}

for testName, testCase := range tests {
t.Run(testName, func(t *testing.T) {
mockService := mock.NewClustersServiceMock(testCase.returnedClusters, nil)

grpcServer, listener := mock.SetupClusterServer(mockService)
defer grpcServer.Stop()

tool, ok := NewGetClustersForCVETool(createTestClient(t, listener)).(*getClustersForCVETool)
require.True(t, ok)

input := getClustersForCVEInput{
CVEName: "CVE-2021-44228",
FilterClusterName: testCase.clusterName,
}

result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, input)

require.NoError(t, err)
require.NotNil(t, output)
assert.Nil(t, result)
assert.Contains(t, mockService.GetLastCallQuery(), testCase.expectedQuery)
})
}
}

func TestClusterHandle_WithNotValidClusterNameFilter(t *testing.T) {
mockService := mock.NewClustersServiceMock([]*storage.Cluster{}, nil)

grpcServer, listener := mock.SetupClusterServer(mockService)
defer grpcServer.Stop()

tool, ok := NewGetClustersForCVETool(createTestClient(t, listener)).(*getClustersForCVETool)
require.True(t, ok)

input := getClustersForCVEInput{
CVEName: "CVE-2021-44228",
FilterClusterName: "nonexistent",
}

result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, input)

require.Error(t, err)
assert.Contains(t, err.Error(), `cluster with name "nonexistent" not found`)
assert.Nil(t, result)
assert.Nil(t, output)
}
Loading
Loading