diff --git a/experimental/apps-mcp/lib/middlewares/warehouse.go b/experimental/apps-mcp/lib/middlewares/warehouse.go index 9bf1b0a071..dd56a65828 100644 --- a/experimental/apps-mcp/lib/middlewares/warehouse.go +++ b/experimental/apps-mcp/lib/middlewares/warehouse.go @@ -2,15 +2,12 @@ package middlewares import ( "context" - "errors" "fmt" - "net/url" - "sort" "sync" "github.com/databricks/cli/experimental/apps-mcp/lib/session" + "github.com/databricks/cli/libs/databrickscfg/cfgpickers" "github.com/databricks/cli/libs/env" - "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/service/sql" ) @@ -83,13 +80,14 @@ func GetWarehouseID(ctx context.Context) (string, error) { } func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) { + w, err := GetDatabricksClient(ctx) + if err != nil { + return nil, fmt.Errorf("get databricks client: %w", err) + } + // first resolve DATABRICKS_WAREHOUSE_ID env variable warehouseID := env.Get(ctx, "DATABRICKS_WAREHOUSE_ID") if warehouseID != "" { - w, err := GetDatabricksClient(ctx) - if err != nil { - return nil, fmt.Errorf("get databricks client: %w", err) - } warehouse, err := w.Warehouses.Get(ctx, sql.GetWarehouseRequest{ Id: warehouseID, }) @@ -103,48 +101,5 @@ func getDefaultWarehouse(ctx context.Context) (*sql.EndpointInfo, error) { }, nil } - apiClient, err := GetApiClient(ctx) - if err != nil { - return nil, err - } - - apiPath := "/api/2.0/sql/warehouses" - params := url.Values{} - params.Add("skip_cannot_use", "true") - fullPath := fmt.Sprintf("%s?%s", apiPath, params.Encode()) - - var response sql.ListWarehousesResponse - err = apiClient.Do(ctx, "GET", fullPath, httpclient.WithResponseUnmarshal(&response)) - if err != nil { - return nil, err - } - - priorities := map[sql.State]int{ - sql.StateRunning: 1, - sql.StateStarting: 2, - sql.StateStopped: 3, - sql.StateStopping: 4, - sql.StateDeleted: 99, - sql.StateDeleting: 99, - } - - warehouses := response.Warehouses - sort.Slice(warehouses, func(i, j int) bool { - return priorities[warehouses[i].State] < priorities[warehouses[j].State] - }) - - if len(warehouses) == 0 { - return nil, errNoWarehouses() - } - - firstWarehouse := warehouses[0] - if firstWarehouse.State == sql.StateDeleted || firstWarehouse.State == sql.StateDeleting { - return nil, errNoWarehouses() - } - - return &firstWarehouse, nil -} - -func errNoWarehouses() error { - return errors.New("no warehouse found. You can explicitly set the warehouse ID using the DATABRICKS_WAREHOUSE_ID environment variable") + return cfgpickers.GetDefaultWarehouse(ctx, w) }