diff --git a/internal/api/catalog_version_test.go b/internal/api/catalog_version_test.go new file mode 100644 index 0000000..b1c02ab --- /dev/null +++ b/internal/api/catalog_version_test.go @@ -0,0 +1,139 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetLatestCatalogVersion_Success(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1/license/catalog/aws/version", r.URL.Path) + assert.Equal(t, http.MethodGet, r.Method) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "emulator_type": "aws", + "version": "4.14.0", + }) + })) + defer srv.Close() + + client := NewPlatformClient(srv.URL) + version, err := client.GetLatestCatalogVersion(context.Background(), "aws") + + require.NoError(t, err) + assert.Equal(t, "4.14.0", version) +} + +func TestGetLatestCatalogVersion_NonOKStatus(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + })) + defer srv.Close() + + client := NewPlatformClient(srv.URL) + _, err := client.GetLatestCatalogVersion(context.Background(), "aws") + + require.Error(t, err) + assert.Contains(t, err.Error(), "status 400") +} + +func TestGetLatestCatalogVersion_EmptyVersion(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "emulator_type": "aws", + "version": "", + }) + })) + defer srv.Close() + + client := NewPlatformClient(srv.URL) + _, err := client.GetLatestCatalogVersion(context.Background(), "aws") + + require.Error(t, err) + assert.Contains(t, err.Error(), "incomplete catalog response") +} + +func TestGetLatestCatalogVersion_MissingVersion(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "emulator_type": "aws", + }) + })) + defer srv.Close() + + client := NewPlatformClient(srv.URL) + _, err := client.GetLatestCatalogVersion(context.Background(), "aws") + + require.Error(t, err) + assert.Contains(t, err.Error(), "incomplete catalog response") +} + +func TestGetLatestCatalogVersion_EmptyEmulatorType(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "emulator_type": "", + "version": "4.14.0", + }) + })) + defer srv.Close() + + client := NewPlatformClient(srv.URL) + _, err := client.GetLatestCatalogVersion(context.Background(), "aws") + + require.Error(t, err) + assert.Contains(t, err.Error(), "incomplete catalog response") +} + +func TestGetLatestCatalogVersion_MismatchedEmulatorType(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "emulator_type": "azure", + "version": "4.14.0", + }) + })) + defer srv.Close() + + client := NewPlatformClient(srv.URL) + _, err := client.GetLatestCatalogVersion(context.Background(), "aws") + + require.Error(t, err) + assert.Contains(t, err.Error(), "unexpected emulator_type") +} + +func TestGetLatestCatalogVersion_Timeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // hang until request context is cancelled + <-r.Context().Done() + })) + defer srv.Close() + + client := NewPlatformClient(srv.URL) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err := client.GetLatestCatalogVersion(ctx, "aws") + + require.Error(t, err) +} + +func TestGetLatestCatalogVersion_ServerDown(t *testing.T) { + // use a URL with no server behind it + client := NewPlatformClient("http://127.0.0.1:1") + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + _, err := client.GetLatestCatalogVersion(ctx, "aws") + + require.Error(t, err) +} diff --git a/internal/api/client.go b/internal/api/client.go index b1b078a..8596a95 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -1,5 +1,7 @@ package api +//go:generate mockgen -source=client.go -destination=mock_platform_api.go -package=api + import ( "bytes" "context" @@ -7,6 +9,7 @@ import ( "fmt" "log" "net/http" + "net/url" "time" "github.com/localstack/lstk/internal/version" @@ -20,6 +23,7 @@ type PlatformAPI interface { ExchangeAuthRequest(ctx context.Context, id, exchangeToken string) (string, error) GetLicenseToken(ctx context.Context, bearerToken string) (string, error) GetLicense(ctx context.Context, req *LicenseRequest) error + GetLatestCatalogVersion(ctx context.Context, emulatorType string) (string, error) } type AuthRequest struct { @@ -238,3 +242,45 @@ func (c *PlatformClient) GetLicense(ctx context.Context, licReq *LicenseRequest) return fmt.Errorf("license request failed with status %d", resp.StatusCode) } } + +type catalogVersionResponse struct { + EmulatorType string `json:"emulator_type"` + Version string `json:"version"` +} + +func (c *PlatformClient) GetLatestCatalogVersion(ctx context.Context, emulatorType string) (string, error) { + reqURL := fmt.Sprintf("%s/v1/license/catalog/%s/version", c.baseURL, url.PathEscape(emulatorType)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to get catalog version: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Printf("failed to close response body: %v", err) + } + }() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("failed to get catalog version: status %d", resp.StatusCode) + } + + var versionResp catalogVersionResponse + if err := json.NewDecoder(resp.Body).Decode(&versionResp); err != nil { + return "", fmt.Errorf("failed to decode response: %w", err) + } + + if versionResp.EmulatorType == "" || versionResp.Version == "" { + return "", fmt.Errorf("incomplete catalog response: emulator_type=%q version=%q", versionResp.EmulatorType, versionResp.Version) + } + + if versionResp.EmulatorType != emulatorType { + return "", fmt.Errorf("unexpected emulator_type: got=%q want=%q", versionResp.EmulatorType, emulatorType) + } + + return versionResp.Version, nil +} diff --git a/internal/api/mock_platform_api.go b/internal/api/mock_platform_api.go new file mode 100644 index 0000000..82fa9bd --- /dev/null +++ b/internal/api/mock_platform_api.go @@ -0,0 +1,130 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: client.go +// +// Generated by this command: +// +// mockgen -source=client.go -destination=mock_platform_api.go -package=api +// + +// Package api is a generated GoMock package. +package api + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockPlatformAPI is a mock of PlatformAPI interface. +type MockPlatformAPI struct { + ctrl *gomock.Controller + recorder *MockPlatformAPIMockRecorder + isgomock struct{} +} + +// MockPlatformAPIMockRecorder is the mock recorder for MockPlatformAPI. +type MockPlatformAPIMockRecorder struct { + mock *MockPlatformAPI +} + +// NewMockPlatformAPI creates a new mock instance. +func NewMockPlatformAPI(ctrl *gomock.Controller) *MockPlatformAPI { + mock := &MockPlatformAPI{ctrl: ctrl} + mock.recorder = &MockPlatformAPIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPlatformAPI) EXPECT() *MockPlatformAPIMockRecorder { + return m.recorder +} + +// CheckAuthRequestConfirmed mocks base method. +func (m *MockPlatformAPI) CheckAuthRequestConfirmed(ctx context.Context, id, exchangeToken string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckAuthRequestConfirmed", ctx, id, exchangeToken) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CheckAuthRequestConfirmed indicates an expected call of CheckAuthRequestConfirmed. +func (mr *MockPlatformAPIMockRecorder) CheckAuthRequestConfirmed(ctx, id, exchangeToken any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckAuthRequestConfirmed", reflect.TypeOf((*MockPlatformAPI)(nil).CheckAuthRequestConfirmed), ctx, id, exchangeToken) +} + +// CreateAuthRequest mocks base method. +func (m *MockPlatformAPI) CreateAuthRequest(ctx context.Context) (*AuthRequest, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateAuthRequest", ctx) + ret0, _ := ret[0].(*AuthRequest) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateAuthRequest indicates an expected call of CreateAuthRequest. +func (mr *MockPlatformAPIMockRecorder) CreateAuthRequest(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAuthRequest", reflect.TypeOf((*MockPlatformAPI)(nil).CreateAuthRequest), ctx) +} + +// ExchangeAuthRequest mocks base method. +func (m *MockPlatformAPI) ExchangeAuthRequest(ctx context.Context, id, exchangeToken string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExchangeAuthRequest", ctx, id, exchangeToken) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ExchangeAuthRequest indicates an expected call of ExchangeAuthRequest. +func (mr *MockPlatformAPIMockRecorder) ExchangeAuthRequest(ctx, id, exchangeToken any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeAuthRequest", reflect.TypeOf((*MockPlatformAPI)(nil).ExchangeAuthRequest), ctx, id, exchangeToken) +} + +// GetLatestCatalogVersion mocks base method. +func (m *MockPlatformAPI) GetLatestCatalogVersion(ctx context.Context, emulatorType string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLatestCatalogVersion", ctx, emulatorType) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLatestCatalogVersion indicates an expected call of GetLatestCatalogVersion. +func (mr *MockPlatformAPIMockRecorder) GetLatestCatalogVersion(ctx, emulatorType any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestCatalogVersion", reflect.TypeOf((*MockPlatformAPI)(nil).GetLatestCatalogVersion), ctx, emulatorType) +} + +// GetLicense mocks base method. +func (m *MockPlatformAPI) GetLicense(ctx context.Context, req *LicenseRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLicense", ctx, req) + ret0, _ := ret[0].(error) + return ret0 +} + +// GetLicense indicates an expected call of GetLicense. +func (mr *MockPlatformAPIMockRecorder) GetLicense(ctx, req any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLicense", reflect.TypeOf((*MockPlatformAPI)(nil).GetLicense), ctx, req) +} + +// GetLicenseToken mocks base method. +func (m *MockPlatformAPI) GetLicenseToken(ctx context.Context, bearerToken string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLicenseToken", ctx, bearerToken) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLicenseToken indicates an expected call of GetLicenseToken. +func (mr *MockPlatformAPIMockRecorder) GetLicenseToken(ctx, bearerToken any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLicenseToken", reflect.TypeOf((*MockPlatformAPI)(nil).GetLicenseToken), ctx, bearerToken) +} diff --git a/internal/container/start.go b/internal/container/start.go index 60a14e2..2c6b704 100644 --- a/internal/container/start.go +++ b/internal/container/start.go @@ -7,6 +7,7 @@ import ( "os" stdruntime "runtime" "slices" + "strings" "time" "github.com/containerd/errdefs" @@ -78,13 +79,14 @@ func Start(ctx context.Context, rt runtime.Runtime, sink output.Sink, opts Start } env := append(resolvedEnv, "LOCALSTACK_AUTH_TOKEN="+token) containers[i] = runtime.ContainerConfig{ - Image: image, - Name: c.Name(), - Port: c.Port, - HealthPath: healthPath, - Env: env, - Tag: c.Tag, - ProductName: productName, + Image: image, + Name: c.Name(), + Port: c.Port, + HealthPath: healthPath, + Env: env, + Tag: c.Tag, + ProductName: productName, + EmulatorType: string(c.Type), } } @@ -96,13 +98,18 @@ func Start(ctx context.Context, rt runtime.Runtime, sink output.Sink, opts Start return nil } - // TODO validate license for tag "latest" without resolving the actual image version, - // and avoid pulling all images first + containers = resolveContainerVersions(ctx, opts.PlatformClient, containers) + if err := pullImages(ctx, rt, sink, containers); err != nil { return err } - if err := validateLicenses(ctx, rt, sink, opts.PlatformClient, containers, token); err != nil { + containers, err = resolveVersionsFromImages(ctx, rt, containers) + if err != nil { + return err + } + + if err := validateLicenses(ctx, sink, opts.PlatformClient, containers, token); err != nil { return err } @@ -171,9 +178,9 @@ func pullImages(ctx context.Context, rt runtime.Runtime, sink output.Sink, conta return nil } -func validateLicenses(ctx context.Context, rt runtime.Runtime, sink output.Sink, platformClient api.PlatformAPI, containers []runtime.ContainerConfig, token string) error { +func validateLicenses(ctx context.Context, sink output.Sink, platformClient api.PlatformAPI, containers []runtime.ContainerConfig, token string) error { for _, c := range containers { - if err := validateLicense(ctx, rt, sink, platformClient, c, token); err != nil { + if err := validateLicense(ctx, sink, platformClient, c, token); err != nil { return err } } @@ -234,15 +241,55 @@ func emitPortInUseError(sink output.Sink, port string) { }) } -func validateLicense(ctx context.Context, rt runtime.Runtime, sink output.Sink, platformClient api.PlatformAPI, containerConfig runtime.ContainerConfig, token string) error { - version := containerConfig.Tag - if version == "" || version == "latest" { - actualVersion, err := rt.GetImageVersion(ctx, containerConfig.Image) +// resolveContainerVersions replaces "latest" image tags with a specific version +// resolved from the catalog API, so the subsequent pull targets a pinned version. +// If the API is unreachable for a given container, its original image reference is preserved. +func resolveContainerVersions(ctx context.Context, platformClient api.PlatformAPI, containers []runtime.ContainerConfig) []runtime.ContainerConfig { + resolved := make([]runtime.ContainerConfig, len(containers)) + copy(resolved, containers) + for i, c := range resolved { + if c.Tag != "" && c.Tag != "latest" { + continue + } + + apiCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + v, err := platformClient.GetLatestCatalogVersion(apiCtx, c.EmulatorType) + cancel() + + if err != nil || v == "" { + continue + } + + resolved[i].Tag = v + if idx := strings.LastIndex(c.Image, ":"); idx != -1 { + resolved[i].Image = c.Image[:idx+1] + v + } else { + resolved[i].Image = c.Image + ":" + v + } + } + return resolved +} + +// resolveVersionsFromImages inspects pulled images to resolve any remaining "latest" tags +// that the pre-pull catalog API call could not resolve (e.g. due to network unavailability). +func resolveVersionsFromImages(ctx context.Context, rt runtime.Runtime, containers []runtime.ContainerConfig) ([]runtime.ContainerConfig, error) { + resolved := make([]runtime.ContainerConfig, len(containers)) + copy(resolved, containers) + for i, c := range resolved { + if c.Tag != "" && c.Tag != "latest" { + continue + } + v, err := rt.GetImageVersion(ctx, c.Image) if err != nil { - return fmt.Errorf("could not resolve version from image %s: %w", containerConfig.Image, err) + return nil, fmt.Errorf("could not resolve version from image %s: %w", c.Image, err) } - version = actualVersion + resolved[i].Tag = v } + return resolved, nil +} + +func validateLicense(ctx context.Context, sink output.Sink, platformClient api.PlatformAPI, containerConfig runtime.ContainerConfig, token string) error { + version := containerConfig.Tag output.EmitStatus(sink, "validating license", containerConfig.Name, version) diff --git a/internal/container/start_test.go b/internal/container/start_test.go index f422baa..e557d8c 100644 --- a/internal/container/start_test.go +++ b/internal/container/start_test.go @@ -6,6 +6,7 @@ import ( "io" "testing" + "github.com/localstack/lstk/internal/api" "github.com/localstack/lstk/internal/output" "github.com/localstack/lstk/internal/runtime" "github.com/stretchr/testify/assert" @@ -13,6 +14,91 @@ import ( "go.uber.org/mock/gomock" ) +func TestResolveContainerVersions_PinnedTagIsUnchanged(t *testing.T) { + ctrl := gomock.NewController(t) + mockPlatform := api.NewMockPlatformAPI(ctrl) + // API must not be called for pinned tags + containers := []runtime.ContainerConfig{ + {Tag: "3.8.1", Image: "localstack/localstack-pro:3.8.1", EmulatorType: "aws"}, + } + + result := resolveContainerVersions(context.Background(), mockPlatform, containers) + + assert.Equal(t, "3.8.1", result[0].Tag) + assert.Equal(t, "localstack/localstack-pro:3.8.1", result[0].Image) +} + +func TestResolveContainerVersions_ResolvesLatestToSpecificVersion(t *testing.T) { + ctrl := gomock.NewController(t) + mockPlatform := api.NewMockPlatformAPI(ctrl) + mockPlatform.EXPECT().GetLatestCatalogVersion(gomock.Any(), "aws").Return("3.8.1", nil) + containers := []runtime.ContainerConfig{ + {Tag: "latest", Image: "localstack/localstack-pro:latest", EmulatorType: "aws"}, + } + + result := resolveContainerVersions(context.Background(), mockPlatform, containers) + + assert.Equal(t, "3.8.1", result[0].Tag) + assert.Equal(t, "localstack/localstack-pro:3.8.1", result[0].Image) +} + +func TestResolveContainerVersions_KeepsLatestWhenAPIFails(t *testing.T) { + ctrl := gomock.NewController(t) + mockPlatform := api.NewMockPlatformAPI(ctrl) + mockPlatform.EXPECT().GetLatestCatalogVersion(gomock.Any(), "aws").Return("", errors.New("api down")) + containers := []runtime.ContainerConfig{ + {Tag: "latest", Image: "localstack/localstack-pro:latest", EmulatorType: "aws"}, + } + + result := resolveContainerVersions(context.Background(), mockPlatform, containers) + + assert.Equal(t, "latest", result[0].Tag) + assert.Equal(t, "localstack/localstack-pro:latest", result[0].Image) +} + +func TestResolveVersionsFromImages_PinnedTagIsUnchanged(t *testing.T) { + ctrl := gomock.NewController(t) + mockRT := runtime.NewMockRuntime(ctrl) + // GetImageVersion must not be called for pinned tags + containers := []runtime.ContainerConfig{ + {Tag: "3.8.1", Image: "localstack/localstack-pro:3.8.1"}, + } + + result, err := resolveVersionsFromImages(context.Background(), mockRT, containers) + + require.NoError(t, err) + assert.Equal(t, "3.8.1", result[0].Tag) +} + +func TestResolveVersionsFromImages_ResolvesLatestViaImageInspection(t *testing.T) { + ctrl := gomock.NewController(t) + mockRT := runtime.NewMockRuntime(ctrl) + mockRT.EXPECT().GetImageVersion(gomock.Any(), "localstack/localstack-pro:latest").Return("3.8.1", nil) + containers := []runtime.ContainerConfig{ + {Tag: "latest", Image: "localstack/localstack-pro:latest"}, + } + + result, err := resolveVersionsFromImages(context.Background(), mockRT, containers) + + require.NoError(t, err) + assert.Equal(t, "3.8.1", result[0].Tag) +} + +func TestResolveVersionsFromImages_ReturnsErrorWhenImageInspectionFails(t *testing.T) { + ctrl := gomock.NewController(t) + mockRT := runtime.NewMockRuntime(ctrl) + mockRT.EXPECT().GetImageVersion(gomock.Any(), "localstack/localstack-pro:latest"). + Return("", errors.New("image not found")) + containers := []runtime.ContainerConfig{ + {Tag: "latest", Image: "localstack/localstack-pro:latest"}, + } + + _, err := resolveVersionsFromImages(context.Background(), mockRT, containers) + + require.Error(t, err) + assert.Contains(t, err.Error(), "image not found") +} + func TestStart_ReturnsEarlyIfRuntimeUnhealthy(t *testing.T) { ctrl := gomock.NewController(t) mockRT := runtime.NewMockRuntime(ctrl) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 7a54b5d..cfe70b2 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -8,13 +8,14 @@ import ( ) type ContainerConfig struct { - Image string - Name string - Port string - HealthPath string - Env []string // e.g., ["KEY=value", "FOO=bar"] - Tag string - ProductName string + Image string + Name string + Port string + HealthPath string + Env []string // e.g., ["KEY=value", "FOO=bar"] + Tag string + ProductName string + EmulatorType string } type PullProgress struct {