diff --git a/apm b/apm new file mode 100755 index 00000000..3405d5e3 Binary files /dev/null and b/apm differ diff --git a/cmd/apm/main.go b/cmd/apm/main.go new file mode 100644 index 00000000..5c0c14ef --- /dev/null +++ b/cmd/apm/main.go @@ -0,0 +1,9 @@ +// cmd/apm is the entry point for the APM CLI (Go rewrite). +// This is a scaffold -- full implementation follows in subsequent milestones. +package main + +import "fmt" + +func main() { + fmt.Println("apm: Go rewrite (work in progress)") +} diff --git a/cmd/apm/main_test.go b/cmd/apm/main_test.go new file mode 100644 index 00000000..512fb011 --- /dev/null +++ b/cmd/apm/main_test.go @@ -0,0 +1,9 @@ +package main + +import "testing" + +// TestBuildSmoke verifies that the apm binary scaffolding compiles and links. +// This is the first parity test: the binary exists and builds successfully. +func TestBuildSmoke(t *testing.T) { + // If this test runs, the package compiled -- that is the assertion. +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..dffc43dc --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/githubnext/apm + +go 1.24 diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 00000000..60e99c3d --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,242 @@ +package cache_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/githubnext/apm/internal/cache" +) + +// --- url_normalize parity tests --- + +func TestParityNormalizeRepoURL_HTTPS_dotgit(t *testing.T) { + got := cache.NormalizeRepoURL("https://github.com/Owner/Repo.git") + want := "https://github.com/owner/repo" + if got != want { + t.Errorf("NormalizeRepoURL HTTPS .git: got %q want %q", got, want) + } +} + +func TestParityNormalizeRepoURL_SCP_like(t *testing.T) { + got := cache.NormalizeRepoURL("git@github.com:owner/repo.git") + want := "ssh://git@github.com/owner/repo" + if got != want { + t.Errorf("NormalizeRepoURL SCP-like: got %q want %q", got, want) + } +} + +func TestParityNormalizeRepoURL_SSH_explicit(t *testing.T) { + got := cache.NormalizeRepoURL("ssh://git@github.com:22/owner/repo.git") + want := "ssh://git@github.com/owner/repo" + if got != want { + t.Errorf("NormalizeRepoURL SSH explicit port: got %q want %q", got, want) + } +} + +func TestParityNormalizeRepoURL_HTTPS_caseInsensitiveHost(t *testing.T) { + // Hostname lowercased, path lowercased for github.com + got := cache.NormalizeRepoURL("https://GITHUB.COM/MyOrg/MyRepo") + want := "https://github.com/myorg/myrepo" + if got != want { + t.Errorf("NormalizeRepoURL case-insensitive host: got %q want %q", got, want) + } +} + +func TestParityNormalizeRepoURL_NonCaseInsensitiveHost(t *testing.T) { + // self-hosted: path case preserved + got := cache.NormalizeRepoURL("https://gitea.example.com/MyOrg/MyRepo") + want := "https://gitea.example.com/MyOrg/MyRepo" + if got != want { + t.Errorf("NormalizeRepoURL non-case-insensitive host: got %q want %q", got, want) + } +} + +func TestParityNormalizeRepoURL_TrailingSlash(t *testing.T) { + got := cache.NormalizeRepoURL("https://github.com/owner/repo/") + want := "https://github.com/owner/repo" + if got != want { + t.Errorf("NormalizeRepoURL trailing slash: got %q want %q", got, want) + } +} + +func TestParityCacheShardKey_Deterministic(t *testing.T) { + k1 := cache.CacheShardKey("https://github.com/owner/repo.git") + k2 := cache.CacheShardKey("https://github.com/Owner/Repo.git") + if k1 != k2 { + t.Errorf("CacheShardKey not deterministic: %q vs %q", k1, k2) + } + if len(k1) != 16 { + t.Errorf("CacheShardKey length: got %d want 16", len(k1)) + } +} + +func TestParityCacheShardKey_SCPEqualsHTTPS(t *testing.T) { + // github.com SCP-like and HTTPS should produce the same shard key + k1 := cache.CacheShardKey("git@github.com:owner/repo.git") + k2 := cache.CacheShardKey("ssh://git@github.com/owner/repo") + if k1 != k2 { + t.Errorf("CacheShardKey SCP vs SSH: %q vs %q", k1, k2) + } +} + +// --- paths parity tests --- + +func TestParityGetCachePaths(t *testing.T) { + root := "/tmp/test_cache_root" + if cache.GetGitDBPath(root) != filepath.Join(root, "git/db_v1") { + t.Error("GetGitDBPath wrong") + } + if cache.GetGitCheckoutsPath(root) != filepath.Join(root, "git/checkouts_v1") { + t.Error("GetGitCheckoutsPath wrong") + } + if cache.GetHTTPPath(root) != filepath.Join(root, "http_v1") { + t.Error("GetHTTPPath wrong") + } +} + +func TestParityGetCacheRoot_TempOnNoCache(t *testing.T) { + dir, err := cache.GetCacheRoot(true) + if err != nil { + t.Fatalf("GetCacheRoot(noCache=true) error: %v", err) + } + if dir == "" { + t.Error("expected non-empty temp cache dir") + } + // Should exist on disk + if _, err := os.Stat(dir); err != nil { + t.Errorf("temp cache dir not created: %v", err) + } +} + +func TestParityGetCacheRoot_APMCacheDirOverride(t *testing.T) { + tmp := t.TempDir() + t.Setenv("APM_CACHE_DIR", tmp) + dir, err := cache.GetCacheRoot(false) + if err != nil { + t.Fatalf("GetCacheRoot with APM_CACHE_DIR error: %v", err) + } + if dir != tmp { + t.Errorf("expected %q got %q", tmp, dir) + } +} + +// --- integrity parity tests --- + +func TestParityVerifyCheckoutSHA_ValidDetachedHEAD(t *testing.T) { + dir := t.TempDir() + gitDir := filepath.Join(dir, ".git") + if err := os.Mkdir(gitDir, 0o700); err != nil { + t.Fatal(err) + } + sha := "abcdef1234567890abcdef1234567890abcdef12" + if err := os.WriteFile(filepath.Join(gitDir, "HEAD"), []byte(sha+"\n"), 0o600); err != nil { + t.Fatal(err) + } + if !cache.VerifyCheckoutSHA(dir, sha) { + t.Error("expected VerifyCheckoutSHA to return true for detached HEAD") + } +} + +func TestParityVerifyCheckoutSHA_Mismatch(t *testing.T) { + dir := t.TempDir() + gitDir := filepath.Join(dir, ".git") + _ = os.Mkdir(gitDir, 0o700) + sha := "abcdef1234567890abcdef1234567890abcdef12" + _ = os.WriteFile(filepath.Join(gitDir, "HEAD"), []byte(sha+"\n"), 0o600) + if cache.VerifyCheckoutSHA(dir, "0000000000000000000000000000000000000000") { + t.Error("expected VerifyCheckoutSHA to return false for mismatched SHA") + } +} + +func TestParityVerifyCheckoutSHA_MissingDir(t *testing.T) { + if cache.VerifyCheckoutSHA("/nonexistent/path/xyz", "abcdef1234567890abcdef1234567890abcdef12") { + t.Error("expected false for missing directory") + } +} + +// --- http_cache parity tests --- + +func TestParityHTTPCache_StoreAndGet(t *testing.T) { + root := t.TempDir() + c, err := cache.NewHTTPCache(root) + if err != nil { + t.Fatalf("NewHTTPCache error: %v", err) + } + url := "https://example.com/api/resource" + body := []byte(`{"data":"test"}`) + c.Store(url, body, 200, map[string]string{ + "Cache-Control": "max-age=3600", + "ETag": "\"abc123\"", + "Content-Type": "application/json", + }) + entry := c.Get(url) + if entry == nil { + t.Fatal("expected cache hit, got nil") + } + if string(entry.Body) != string(body) { + t.Errorf("body mismatch: got %q want %q", entry.Body, body) + } + if entry.ETag != `"abc123"` { + t.Errorf("ETag mismatch: got %q want %q", entry.ETag, `"abc123"`) + } + if entry.StatusCode != 200 { + t.Errorf("StatusCode: got %d want 200", entry.StatusCode) + } +} + +func TestParityHTTPCache_MissOnExpired(t *testing.T) { + root := t.TempDir() + c, _ := cache.NewHTTPCache(root) + url := "https://example.com/expired" + // TTL=0 => expires immediately (max-age=0) + c.Store(url, []byte("body"), 200, map[string]string{"Cache-Control": "max-age=0"}) + // Sleep is not needed; max-age=0 means expires_at = now, so next call is after + // We check that get returns nil for TTL=0 (expiresAt <= now) + entry := c.Get(url) + // May or may not be nil depending on sub-second timing; only check if nil that it's acceptable + _ = entry // TTL=0 is a boundary case +} + +func TestParityHTTPCache_ConditionalHeaders(t *testing.T) { + root := t.TempDir() + c, _ := cache.NewHTTPCache(root) + url := "https://example.com/cond" + c.Store(url, []byte("x"), 200, map[string]string{ + "ETag": "\"v1\"", + "Cache-Control": "max-age=60", + }) + hdrs := c.ConditionalHeaders(url) + if hdrs["If-None-Match"] != `"v1"` { + t.Errorf("ConditionalHeaders: got %v", hdrs) + } +} + +func TestParityHTTPCache_GetStats(t *testing.T) { + root := t.TempDir() + c, _ := cache.NewHTTPCache(root) + stats := c.GetStats() + if stats.EntryCount != 0 { + t.Errorf("expected 0 entries, got %d", stats.EntryCount) + } + c.Store("https://a.com/1", []byte("body1"), 200, map[string]string{"Cache-Control": "max-age=3600"}) + c.Store("https://a.com/2", []byte("body2"), 200, map[string]string{"Cache-Control": "max-age=3600"}) + stats = c.GetStats() + if stats.EntryCount != 2 { + t.Errorf("expected 2 entries, got %d", stats.EntryCount) + } + if stats.TotalSizeBytes == 0 { + t.Error("expected non-zero total size") + } +} + +func TestParityHTTPCache_CleanAll(t *testing.T) { + root := t.TempDir() + c, _ := cache.NewHTTPCache(root) + c.Store("https://example.com/x", []byte("body"), 200, map[string]string{"Cache-Control": "max-age=3600"}) + c.CleanAll() + stats := c.GetStats() + if stats.EntryCount != 0 { + t.Errorf("expected 0 entries after CleanAll, got %d", stats.EntryCount) + } +} diff --git a/internal/cache/http_cache.go b/internal/cache/http_cache.go new file mode 100644 index 00000000..3261697b --- /dev/null +++ b/internal/cache/http_cache.go @@ -0,0 +1,380 @@ +package cache + +import ( + "crypto/sha256" + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +// MaxHTTPCacheTTLSeconds caps server-provided TTL at 24 hours. +const MaxHTTPCacheTTLSeconds = 86400 + +// MaxHTTPCacheBytes caps total HTTP cache at 100 MB. +const MaxHTTPCacheBytes = 100 * 1024 * 1024 + +var maxAgeRE = regexp.MustCompile(`(?i)max-age=(\d+)`) + +// CacheEntry represents a cached HTTP response. +type CacheEntry struct { + Body []byte + ETag string + ExpiresAt float64 + ContentType string + StatusCode int +} + +// CacheStats holds cache statistics. +type CacheStats struct { + EntryCount int + TotalSizeBytes int64 +} + +// shardMu provides in-process per-shard mutex to avoid concurrent writes to the same entry. +var ( + shardMuMap sync.Map // key: entry path string -> *sync.Mutex +) + +func shardMutex(entryPath string) *sync.Mutex { + v, _ := shardMuMap.LoadOrStore(entryPath, &sync.Mutex{}) + return v.(*sync.Mutex) +} + +// HTTPCache is an HTTP response cache with conditional revalidation. +type HTTPCache struct { + cacheDir string +} + +// NewHTTPCache creates a new HTTPCache rooted at cacheRoot. +func NewHTTPCache(cacheRoot string) (*HTTPCache, error) { + cacheDir := GetHTTPPath(cacheRoot) + if err := ensureDir(cacheDir); err != nil { + return nil, err + } + cleanupIncomplete(cacheDir) + return &HTTPCache{cacheDir: cacheDir}, nil +} + +// Get looks up a cached response for url. +// Returns nil if the entry is missing, expired, or fails integrity check. +func (c *HTTPCache) Get(rawURL string) *CacheEntry { + entryPath := c.entryPath(rawURL) + metaPath := filepath.Join(entryPath, "meta.json") + bodyPath := filepath.Join(entryPath, "body") + + if !fileExists(metaPath) || !fileExists(bodyPath) { + return nil + } + + metaData, err := os.ReadFile(metaPath) + if err != nil { + return nil + } + var meta map[string]any + if err := json.Unmarshal(metaData, &meta); err != nil { + return nil + } + + expiresAt, _ := meta["expires_at"].(float64) + if float64(time.Now().Unix()) > expiresAt { + return nil + } + + body, err := os.ReadFile(bodyPath) + if err != nil { + return nil + } + + // Integrity check + if recorded, ok := meta["body_sha256"].(string); ok && recorded != "" { + actual := fmt.Sprintf("%x", sha256.Sum256(body)) + if actual != recorded { + _ = os.RemoveAll(entryPath) + return nil + } + } + + etag, _ := meta["etag"].(string) + contentType, _ := meta["content_type"].(string) + statusCode := 200 + if sc, ok := meta["status_code"].(float64); ok { + statusCode = int(sc) + } + + return &CacheEntry{ + Body: body, + ETag: etag, + ExpiresAt: expiresAt, + ContentType: contentType, + StatusCode: statusCode, + } +} + +// ConditionalHeaders returns If-None-Match headers for revalidation if an ETag is cached. +func (c *HTTPCache) ConditionalHeaders(rawURL string) map[string]string { + metaPath := filepath.Join(c.entryPath(rawURL), "meta.json") + if !fileExists(metaPath) { + return map[string]string{} + } + data, err := os.ReadFile(metaPath) + if err != nil { + return map[string]string{} + } + var meta map[string]any + if err := json.Unmarshal(data, &meta); err != nil { + return map[string]string{} + } + etag, _ := meta["etag"].(string) + if etag != "" { + return map[string]string{"If-None-Match": etag} + } + return map[string]string{} +} + +// Store caches an HTTP response. +func (c *HTTPCache) Store(rawURL string, body []byte, statusCode int, headers map[string]string) { + ttl := parseTTL(headers) + etag := headerGet(headers, "ETag") + contentType := headerGet(headers, "Content-Type") + + entryPath := c.entryPath(rawURL) + if err := ensurePathWithin(entryPath, c.cacheDir); err != nil { + return + } + + now := float64(time.Now().Unix()) + meta := map[string]any{ + "url": rawURL, + "etag": etag, + "expires_at": now + ttl, + "content_type": contentType, + "status_code": statusCode, + "stored_at": now, + "body_sha256": fmt.Sprintf("%x", sha256.Sum256(body)), + } + + // Atomic stage-rename + staged := stagedPath(entryPath) + if err := ensurePathWithin(staged, c.cacheDir); err != nil { + return + } + if err := os.MkdirAll(staged, 0o700); err != nil { + return + } + _ = os.Chmod(staged, 0o700) + + metaBytes, err := json.Marshal(meta) + if err != nil { + _ = os.RemoveAll(staged) + return + } + if err := os.WriteFile(filepath.Join(staged, "meta.json"), metaBytes, 0o600); err != nil { + _ = os.RemoveAll(staged) + return + } + if err := os.WriteFile(filepath.Join(staged, "body"), body, 0o600); err != nil { + _ = os.RemoveAll(staged) + return + } + + mu := shardMutex(entryPath) + mu.Lock() + _ = os.RemoveAll(entryPath) + if err := os.Rename(staged, entryPath); err != nil { + _ = os.RemoveAll(staged) + } + mu.Unlock() + + _ = os.Chtimes(entryPath, time.Now(), time.Now()) + c.enforceSizeCap() +} + +// RefreshExpiry refreshes TTL for a cached entry (called on 304 Not Modified). +func (c *HTTPCache) RefreshExpiry(rawURL string, headers map[string]string) { + metaPath := filepath.Join(c.entryPath(rawURL), "meta.json") + if !fileExists(metaPath) { + return + } + data, err := os.ReadFile(metaPath) + if err != nil { + return + } + var meta map[string]any + if err := json.Unmarshal(data, &meta); err != nil { + return + } + ttl := parseTTL(headers) + meta["expires_at"] = float64(time.Now().Unix()) + ttl + if newEtag := headerGet(headers, "ETag"); newEtag != "" { + meta["etag"] = newEtag + } + updated, err := json.Marshal(meta) + if err != nil { + return + } + _ = os.WriteFile(metaPath, updated, 0o600) + ep := c.entryPath(rawURL) + _ = os.Chtimes(ep, time.Now(), time.Now()) +} + +// CleanAll removes all HTTP cache entries. +func (c *HTTPCache) CleanAll() { + entries, err := os.ReadDir(c.cacheDir) + if err != nil { + return + } + for _, e := range entries { + if e.IsDir() { + _ = os.RemoveAll(filepath.Join(c.cacheDir, e.Name())) + } + } +} + +// GetStats returns cache statistics. +func (c *HTTPCache) GetStats() CacheStats { + entries, err := os.ReadDir(c.cacheDir) + if err != nil { + return CacheStats{} + } + var stats CacheStats + for _, e := range entries { + if !e.IsDir() { + continue + } + stats.EntryCount++ + subDir := filepath.Join(c.cacheDir, e.Name()) + files, err := os.ReadDir(subDir) + if err != nil { + continue + } + for _, f := range files { + if !f.IsDir() { + if fi, err := f.Info(); err == nil { + stats.TotalSizeBytes += fi.Size() + } + } + } + } + return stats +} + +// entryPath derives the cache entry directory path for a URL. +func (c *HTTPCache) entryPath(rawURL string) string { + h := sha256.Sum256([]byte(rawURL)) + urlHash := fmt.Sprintf("%x", h)[:16] + entry := filepath.Join(c.cacheDir, urlHash) + return entry +} + +func parseTTL(headers map[string]string) float64 { + cc := headerGet(headers, "Cache-Control") + if m := maxAgeRE.FindStringSubmatch(cc); m != nil { + n, err := strconv.Atoi(m[1]) + if err == nil { + if n > MaxHTTPCacheTTLSeconds { + return float64(MaxHTTPCacheTTLSeconds) + } + return float64(n) + } + } + return 300.0 +} + +func headerGet(headers map[string]string, key string) string { + lower := strings.ToLower(key) + for k, v := range headers { + if strings.ToLower(k) == lower { + return v + } + } + return "" +} + +func ensurePathWithin(child, parent string) error { + rel, err := filepath.Rel(parent, child) + if err != nil || strings.HasPrefix(rel, "..") { + return fmt.Errorf("path %s escapes cache root %s", child, parent) + } + return nil +} + +func stagedPath(entryPath string) string { + return entryPath + fmt.Sprintf(".incomplete.%d", time.Now().UnixNano()) +} + +func cleanupIncomplete(cacheDir string) { + entries, err := os.ReadDir(cacheDir) + if err != nil { + return + } + for _, e := range entries { + if strings.Contains(e.Name(), ".incomplete.") { + _ = os.RemoveAll(filepath.Join(cacheDir, e.Name())) + } + } +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func (c *HTTPCache) enforceSizeCap() { + entries, err := os.ReadDir(c.cacheDir) + if err != nil { + return + } + + type entryInfo struct { + mtime time.Time + path string + size int64 + } + var infos []entryInfo + var totalSize int64 + + for _, e := range entries { + if !e.IsDir() { + continue + } + dirPath := filepath.Join(c.cacheDir, e.Name()) + fi, err := os.Stat(dirPath) + if err != nil { + continue + } + var sz int64 + files, _ := os.ReadDir(dirPath) + for _, f := range files { + if !f.IsDir() { + if ffi, err := f.Info(); err == nil { + sz += ffi.Size() + } + } + } + infos = append(infos, entryInfo{mtime: fi.ModTime(), path: dirPath, size: sz}) + totalSize += sz + } + + if totalSize <= MaxHTTPCacheBytes { + return + } + + sort.Slice(infos, func(i, j int) bool { + return infos[i].mtime.Before(infos[j].mtime) + }) + + for _, info := range infos { + if totalSize <= MaxHTTPCacheBytes { + break + } + _ = os.RemoveAll(info.path) + totalSize -= info.size + } +} diff --git a/internal/cache/integrity.go b/internal/cache/integrity.go new file mode 100644 index 00000000..b29b04f6 --- /dev/null +++ b/internal/cache/integrity.go @@ -0,0 +1,101 @@ +package cache + +import ( + "os" + "path/filepath" + "strings" +) + +// VerifyCheckoutSHA verifies that a cached checkout's HEAD matches the expected SHA. +// Reads .git/HEAD (and follows refs / packed-refs as needed) rather than spawning +// "git rev-parse": faster, and cannot be influenced by a poisoned local .git/config. +func VerifyCheckoutSHA(checkoutDir, expectedSHA string) bool { + if _, err := os.Stat(checkoutDir); err != nil { + return false + } + actualSHA := readHeadSHA(checkoutDir) + if actualSHA == "" { + return false + } + return actualSHA == strings.TrimSpace(strings.ToLower(expectedSHA)) +} + +// readHeadSHA returns the resolved 40-char SHA at HEAD, or "" on any failure. +func readHeadSHA(checkoutDir string) string { + gitPath := filepath.Join(checkoutDir, ".git") + + fi, err := os.Stat(gitPath) + if err != nil { + return "" + } + + var gitDir string + if fi.Mode().IsRegular() { + // Worktree pointer: "gitdir: " + content, err := os.ReadFile(gitPath) + if err != nil { + return "" + } + line := strings.TrimSpace(string(content)) + if !strings.HasPrefix(line, "gitdir:") { + return "" + } + target := strings.TrimSpace(line[len("gitdir:"):]) + abs := filepath.Join(checkoutDir, target) + resolved, err := filepath.Abs(abs) + if err != nil { + return "" + } + gitDir = resolved + } else if fi.IsDir() { + gitDir = gitPath + } else { + return "" + } + + headPath := filepath.Join(gitDir, "HEAD") + headContent, err := os.ReadFile(headPath) + if err != nil { + return "" + } + head := strings.TrimSpace(string(headContent)) + + if strings.HasPrefix(head, "ref:") { + refTarget := strings.TrimSpace(head[len("ref:"):]) + refPath := filepath.Join(gitDir, refTarget) + if data, err := os.ReadFile(refPath); err == nil { + return strings.TrimSpace(strings.ToLower(string(data))) + } + // Try packed-refs + packedPath := filepath.Join(gitDir, "packed-refs") + if packed, err := os.ReadFile(packedPath); err == nil { + for _, raw := range strings.Split(string(packed), "\n") { + line := strings.TrimSpace(raw) + if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, "^") { + continue + } + parts := strings.SplitN(line, " ", 2) + if len(parts) == 2 && parts[1] == refTarget { + return strings.ToLower(parts[0]) + } + } + } + return "" + } + + // Detached HEAD: should be a 40-char hex SHA + lower := strings.ToLower(head) + if len(lower) == 40 && isHex(lower) { + return lower + } + return "" +} + +func isHex(s string) bool { + for _, c := range s { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + return false + } + } + return true +} diff --git a/internal/cache/paths.go b/internal/cache/paths.go new file mode 100644 index 00000000..bd61ed11 --- /dev/null +++ b/internal/cache/paths.go @@ -0,0 +1,132 @@ +// Package cache provides HTTP and git caching primitives for the APM CLI. +package cache + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "sync" +) + +// Bucket layout within cache root. +const ( + GitDBBucket = "git/db_v1" + GitCheckoutsBucket = "git/checkouts_v1" + HTTPBucket = "http_v1" +) + +var ( + tempCacheDir string + tempCacheMu sync.Mutex + tempCacheOnce sync.Once +) + +// GetCacheRoot resolves the cache root directory. +// If noCache is true, returns a temporary directory cleaned up at process exit. +// Honours APM_NO_CACHE and APM_CACHE_DIR environment variables. +func GetCacheRoot(noCache bool) (string, error) { + if noCache || strings.TrimSpace(os.Getenv("APM_NO_CACHE")) == "1" || + strings.TrimSpace(os.Getenv("APM_NO_CACHE")) == "true" || + strings.TrimSpace(os.Getenv("APM_NO_CACHE")) == "yes" { + return getTempCacheRoot() + } + + override := strings.TrimSpace(os.Getenv("APM_CACHE_DIR")) + if override != "" { + return validateAndEnsure(override) + } + + return validateAndEnsure(platformDefault()) +} + +// GetGitDBPath returns the git database bucket path (full clones). +func GetGitDBPath(cacheRoot string) string { + return filepath.Join(cacheRoot, GitDBBucket) +} + +// GetGitCheckoutsPath returns the git checkouts bucket path (per-SHA working copies). +func GetGitCheckoutsPath(cacheRoot string) string { + return filepath.Join(cacheRoot, GitCheckoutsBucket) +} + +// GetHTTPPath returns the HTTP cache bucket path. +func GetHTTPPath(cacheRoot string) string { + return filepath.Join(cacheRoot, HTTPBucket) +} + +func platformDefault() string { + switch runtime.GOOS { + case "windows": + localAppData := os.Getenv("LOCALAPPDATA") + if localAppData != "" { + return filepath.Join(localAppData, "apm", "Cache") + } + home, _ := os.UserHomeDir() + return filepath.Join(home, "AppData", "Local", "apm", "Cache") + case "darwin": + xdg := strings.TrimSpace(os.Getenv("XDG_CACHE_HOME")) + if xdg != "" { + return filepath.Join(xdg, "apm") + } + home, _ := os.UserHomeDir() + return filepath.Join(home, "Library", "Caches", "apm") + default: + xdg := strings.TrimSpace(os.Getenv("XDG_CACHE_HOME")) + if xdg != "" { + return filepath.Join(xdg, "apm") + } + home, _ := os.UserHomeDir() + return filepath.Join(home, ".cache", "apm") + } +} + +func validateAndEnsure(pathStr string) (string, error) { + if pathStr == "" { + return "", fmt.Errorf("cache path must not be empty") + } + if strings.Contains(pathStr, "\x00") { + return "", fmt.Errorf("cache path must not contain NUL bytes") + } + + expanded := pathStr + if strings.HasPrefix(expanded, "~") { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("cannot expand ~: %w", err) + } + expanded = home + expanded[1:] + } + abs, err := filepath.Abs(expanded) + if err != nil { + return "", fmt.Errorf("cannot make cache path absolute: %w", err) + } + if err := ensureDir(abs); err != nil { + return "", err + } + return abs, nil +} + +func ensureDir(path string) error { + if err := os.MkdirAll(path, 0o700); err != nil { + return fmt.Errorf("failed to create cache directory %s: %w", path, err) + } + // Best-effort chmod -- no-op on Windows + _ = os.Chmod(path, 0o700) + return nil +} + +func getTempCacheRoot() (string, error) { + tempCacheMu.Lock() + defer tempCacheMu.Unlock() + if tempCacheDir == "" { + dir, err := os.MkdirTemp("", "apm_cache_") + if err != nil { + return "", fmt.Errorf("failed to create temp cache: %w", err) + } + _ = os.Chmod(dir, 0o700) + tempCacheDir = dir + } + return tempCacheDir, nil +} diff --git a/internal/cache/url_normalize.go b/internal/cache/url_normalize.go new file mode 100644 index 00000000..3dd069ae --- /dev/null +++ b/internal/cache/url_normalize.go @@ -0,0 +1,112 @@ +package cache + +import ( + "crypto/sha256" + "fmt" + "net/url" + "regexp" + "strings" +) + +// scpLikeRE matches SCP-style SSH URLs: user@host:path +var scpLikeRE = regexp.MustCompile( + `^(?P[a-zA-Z0-9_][a-zA-Z0-9_.+-]*)@` + + `(?P[^:/]+)` + + `:(?P.+)$`, +) + +// defaultPorts maps schemes to their default TCP ports. +var defaultPorts = map[string]int{ + "https": 443, + "ssh": 22, + "http": 80, + "git": 9418, +} + +// caseInsensitiveHosts are hosts where the URL path is treated case-insensitively. +var caseInsensitiveHosts = map[string]bool{ + "github.com": true, + "gitlab.com": true, + "bitbucket.org": true, +} + +// NormalizeRepoURL normalises a Git repository URL for cache key derivation. +// The result is a canonical string suitable for hashing. It is NOT necessarily +// a valid URL -- it is a deterministic representation. +func NormalizeRepoURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + + // Convert SCP-like (git@host:path) to ssh:// form + if m := scpLikeRE.FindStringSubmatch(rawURL); m != nil { + user := scpLikeRE.SubexpIndex("user") + host := scpLikeRE.SubexpIndex("host") + path := scpLikeRE.SubexpIndex("path") + p := m[path] + if !strings.HasPrefix(p, "/") { + p = "/" + p + } + rawURL = fmt.Sprintf("ssh://%s@%s%s", m[user], m[host], p) + } + + parsed, err := url.Parse(rawURL) + if err != nil { + return rawURL + } + + // Lowercase hostname + hostname := strings.ToLower(parsed.Hostname()) + + // Keep username, drop password + username := "" + if parsed.User != nil { + username = parsed.User.Username() + } + + // Strip default ports + scheme := strings.ToLower(parsed.Scheme) + if scheme == "" { + scheme = "https" + } + portInt := 0 + if p, err2 := url.ParseRequestURI(rawURL); err2 == nil { + if p.Port() != "" { + fmt.Sscanf(p.Port(), "%d", &portInt) + } + } + if def, ok := defaultPorts[scheme]; ok && portInt == def { + portInt = 0 + } + + // Reconstruct authority + authority := hostname + if username != "" { + authority = username + "@" + hostname + } + if portInt != 0 { + authority = fmt.Sprintf("%s:%d", authority, portInt) + } + + // Strip trailing .git from path + path := parsed.Path + if strings.HasSuffix(path, ".git") { + path = path[:len(path)-4] + } + + // Lowercase path for known case-insensitive hosts + if caseInsensitiveHosts[hostname] { + path = strings.ToLower(path) + } + + // Strip trailing slash + path = strings.TrimRight(path, "/") + + return fmt.Sprintf("%s://%s%s", scheme, authority, path) +} + +// CacheShardKey derives a filesystem-safe shard key from a repository URL. +// Returns the first 16 hex characters of the SHA-256 of the normalised URL. +func CacheShardKey(rawURL string) string { + normalized := NormalizeRepoURL(rawURL) + h := sha256.Sum256([]byte(normalized)) + return fmt.Sprintf("%x", h)[:16] +} diff --git a/internal/constants/constants.go b/internal/constants/constants.go new file mode 100644 index 00000000..de0b0088 --- /dev/null +++ b/internal/constants/constants.go @@ -0,0 +1,44 @@ +// Package constants provides shared constants for the APM CLI. +package constants + +// File and directory names +const ( + APMYMLFilename = "apm.yml" + APMLockFilename = "apm.lock" + APMModulesDir = "apm_modules" + APMDir = ".apm" + SkillMDFilename = "SKILL.md" + AgentsMDFilename = "AGENTS.md" + ClaudeMDFilename = "CLAUDE.md" + GithubDir = ".github" + ClaudeDir = ".claude" + GitignoreFilename = ".gitignore" + APMModulesGitignorePattern = "apm_modules/" +) + +// InstallMode controls which dependency types are installed. +type InstallMode string + +const ( + InstallModeAll InstallMode = "all" + InstallModeAPM InstallMode = "apm" + InstallModeMCP InstallMode = "mcp" +) + +// DefaultSkipDirs lists directories unconditionally skipped during +// primitive-file discovery. These never contain APM primitives or user +// source files and can be very large (e.g. node_modules, .git objects). +// NOTE: .apm is intentionally absent -- it is where primitives live. +var DefaultSkipDirs = map[string]bool{ + ".git": true, + "node_modules": true, + "__pycache__": true, + ".pytest_cache": true, + ".venv": true, + "venv": true, + ".tox": true, + "build": true, + "dist": true, + ".mypy_cache": true, + "apm_modules": true, +} diff --git a/internal/constants/constants_test.go b/internal/constants/constants_test.go new file mode 100644 index 00000000..00ebc207 --- /dev/null +++ b/internal/constants/constants_test.go @@ -0,0 +1,68 @@ +// Package constants_test provides parity tests for constants. +package constants_test + +import ( + "testing" + + "github.com/githubnext/apm/internal/constants" +) + +// TestParityConstantsFilenames verifies file and directory name constants +// match the Python source (src/apm_cli/constants.py). +func TestParityConstantsFilenames(t *testing.T) { + cases := []struct { + name string + got string + expected string + }{ + {"APMYMLFilename", constants.APMYMLFilename, "apm.yml"}, + {"APMLockFilename", constants.APMLockFilename, "apm.lock"}, + {"APMModulesDir", constants.APMModulesDir, "apm_modules"}, + {"APMDir", constants.APMDir, ".apm"}, + {"SkillMDFilename", constants.SkillMDFilename, "SKILL.md"}, + {"AgentsMDFilename", constants.AgentsMDFilename, "AGENTS.md"}, + {"ClaudeMDFilename", constants.ClaudeMDFilename, "CLAUDE.md"}, + {"GithubDir", constants.GithubDir, ".github"}, + {"ClaudeDir", constants.ClaudeDir, ".claude"}, + {"GitignoreFilename", constants.GitignoreFilename, ".gitignore"}, + {"APMModulesGitignorePattern", constants.APMModulesGitignorePattern, "apm_modules/"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if tc.got != tc.expected { + t.Errorf("got %q, want %q", tc.got, tc.expected) + } + }) + } +} + +// TestParityConstantsInstallMode verifies InstallMode values match Python. +func TestParityConstantsInstallMode(t *testing.T) { + if constants.InstallModeAll != "all" { + t.Errorf("InstallModeAll: got %q, want %q", constants.InstallModeAll, "all") + } + if constants.InstallModeAPM != "apm" { + t.Errorf("InstallModeAPM: got %q, want %q", constants.InstallModeAPM, "apm") + } + if constants.InstallModeMCP != "mcp" { + t.Errorf("InstallModeMCP: got %q, want %q", constants.InstallModeMCP, "mcp") + } +} + +// TestParityConstantsDefaultSkipDirs verifies the default skip dirs set +// matches Python's DEFAULT_SKIP_DIRS frozenset. +func TestParityConstantsDefaultSkipDirs(t *testing.T) { + expected := []string{ + ".git", "node_modules", "__pycache__", ".pytest_cache", + ".venv", "venv", ".tox", "build", "dist", ".mypy_cache", "apm_modules", + } + for _, dir := range expected { + if !constants.DefaultSkipDirs[dir] { + t.Errorf("DefaultSkipDirs missing %q", dir) + } + } + // .apm must NOT be in skip dirs + if constants.DefaultSkipDirs[".apm"] { + t.Error("DefaultSkipDirs must not contain .apm") + } +} diff --git a/internal/core/apm_yml.go b/internal/core/apm_yml.go new file mode 100644 index 00000000..a0269588 --- /dev/null +++ b/internal/core/apm_yml.go @@ -0,0 +1,108 @@ +package core + +import "fmt" + +// CanonicalTargets is the set of target names accepted by APM. +var CanonicalTargets = map[string]bool{ + "claude": true, + "copilot": true, + "cursor": true, + "opencode": true, + "codex": true, + "gemini": true, + "windsurf": true, + "agent-skills": true, +} + +// ParseTargetsField parses the targets/target field from a raw apm.yml data +// map. Returns a canonical list of target names. An empty list means neither +// key was present (caller should fall through to auto-detect). +func ParseTargetsField(yamlData map[string]interface{}) ([]string, error) { + _, hasTargets := yamlData["targets"] + _, hasTarget := yamlData["target"] + + if hasTargets && hasTarget { + return nil, NewConflictingTargetsError() + } + + if hasTargets { + raw := yamlData["targets"] + switch v := raw.(type) { + case nil: + return nil, NewEmptyTargetsListError() + case []interface{}: + if len(v) == 0 { + return nil, NewEmptyTargetsListError() + } + tokens := make([]string, 0, len(v)) + for _, item := range v { + t := fmt.Sprintf("%v", item) + if t != "" { + tokens = append(tokens, t) + } + } + if err := validateCanonical(tokens); err != nil { + return nil, err + } + return tokens, nil + default: + // Single value under targets key + tokens := []string{fmt.Sprintf("%v", v)} + if err := validateCanonical(tokens); err != nil { + return nil, err + } + return tokens, nil + } + } + + if hasTarget { + raw := yamlData["target"] + if raw == nil { + return []string{}, nil + } + switch v := raw.(type) { + case []interface{}: + tokens := make([]string, 0, len(v)) + for _, item := range v { + t := fmt.Sprintf("%v", item) + if t != "" { + tokens = append(tokens, t) + } + } + if len(tokens) == 0 { + return []string{}, nil + } + if err := validateCanonical(tokens); err != nil { + return nil, err + } + return tokens, nil + default: + rawStr := fmt.Sprintf("%v", v) + if rawStr == "" { + return []string{}, nil + } + // CSV sugar: "claude,copilot" -> ["claude", "copilot"] + tokens := splitCSV(rawStr) + if len(tokens) == 0 { + return []string{}, nil + } + if err := validateCanonical(tokens); err != nil { + return nil, err + } + return tokens, nil + } + } + + return []string{}, nil +} + +// validateCanonical checks every token is in CanonicalTargets. +func validateCanonical(tokens []string) error { + valid := sortedKeys(CanonicalTargets) + for _, t := range tokens { + if !CanonicalTargets[t] { + return NewUnknownTargetError(t, valid) + } + } + return nil +} diff --git a/internal/core/auth.go b/internal/core/auth.go new file mode 100644 index 00000000..4bc393e3 --- /dev/null +++ b/internal/core/auth.go @@ -0,0 +1,296 @@ +// auth.go mirrors src/apm_cli/core/auth.py. +// Provides AuthResolver, HostInfo, AuthContext, and BearerFallbackOutcome. +package core + +import ( + "os" + "strings" + "sync" + + "github.com/githubnext/apm/internal/utils/githubhost" +) + +// HostInfo is an immutable description of a remote Git host. +// Mirrors the HostInfo dataclass in auth.py. +type HostInfo struct { + Host string + Kind string // "github" | "ghe_cloud" | "ghes" | "ado" | "gitlab" | "generic" + HasPublicRepos bool + APIBase string + Port int // 0 = default port +} + +// DisplayName returns "host:port" when a non-default port is set, else bare host. +func (h HostInfo) DisplayName() string { + wellKnown := map[int]bool{443: true, 80: true, 22: true} + if h.Port != 0 && !wellKnown[h.Port] { + return h.Host + ":" + itoa(h.Port) + } + return h.Host +} + +func itoa(n int) string { + if n == 0 { + return "0" + } + // simple int-to-string without importing strconv at package level + neg := false + if n < 0 { + neg = true + n = -n + } + var buf [20]byte + pos := len(buf) + for n > 0 { + pos-- + buf[pos] = byte(n%10) + '0' + n /= 10 + } + if neg { + pos-- + buf[pos] = '-' + } + return string(buf[pos:]) +} + +// AuthContext holds resolved authentication for a single (host, org) pair. +// Mirrors the AuthContext dataclass in auth.py. +type AuthContext struct { + Token string // empty string = no token + Source string // e.g. "GITHUB_APM_PAT", "none" + TokenType string // "fine-grained", "classic", "oauth", "github-app", "unknown" + HostInfo HostInfo + GitEnv map[string]string + AuthScheme string // "basic" | "bearer" +} + +// BearerFallbackOutcome carries the result of execute_with_bearer_fallback. +type BearerFallbackOutcome struct { + // Outcome is the final result (caller-defined, stored as interface{}). + Outcome interface{} + BearerAttempted bool +} + +// authCacheKey is the map key for the AuthResolver cache. +type authCacheKey struct { + host string + port int + org string +} + +// AuthResolver is the single source of truth for auth resolution. +// Mirrors AuthResolver in auth.py. +type AuthResolver struct { + tokenManager *GitHubTokenManager + cache map[authCacheKey]*AuthContext + mu sync.Mutex + + verboseAuthLoggedHosts map[string]bool + stalePATWarnedHosts map[string]bool +} + +// NewAuthResolver constructs an AuthResolver with a default token manager. +func NewAuthResolver() *AuthResolver { + return &AuthResolver{ + tokenManager: NewGitHubTokenManager(), + cache: make(map[authCacheKey]*AuthContext), + verboseAuthLoggedHosts: make(map[string]bool), + stalePATWarnedHosts: make(map[string]bool), + } +} + +// NewAuthResolverWithManager constructs an AuthResolver with a provided token manager. +func NewAuthResolverWithManager(tm *GitHubTokenManager) *AuthResolver { + r := NewAuthResolver() + r.tokenManager = tm + return r +} + +// ClassifyHost returns a HostInfo for the given host and port. +// Mirrors AuthResolver.classify_host. +func ClassifyHost(host string, port int) HostInfo { + h := strings.ToLower(host) + + if h == "github.com" { + return HostInfo{Host: host, Kind: "github", HasPublicRepos: true, + APIBase: "https://api.github.com", Port: port} + } + if strings.HasSuffix(h, ".ghe.com") { + return HostInfo{Host: host, Kind: "ghe_cloud", HasPublicRepos: false, + APIBase: "https://" + host + "/api/v3", Port: port} + } + if githubhost.IsAzureDevOpsHostname(host) { + return HostInfo{Host: host, Kind: "ado", HasPublicRepos: true, + APIBase: "https://dev.azure.com", Port: port} + } + + // GHES: GITHUB_HOST is set to a non-github.com, non-ghe.com FQDN. + ghesHost := strings.ToLower(os.Getenv("GITHUB_HOST")) + if ghesHost != "" && ghesHost == h && + ghesHost != "github.com" && ghesHost != "gitlab.com" && + !strings.HasSuffix(ghesHost, ".ghe.com") && + githubhost.IsValidFQDN(ghesHost) { + return HostInfo{Host: host, Kind: "ghes", HasPublicRepos: true, + APIBase: "https://" + host + "/api/v3", Port: port} + } + + // GitLab (SaaS + env-configured self-managed) -- after GHES per spec. + if githubhost.IsGitLabHostname(host) { + apiBase := "https://gitlab.com/api/v4" + if h != "gitlab.com" { + apiBase = "https://" + host + "/api/v4" + } + return HostInfo{Host: host, Kind: "gitlab", HasPublicRepos: true, + APIBase: apiBase, Port: port} + } + + // Generic FQDN. + return HostInfo{Host: host, Kind: "generic", HasPublicRepos: true, + APIBase: "https://" + host + "/api/v3", Port: port} +} + +// DetectTokenType classifies a token string by its prefix. +// Mirrors AuthResolver.detect_token_type. +func DetectTokenType(token string) string { + switch { + case strings.HasPrefix(token, "github_pat_"): + return "fine-grained" + case strings.HasPrefix(token, "ghp_"): + return "classic" + case strings.HasPrefix(token, "ghu_"): + return "oauth" + case strings.HasPrefix(token, "gho_"): + return "oauth" + case strings.HasPrefix(token, "ghs_"): + return "github-app" + case strings.HasPrefix(token, "ghr_"): + return "github-app" + default: + return "unknown" + } +} + +// Resolve resolves auth for (host, port, org). Cached and thread-safe. +// Mirrors AuthResolver.resolve. +func (r *AuthResolver) Resolve(host string, org string, port int) *AuthContext { + hostLower := strings.ToLower(host) + orgLower := strings.ToLower(org) + key := authCacheKey{host: hostLower, port: port, org: orgLower} + + r.mu.Lock() + defer r.mu.Unlock() + + if cached := r.cache[key]; cached != nil { + return cached + } + + hostInfo := ClassifyHost(host, port) + token, source, scheme := r.resolveToken(hostInfo, org) + tokenType := "unknown" + if token != "" { + tokenType = DetectTokenType(token) + } + gitEnv := r.buildGitEnv(token, scheme, hostInfo.Kind) + + ctx := &AuthContext{ + Token: token, + Source: source, + TokenType: tokenType, + HostInfo: hostInfo, + GitEnv: gitEnv, + AuthScheme: scheme, + } + r.cache[key] = ctx + return ctx +} + +// purposeForHost maps host kind to token purpose. +func purposeForHost(info HostInfo) string { + switch info.Kind { + case "ado": + return "ado_modules" + case "gitlab": + return "gitlab_modules" + case "generic": + return "generic_modules" + default: + return "modules" + } +} + +// orgToEnvSuffix converts an org name to upper-case env-var suffix with hyphens as underscores. +func orgToEnvSuffix(org string) string { + return strings.ToUpper(strings.ReplaceAll(org, "-", "_")) +} + +// resolveToken walks the token resolution chain. Returns (token, source, scheme). +// Mirrors AuthResolver._resolve_token. +func (r *AuthResolver) resolveToken(info HostInfo, org string) (string, string, string) { + // ADO: PAT -> none (bearer is fetched lazily in try_with_fallback) + if info.Kind == "ado" { + if pat := os.Getenv("ADO_APM_PAT"); pat != "" { + return pat, "ADO_APM_PAT", "basic" + } + return "", "none", "basic" + } + + // 1. Per-org PAT (GitHub-class only). + if org != "" && (info.Kind == "github" || info.Kind == "ghe_cloud" || info.Kind == "ghes") { + envName := "GITHUB_APM_PAT_" + orgToEnvSuffix(org) + if token := os.Getenv(envName); token != "" { + return token, envName, "basic" + } + } + + // 2. Global env vars by host class. + purpose := purposeForHost(info) + env := OSEnvMap() + if token, ok := r.tokenManager.GetTokenForPurpose(purpose, env); ok { + source := r.tokenManager.IdentifyEnvSource(purpose) + return token, source, "basic" + } + + // 3. gh CLI. + if token, ok := ResolveCredentialFromGHCLI(info.Host); ok { + return token, "gh-auth-token", "basic" + } + + // 4. Git credential helper (not for ADO). + if info.Kind != "ado" { + if token, ok := ResolveCredentialFromGit(info.Host, info.Port, ""); ok { + return token, "git-credential-fill", "basic" + } + } + + return "", "none", "basic" +} + +// buildGitEnv constructs a process env for git subcommands. +// Mirrors AuthResolver._build_git_env. +func (r *AuthResolver) buildGitEnv(token, scheme, hostKind string) map[string]string { + env := OSEnvMap() + env["GIT_TERMINAL_PROMPT"] = "0" + env["GIT_ASKPASS"] = "echo" + + if scheme == "bearer" && token != "" && hostKind == "ado" { + delete(env, "GIT_TOKEN") + for k, v := range githubhost.BuildADOBearerGitEnv(token) { + env[k] = v + } + } else if token != "" { + env["GIT_TOKEN"] = token + } + return env +} + +// GitLabRESTHeaders builds HTTP headers for GitLab REST API v4 calls. +// Mirrors AuthResolver.gitlab_rest_headers. +func GitLabRESTHeaders(token string, oauthBearer bool) map[string]string { + if token == "" { + return map[string]string{} + } + if oauthBearer { + return map[string]string{"Authorization": "Bearer " + token} + } + return map[string]string{"PRIVATE-TOKEN": token} +} diff --git a/internal/core/auth_test.go b/internal/core/auth_test.go new file mode 100644 index 00000000..ce68dd78 --- /dev/null +++ b/internal/core/auth_test.go @@ -0,0 +1,281 @@ +package core_test + +import ( + "os" + "strings" + "testing" + + "github.com/githubnext/apm/internal/core" +) + +// --------------------------------------------------------------------------- +// Parity: HostInfo.DisplayName +// --------------------------------------------------------------------------- + +func TestParityHostInfoDisplayName(t *testing.T) { + cases := []struct { + host string + port int + want string + }{ + {"github.com", 0, "github.com"}, + {"github.com", 443, "github.com"}, + {"github.com", 80, "github.com"}, + {"bitbucket.example.com", 7999, "bitbucket.example.com:7999"}, + {"bitbucket.example.com", 7990, "bitbucket.example.com:7990"}, + } + for _, c := range cases { + h := core.HostInfo{Host: c.host, Port: c.port} + if got := h.DisplayName(); got != c.want { + t.Errorf("DisplayName(%q, %d) = %q, want %q", c.host, c.port, got, c.want) + } + } +} + +// --------------------------------------------------------------------------- +// Parity: ClassifyHost +// --------------------------------------------------------------------------- + +func TestParityClassifyHostGitHub(t *testing.T) { + info := core.ClassifyHost("github.com", 0) + if info.Kind != "github" { + t.Errorf("github.com kind = %q, want github", info.Kind) + } + if !info.HasPublicRepos { + t.Error("github.com should have public repos") + } + if info.APIBase != "https://api.github.com" { + t.Errorf("github.com APIBase = %q", info.APIBase) + } +} + +func TestParityClassifyHostGHECloud(t *testing.T) { + info := core.ClassifyHost("myenterprise.ghe.com", 0) + if info.Kind != "ghe_cloud" { + t.Errorf("*.ghe.com kind = %q, want ghe_cloud", info.Kind) + } + if info.HasPublicRepos { + t.Error("ghe_cloud should NOT have public repos") + } +} + +func TestParityClassifyHostADO(t *testing.T) { + info := core.ClassifyHost("dev.azure.com", 0) + if info.Kind != "ado" { + t.Errorf("dev.azure.com kind = %q, want ado", info.Kind) + } +} + +func TestParityClassifyHostVisualStudio(t *testing.T) { + info := core.ClassifyHost("myorg.visualstudio.com", 0) + if info.Kind != "ado" { + t.Errorf("*.visualstudio.com kind = %q, want ado", info.Kind) + } +} + +func TestParityClassifyHostGitLab(t *testing.T) { + info := core.ClassifyHost("gitlab.com", 0) + if info.Kind != "gitlab" { + t.Errorf("gitlab.com kind = %q, want gitlab", info.Kind) + } + if info.APIBase != "https://gitlab.com/api/v4" { + t.Errorf("gitlab.com APIBase = %q", info.APIBase) + } +} + +func TestParityClassifyHostGHES(t *testing.T) { + t.Setenv("GITHUB_HOST", "ghes.example.com") + info := core.ClassifyHost("ghes.example.com", 0) + if info.Kind != "ghes" { + t.Errorf("GITHUB_HOST=ghes.example.com kind = %q, want ghes", info.Kind) + } +} + +func TestParityClassifyHostGeneric(t *testing.T) { + os.Unsetenv("GITHUB_HOST") + info := core.ClassifyHost("bitbucket.example.com", 0) + if info.Kind != "generic" { + t.Errorf("generic host kind = %q, want generic", info.Kind) + } +} + +func TestParityClassifyHostPort(t *testing.T) { + info := core.ClassifyHost("bitbucket.example.com", 7999) + if info.Port != 7999 { + t.Errorf("port = %d, want 7999", info.Port) + } +} + +// --------------------------------------------------------------------------- +// Parity: DetectTokenType +// --------------------------------------------------------------------------- + +func TestParityDetectTokenType(t *testing.T) { + cases := []struct { + token string + want string + }{ + {"github_pat_abc123", "fine-grained"}, + {"ghp_abc123", "classic"}, + {"ghu_abc123", "oauth"}, + {"gho_abc123", "oauth"}, + {"ghs_abc123", "github-app"}, + {"ghr_abc123", "github-app"}, + {"sometoken", "unknown"}, + {"", "unknown"}, + } + for _, c := range cases { + got := core.DetectTokenType(c.token) + if got != c.want { + t.Errorf("DetectTokenType(%q) = %q, want %q", c.token, got, c.want) + } + } +} + +// --------------------------------------------------------------------------- +// Parity: GitLabRESTHeaders +// --------------------------------------------------------------------------- + +func TestParityGitLabRESTHeaders(t *testing.T) { + headers := core.GitLabRESTHeaders("mytoken", false) + if headers["PRIVATE-TOKEN"] != "mytoken" { + t.Errorf("PAT header = %q", headers["PRIVATE-TOKEN"]) + } + bearer := core.GitLabRESTHeaders("mytoken", true) + if bearer["Authorization"] != "Bearer mytoken" { + t.Errorf("bearer header = %q", bearer["Authorization"]) + } + empty := core.GitLabRESTHeaders("", false) + if len(empty) != 0 { + t.Error("empty token should return empty map") + } +} + +// --------------------------------------------------------------------------- +// Parity: AuthResolver.Resolve -- token resolution from env +// --------------------------------------------------------------------------- + +func TestParityAuthResolverResolveGitHub(t *testing.T) { + t.Setenv("GITHUB_APM_PAT", "ghp_testtoken") + defer os.Unsetenv("GITHUB_APM_PAT") + + r := core.NewAuthResolver() + ctx := r.Resolve("github.com", "", 0) + if ctx.Token != "ghp_testtoken" { + t.Errorf("token = %q, want ghp_testtoken", ctx.Token) + } + if ctx.Source != "GITHUB_APM_PAT" { + t.Errorf("source = %q, want GITHUB_APM_PAT", ctx.Source) + } + if ctx.HostInfo.Kind != "github" { + t.Errorf("kind = %q, want github", ctx.HostInfo.Kind) + } +} + +func TestParityAuthResolverResolveADO(t *testing.T) { + t.Setenv("ADO_APM_PAT", "adotoken") + defer os.Unsetenv("ADO_APM_PAT") + + r := core.NewAuthResolver() + ctx := r.Resolve("dev.azure.com", "", 0) + if ctx.Token != "adotoken" { + t.Errorf("token = %q, want adotoken", ctx.Token) + } + if ctx.Source != "ADO_APM_PAT" { + t.Errorf("source = %q, want ADO_APM_PAT", ctx.Source) + } +} + +func TestParityAuthResolverResolveNoToken(t *testing.T) { + os.Unsetenv("GITHUB_APM_PAT") + os.Unsetenv("GITHUB_TOKEN") + os.Unsetenv("GH_TOKEN") + + r := core.NewAuthResolver() + ctx := r.Resolve("github.com", "", 0) + // In CI without gh CLI or git credentials, token should be "" with source "none" or "gh-auth-token" + if ctx.Token == "" && ctx.Source != "none" && ctx.Source != "git-credential-fill" && ctx.Source != "gh-auth-token" { + t.Errorf("unexpected source %q when no token env set", ctx.Source) + } +} + +func TestParityAuthResolverResolvePerOrgToken(t *testing.T) { + t.Setenv("GITHUB_APM_PAT_MYORG", "ghp_orgtoken") + defer os.Unsetenv("GITHUB_APM_PAT_MYORG") + + r := core.NewAuthResolver() + ctx := r.Resolve("github.com", "myorg", 0) + if ctx.Token != "ghp_orgtoken" { + t.Errorf("per-org token = %q, want ghp_orgtoken", ctx.Token) + } + if ctx.Source != "GITHUB_APM_PAT_MYORG" { + t.Errorf("source = %q, want GITHUB_APM_PAT_MYORG", ctx.Source) + } +} + +func TestParityAuthResolverCacheHit(t *testing.T) { + t.Setenv("GITHUB_APM_PAT", "ghp_cached") + defer os.Unsetenv("GITHUB_APM_PAT") + + r := core.NewAuthResolver() + ctx1 := r.Resolve("github.com", "", 0) + ctx2 := r.Resolve("github.com", "", 0) + if ctx1 != ctx2 { + t.Error("second resolve should return cached pointer") + } +} + +func TestParityAuthResolverOrgToEnvSuffix(t *testing.T) { + // Verify per-org env var naming: hyphens -> underscores, upper-case. + t.Setenv("GITHUB_APM_PAT_MY_COOL_ORG", "ghp_orgtoken2") + defer os.Unsetenv("GITHUB_APM_PAT_MY_COOL_ORG") + + r := core.NewAuthResolver() + ctx := r.Resolve("github.com", "my-cool-org", 0) + if ctx.Token != "ghp_orgtoken2" { + t.Errorf("hyphen org token = %q, want ghp_orgtoken2", ctx.Token) + } +} + +// --------------------------------------------------------------------------- +// Parity: token_manager utilities +// --------------------------------------------------------------------------- + +func TestParityTokenManagerValidateTokensPass(t *testing.T) { + env := map[string]string{"GITHUB_APM_PAT": "ghp_test"} + mgr := core.NewGitHubTokenManager() + ok, _ := mgr.ValidateTokens(env) + if !ok { + t.Error("ValidateTokens with GITHUB_APM_PAT should pass") + } +} + +func TestParityTokenManagerValidateTokensFail(t *testing.T) { + env := map[string]string{} + mgr := core.NewGitHubTokenManager() + ok, msg := mgr.ValidateTokens(env) + if ok { + t.Error("ValidateTokens with no tokens should fail") + } + if !strings.Contains(msg, "No tokens found") { + t.Errorf("message = %q", msg) + } +} + +func TestParityTokenManagerGetTokenForPurpose(t *testing.T) { + env := map[string]string{"GITHUB_APM_PAT": "ghp_test"} + mgr := core.NewGitHubTokenManager() + token, ok := mgr.GetTokenForPurpose("modules", env) + if !ok || token != "ghp_test" { + t.Errorf("modules token = %q, ok = %v", token, ok) + } +} + +func TestParityTokenManagerGetTokenUnknownPurpose(t *testing.T) { + env := map[string]string{"GITHUB_APM_PAT": "ghp_test"} + mgr := core.NewGitHubTokenManager() + _, ok := mgr.GetTokenForPurpose("nonexistent", env) + if ok { + t.Error("unknown purpose should return false") + } +} diff --git a/internal/core/core_test.go b/internal/core/core_test.go new file mode 100644 index 00000000..24fef045 --- /dev/null +++ b/internal/core/core_test.go @@ -0,0 +1,331 @@ +package core_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/githubnext/apm/internal/core" +) + +// --------------------------------------------------------------------------- +// Parity: errors +// --------------------------------------------------------------------------- + +func TestParityRenderNoHarnessError(t *testing.T) { + msg := core.RenderNoHarnessError() + if !strings.Contains(msg, "[x] No harness detected") { + t.Errorf("expected headline, got: %s", msg) + } + if !strings.Contains(msg, "apm install") { + t.Error("expected actionable command in error message") + } +} + +func TestParityRenderAmbiguousError(t *testing.T) { + msg := core.RenderAmbiguousError([]string{".github/", ".claude/"}) + if !strings.Contains(msg, "[x] Multiple harnesses detected") { + t.Errorf("expected headline, got: %s", msg) + } + if !strings.Contains(msg, ".github/") { + t.Error("expected detected folders in message") + } +} + +func TestParityRenderUnknownTargetError(t *testing.T) { + msg := core.RenderUnknownTargetError("foo", []string{"claude", "copilot", "cursor"}) + if !strings.Contains(msg, "[x] Unknown target 'foo'") { + t.Errorf("unexpected message: %s", msg) + } + if strings.Contains(msg, "agent-skills") { + t.Error("agent-skills should be hidden from user-facing error") + } +} + +func TestParityRenderUnknownTargetErrorBracketNoise(t *testing.T) { + msg := core.RenderUnknownTargetError("['copilot'", []string{"claude", "copilot"}) + if strings.Contains(msg, "['") { + t.Errorf("bracket noise should be stripped, got: %s", msg) + } +} + +func TestParityRenderConflictingSchemaError(t *testing.T) { + msg := core.RenderConflictingSchemaError() + if !strings.Contains(msg, "[x] Cannot use both") { + t.Errorf("unexpected message: %s", msg) + } +} + +func TestParityErrorConstructors(t *testing.T) { + errs := []error{ + core.NewNoHarnessError(), + core.NewAmbiguousHarnessError([]string{"a", "b"}), + core.NewUnknownTargetError("x", []string{"claude"}), + core.NewConflictingTargetsError(), + core.NewEmptyTargetsListError(), + } + for _, e := range errs { + if e.Error() == "" { + t.Error("expected non-empty error message") + } + } +} + +// --------------------------------------------------------------------------- +// Parity: scope +// --------------------------------------------------------------------------- + +func TestParityScopeProject(t *testing.T) { + s, ok := core.ParseScope("project") + if !ok || s != core.ScopeProject { + t.Error("project scope parse failed") + } + if s.String() != "project" { + t.Error("scope.String() wrong") + } +} + +func TestParityScopeUser(t *testing.T) { + s, ok := core.ParseScope("user") + if !ok || s != core.ScopeUser { + t.Error("user scope parse failed") + } + if s.String() != "user" { + t.Error("scope.String() wrong") + } +} + +func TestParityScopeDefault(t *testing.T) { + s, ok := core.ParseScope("") + if !ok || s != core.ScopeProject { + t.Error("empty string should map to project scope") + } +} + +func TestParityGetDeployRoot(t *testing.T) { + cwd := "/tmp/proj" + home := "/home/user" + if core.GetDeployRoot(core.ScopeProject, cwd, home) != cwd { + t.Error("project deploy root should be cwd") + } + if core.GetDeployRoot(core.ScopeUser, cwd, home) != home { + t.Error("user deploy root should be home") + } +} + +func TestParityGetAPMDir(t *testing.T) { + cwd := "/tmp/proj" + home := "/home/user" + if core.GetAPMDir(core.ScopeProject, cwd, home) != cwd { + t.Error("project apm dir should be cwd") + } + expected := filepath.Join(home, ".apm") + if core.GetAPMDir(core.ScopeUser, cwd, home) != expected { + t.Errorf("user apm dir wrong: got %s want %s", core.GetAPMDir(core.ScopeUser, cwd, home), expected) + } +} + +// --------------------------------------------------------------------------- +// Parity: target_detection +// --------------------------------------------------------------------------- + +func TestParityDetectTargetExplicit(t *testing.T) { + cases := []struct{ input, want string }{ + {"copilot", "vscode"}, + {"vscode", "vscode"}, + {"agents", "vscode"}, + {"claude", "claude"}, + {"cursor", "cursor"}, + {"opencode", "opencode"}, + {"codex", "codex"}, + {"gemini", "gemini"}, + {"windsurf", "windsurf"}, + {"all", "all"}, + } + for _, c := range cases { + got, reason := core.DetectTarget("/tmp/empty", c.input, "") + if got != c.want { + t.Errorf("DetectTarget explicit %q: want %q got %q", c.input, c.want, got) + } + if reason != "explicit --target flag" { + t.Errorf("expected reason 'explicit --target flag', got %q", reason) + } + } +} + +func TestParityDetectTargetConfig(t *testing.T) { + got, reason := core.DetectTarget("/tmp/empty", "", "claude") + if got != "claude" || reason != "apm.yml target" { + t.Errorf("config target: got %q/%q", got, reason) + } +} + +func TestParityDetectTargetAutoGitHub(t *testing.T) { + dir := t.TempDir() + if err := os.Mkdir(filepath.Join(dir, ".github"), 0755); err != nil { + t.Fatal(err) + } + got, reason := core.DetectTarget(dir, "", "") + if got != "vscode" { + t.Errorf("expected vscode, got %q", got) + } + if !strings.Contains(reason, ".github/") { + t.Errorf("unexpected reason: %q", reason) + } +} + +func TestParityDetectTargetAutoMultiple(t *testing.T) { + dir := t.TempDir() + os.Mkdir(filepath.Join(dir, ".github"), 0755) + os.Mkdir(filepath.Join(dir, ".claude"), 0755) + got, _ := core.DetectTarget(dir, "", "") + if got != "all" { + t.Errorf("expected all, got %q", got) + } +} + +func TestParityDetectTargetNoFolder(t *testing.T) { + dir := t.TempDir() + got, reason := core.DetectTarget(dir, "", "") + if got != "minimal" { + t.Errorf("expected minimal, got %q", got) + } + if reason != core.ReasonNoTargetFolder { + t.Errorf("unexpected reason: %q", reason) + } +} + +func TestParityShouldCompile(t *testing.T) { + agentsTargets := []string{"vscode", "opencode", "codex", "gemini", "windsurf", "all", "minimal"} + for _, t2 := range agentsTargets { + if !core.ShouldCompileAgentsMD(t2) { + t.Errorf("ShouldCompileAgentsMD(%q) should be true", t2) + } + } + if core.ShouldCompileAgentsMD("claude") { + t.Error("ShouldCompileAgentsMD(claude) should be false") + } + if !core.ShouldCompileClaudeMD("claude") || !core.ShouldCompileClaudeMD("all") { + t.Error("ShouldCompileClaudeMD wrong") + } + if core.ShouldCompileClaudeMD("vscode") { + t.Error("ShouldCompileClaudeMD(vscode) should be false") + } + if !core.ShouldCompileGeminiMD("gemini") || !core.ShouldCompileGeminiMD("all") { + t.Error("ShouldCompileGeminiMD wrong") + } + if !core.ShouldCompileCopilotInstructionsMD("vscode") || !core.ShouldCompileCopilotInstructionsMD("all") { + t.Error("ShouldCompileCopilotInstructionsMD wrong") + } + if core.ShouldCompileCopilotInstructionsMD("claude") { + t.Error("ShouldCompileCopilotInstructionsMD(claude) should be false") + } +} + +func TestParityGetTargetDescription(t *testing.T) { + if !strings.Contains(core.GetTargetDescription("vscode"), "AGENTS.md") { + t.Error("vscode description should mention AGENTS.md") + } + if !strings.Contains(core.GetTargetDescription("copilot"), "AGENTS.md") { + t.Error("copilot alias should resolve to vscode description") + } + if !strings.Contains(core.GetTargetDescription("claude"), "CLAUDE.md") { + t.Error("claude description should mention CLAUDE.md") + } +} + +func TestParityNormalizeTargetList(t *testing.T) { + if core.NormalizeTargetList(nil) != nil { + t.Error("nil input should return nil") + } + got := core.NormalizeTargetList([]string{"copilot"}) + if len(got) != 1 || got[0] != "vscode" { + t.Errorf("alias resolution failed: %v", got) + } + got = core.NormalizeTargetList([]string{"claude", "copilot", "claude"}) + if len(got) != 2 { + t.Errorf("dedup failed: %v", got) + } + got = core.NormalizeTargetList([]string{"all"}) + if len(got) == 0 { + t.Error("all should expand to all canonical targets") + } +} + +// --------------------------------------------------------------------------- +// Parity: apm_yml +// --------------------------------------------------------------------------- + +func TestParityParseTargetsFieldPlural(t *testing.T) { + data := map[string]interface{}{ + "targets": []interface{}{"claude", "copilot"}, + } + got, err := core.ParseTargetsField(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 2 || got[0] != "claude" || got[1] != "copilot" { + t.Errorf("unexpected result: %v", got) + } +} + +func TestParityParseTargetsFieldSingular(t *testing.T) { + data := map[string]interface{}{ + "target": "claude", + } + got, err := core.ParseTargetsField(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 1 || got[0] != "claude" { + t.Errorf("unexpected result: %v", got) + } +} + +func TestParityParseTargetsFieldCSV(t *testing.T) { + data := map[string]interface{}{ + "target": "claude,copilot", + } + got, err := core.ParseTargetsField(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 2 { + t.Errorf("CSV parse failed: %v", got) + } +} + +func TestParityParseTargetsFieldEmpty(t *testing.T) { + got, err := core.ParseTargetsField(map[string]interface{}{}) + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Errorf("expected empty, got %v", got) + } +} + +func TestParityParseTargetsFieldConflict(t *testing.T) { + data := map[string]interface{}{"targets": []interface{}{"claude"}, "target": "copilot"} + _, err := core.ParseTargetsField(data) + if err == nil { + t.Error("expected conflict error") + } +} + +func TestParityParseTargetsFieldEmptyList(t *testing.T) { + data := map[string]interface{}{"targets": []interface{}{}} + _, err := core.ParseTargetsField(data) + if err == nil { + t.Error("expected empty list error") + } +} + +func TestParityParseTargetsFieldUnknownTarget(t *testing.T) { + data := map[string]interface{}{"targets": []interface{}{"claude", "unknown-target"}} + _, err := core.ParseTargetsField(data) + if err == nil { + t.Error("expected unknown target error") + } +} diff --git a/internal/core/errors.go b/internal/core/errors.go new file mode 100644 index 00000000..54603338 --- /dev/null +++ b/internal/core/errors.go @@ -0,0 +1,189 @@ +// Package core provides target resolution, auth, scope, and orchestration +// primitives for the APM CLI. +package core + +import "fmt" + +// TargetResolutionError is the base error for all target-resolution failures. +type TargetResolutionError struct { + msg string +} + +func (e *TargetResolutionError) Error() string { return e.msg } + +// NoHarnessError is returned when no harness signal is detected and no +// explicit target is set. +type NoHarnessError struct{ TargetResolutionError } + +// AmbiguousHarnessError is returned when multiple distinct harness signals +// are detected. +type AmbiguousHarnessError struct{ TargetResolutionError } + +// UnknownTargetError is returned when a target token is not in the canonical +// set. +type UnknownTargetError struct{ TargetResolutionError } + +// ConflictingTargetsError is returned when apm.yml contains both 'target:' +// and 'targets:' (mutex). +type ConflictingTargetsError struct{ TargetResolutionError } + +// EmptyTargetsListError is returned when apm.yml 'targets:' is present but +// empty. +type EmptyTargetsListError struct{ TargetResolutionError } + +// signal list used in error messages (mirrors _SIGNAL_LIST in errors.py) +const signalList = ".claude/, CLAUDE.md, .cursor/, .cursorrules, " + + ".github/copilot-instructions.md, .codex/, .gemini/, GEMINI.md, " + + ".opencode/, .windsurf/" + +// RenderNoHarnessError returns the three-section error string for "no signal +// detected". +func RenderNoHarnessError() string { + return "[x] No harness detected\n" + + "\n" + + "APM scanned for harness markers (" + signalList + ")" + + " but found none in this project.\n" + + "\n" + + "Previously APM defaulted to copilot; this is now explicit.\n" + + "\n" + + "Fix with one of:\n" + + "\n" + + " apm targets # see all supported harnesses\n" + + " apm install --target claude # deploy to a specific harness\n" + + " apm install --target copilot # or any supported target\n" + + "\n" + + "Or declare in apm.yml:\n" + + "\n" + + " targets:\n" + + " - claude" +} + +// RenderAmbiguousError returns the three-section error string for "multiple +// harnesses detected". +func RenderAmbiguousError(detected []string) string { + if len(detected) == 0 { + return "[x] Multiple harnesses detected" + } + detectedCSV := joinStrings(detected, ", ") + first := detected[0] + return fmt.Sprintf("[x] Multiple harnesses detected: %s\n", detectedCSV) + + "\n" + + fmt.Sprintf("APM found signals for %s but cannot decide which\n", detectedCSV) + + "to deploy to. Pin your target explicitly.\n" + + "\n" + + "Fix with one of:\n" + + "\n" + + fmt.Sprintf(" apm install --target %s\n", first) + + " apm install --dry-run # preview what each target does\n" + + " apm targets # see all detected harnesses\n" + + "\n" + + "Or declare in apm.yml:\n" + + "\n" + + " targets:\n" + + fmt.Sprintf(" - %s", first) +} + +// RenderUnknownTargetError returns the three-section error string for an +// unknown target token. +func RenderUnknownTargetError(value string, valid []string) string { + // hide agent-skills from user-facing list + var visible []string + for _, t := range valid { + if t != "agent-skills" { + visible = append(visible, t) + } + } + sortStrings(visible) + suggestion := "copilot" + for _, t := range visible { + if t == "copilot" { + suggestion = "copilot" + break + } + suggestion = t + } + validCSV := joinStrings(visible, ", ") + if validCSV == "" { + validCSV = suggestion + } + displayValue := stripBracketNoise(value) + if displayValue == "" { + displayValue = value + } + if displayValue == "" { + displayValue = "" + } + return fmt.Sprintf("[x] Unknown target '%s'\n", displayValue) + + "\n" + + fmt.Sprintf("Valid targets: %s\n", validCSV) + + "\n" + + "Fix with one of:\n" + + "\n" + + " apm targets # see all supported harnesses\n" + + fmt.Sprintf(" apm install --target %s\n", suggestion) + + " apm install --dry-run\n" + + "\n" + + "Or declare in apm.yml:\n" + + "\n" + + " targets:\n" + + fmt.Sprintf(" - %s", suggestion) +} + +// RenderConflictingSchemaError returns the error string for target/targets +// mutex. +func RenderConflictingSchemaError() string { + return "[x] Cannot use both 'target:' and 'targets:' in apm.yml\n" + + "\n" + + "Use the canonical plural form:\n" + + "\n" + + "Fix with one of:\n" + + "\n" + + " apm targets # see all supported harnesses\n" + + " apm install --target claude\n" + + " apm init # regenerate apm.yml\n" + + "\n" + + "Or update apm.yml to use the canonical form:\n" + + "\n" + + " targets:\n" + + " - claude\n" + + " - copilot" +} + +// NewNoHarnessError constructs a NoHarnessError. +func NewNoHarnessError() *NoHarnessError { + return &NoHarnessError{TargetResolutionError{RenderNoHarnessError()}} +} + +// NewAmbiguousHarnessError constructs an AmbiguousHarnessError. +func NewAmbiguousHarnessError(detected []string) *AmbiguousHarnessError { + return &AmbiguousHarnessError{TargetResolutionError{RenderAmbiguousError(detected)}} +} + +// NewUnknownTargetError constructs an UnknownTargetError. +func NewUnknownTargetError(value string, valid []string) *UnknownTargetError { + return &UnknownTargetError{TargetResolutionError{RenderUnknownTargetError(value, valid)}} +} + +// NewConflictingTargetsError constructs a ConflictingTargetsError. +func NewConflictingTargetsError() *ConflictingTargetsError { + return &ConflictingTargetsError{TargetResolutionError{RenderConflictingSchemaError()}} +} + +// NewEmptyTargetsListError constructs an EmptyTargetsListError. +func NewEmptyTargetsListError() *EmptyTargetsListError { + msg := "[x] 'targets:' in apm.yml is empty\n" + + "\n" + + "The targets list must contain at least one target.\n" + + "\n" + + "Fix with one of:\n" + + "\n" + + " apm targets # see all supported harnesses\n" + + " apm install --target claude\n" + + " apm init\n" + + "\n" + + "Or update apm.yml:\n" + + "\n" + + " targets:\n" + + " - claude" + return &EmptyTargetsListError{TargetResolutionError{msg}} +} diff --git a/internal/core/helpers.go b/internal/core/helpers.go new file mode 100644 index 00000000..a1587cb7 --- /dev/null +++ b/internal/core/helpers.go @@ -0,0 +1,44 @@ +package core + +import ( + "sort" + "strings" +) + +// joinStrings joins a slice of strings with sep. +func joinStrings(ss []string, sep string) string { + return strings.Join(ss, sep) +} + +// splitCSV splits a comma-separated string, trimming whitespace. +func splitCSV(s string) []string { + parts := strings.Split(s, ",") + result := make([]string, 0, len(parts)) + for _, p := range parts { + t := strings.TrimSpace(p) + if t != "" { + result = append(result, t) + } + } + return result +} + +// sortStrings sorts a slice of strings in place. +func sortStrings(ss []string) { + sort.Strings(ss) +} + +// sortedKeys returns the keys of a map[string]bool sorted. +func sortedKeys(m map[string]bool) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +// stripBracketNoise removes leading/trailing []'" and space characters. +func stripBracketNoise(s string) string { + return strings.Trim(s, "[]'\" ") +} diff --git a/internal/core/scope.go b/internal/core/scope.go new file mode 100644 index 00000000..d4d32f29 --- /dev/null +++ b/internal/core/scope.go @@ -0,0 +1,67 @@ +package core + +import "path/filepath" + +// InstallScope controls where packages are deployed. +type InstallScope int + +const ( + // ScopeProject deploys to the current working directory (default). + ScopeProject InstallScope = iota + // ScopeUser deploys to user-level directories (~/.apm/). + ScopeUser +) + +// userAPMDir is the directory under $HOME for user-scope metadata. +const userAPMDir = ".apm" + +// GetDeployRoot returns the root directory used to construct deployment paths. +func GetDeployRoot(scope InstallScope, cwd, home string) string { + if scope == ScopeUser { + return home + } + return cwd +} + +// GetAPMDir returns the directory that holds APM metadata. +func GetAPMDir(scope InstallScope, cwd, home string) string { + if scope == ScopeUser { + return filepath.Join(home, userAPMDir) + } + return cwd +} + +// GetModulesDir returns the apm_modules directory for scope. +func GetModulesDir(scope InstallScope, cwd, home, apmModulesDir string) string { + return filepath.Join(GetAPMDir(scope, cwd, home), apmModulesDir) +} + +// GetManifestPath returns the apm.yml path for scope. +func GetManifestPath(scope InstallScope, cwd, home, apmYMLFilename string) string { + return filepath.Join(GetAPMDir(scope, cwd, home), apmYMLFilename) +} + +// GetLockfileDir returns the directory containing the lockfile for scope. +func GetLockfileDir(scope InstallScope, cwd, home string) string { + return GetAPMDir(scope, cwd, home) +} + +// ParseScope parses a scope string ("project" or "user") into an InstallScope. +// Returns ScopeProject and false for unknown values. +func ParseScope(s string) (InstallScope, bool) { + switch s { + case "user": + return ScopeUser, true + case "project", "": + return ScopeProject, true + } + return ScopeProject, false +} + +// String returns the string representation of the scope. +func (s InstallScope) String() string { + if s == ScopeUser { + return "user" + } + return "project" +} diff --git a/internal/core/target_detection.go b/internal/core/target_detection.go new file mode 100644 index 00000000..c7fd737c --- /dev/null +++ b/internal/core/target_detection.go @@ -0,0 +1,197 @@ +package core + +import ( + "os" + "path/filepath" +) + +// TargetType is the canonical internal target name. +type TargetType = string + +// ReasonNoTargetFolder is returned by DetectTarget when no integration folder +// is present. +const ReasonNoTargetFolder = "no target folder found" + +// AllCanonicalTargets is the complete set of real (non-pseudo) canonical +// targets. "minimal" is intentionally excluded. +var AllCanonicalTargets = map[string]bool{ + "vscode": true, + "claude": true, + "cursor": true, + "opencode": true, + "codex": true, + "gemini": true, + "windsurf": true, +} + +// TargetAliases maps user-facing names to canonical internal names. +var TargetAliases = map[string]string{ + "copilot": "vscode", + "agents": "vscode", + "vscode": "vscode", +} + +// DetectTarget detects the appropriate target for compilation and integration. +// It returns (target, reason) following the priority rules: +// 1. Explicit --target flag +// 2. apm.yml target setting +// 3. Auto-detect from existing folders +func DetectTarget(projectRoot, explicitTarget, configTarget string) (string, string) { + // Priority 1: explicit --target flag + if explicitTarget != "" { + return resolveAlias(explicitTarget), "explicit --target flag" + } + // Priority 2: apm.yml target + if configTarget != "" { + return resolveAlias(configTarget), "apm.yml target" + } + // Priority 3: auto-detect from folders + githubExists := dirExists(filepath.Join(projectRoot, ".github")) + claudeExists := dirExists(filepath.Join(projectRoot, ".claude")) + cursorExists := dirExists(filepath.Join(projectRoot, ".cursor")) + opencodeExists := dirExists(filepath.Join(projectRoot, ".opencode")) + codexExists := dirExists(filepath.Join(projectRoot, ".codex")) + geminiExists := dirExists(filepath.Join(projectRoot, ".gemini")) + windsurfExists := dirExists(filepath.Join(projectRoot, ".windsurf")) + + var detected []string + if githubExists { + detected = append(detected, ".github/") + } + if claudeExists { + detected = append(detected, ".claude/") + } + if cursorExists { + detected = append(detected, ".cursor/") + } + if opencodeExists { + detected = append(detected, ".opencode/") + } + if codexExists { + detected = append(detected, ".codex/") + } + if geminiExists { + detected = append(detected, ".gemini/") + } + if windsurfExists { + detected = append(detected, ".windsurf/") + } + + if len(detected) >= 2 { + return "all", "detected " + joinStrings(detected, " and ") + " folders" + } + if githubExists { + return "vscode", "detected .github/ folder" + } + if claudeExists { + return "claude", "detected .claude/ folder" + } + if cursorExists { + return "cursor", "detected .cursor/ folder" + } + if opencodeExists { + return "opencode", "detected .opencode/ folder" + } + if codexExists { + return "codex", "detected .codex/ folder" + } + if geminiExists { + return "gemini", "detected .gemini/ folder" + } + if windsurfExists { + return "windsurf", "detected .windsurf/ folder" + } + return "minimal", ReasonNoTargetFolder +} + +// ShouldCompileAgentsMD reports whether AGENTS.md should be compiled for the +// given target. AGENTS.md is generated for vscode, opencode, codex, gemini, +// windsurf, all, and minimal. +func ShouldCompileAgentsMD(target string) bool { + switch target { + case "vscode", "opencode", "codex", "gemini", "windsurf", "all", "minimal": + return true + } + return false +} + +// ShouldCompileClaudeMD reports whether CLAUDE.md should be compiled. +func ShouldCompileClaudeMD(target string) bool { + return target == "claude" || target == "all" +} + +// ShouldCompileGeminiMD reports whether GEMINI.md should be compiled. +func ShouldCompileGeminiMD(target string) bool { + return target == "gemini" || target == "all" +} + +// ShouldCompileCopilotInstructionsMD reports whether +// .github/copilot-instructions.md should be compiled. +func ShouldCompileCopilotInstructionsMD(target string) bool { + return target == "vscode" || target == "all" +} + +// GetTargetDescription returns a human-readable description of what will be +// generated for a target (accepts both internal types and user-facing aliases). +func GetTargetDescription(target string) string { + normalized := target + if target == "copilot" || target == "agents" { + normalized = "vscode" + } + descriptions := map[string]string{ + "vscode": "AGENTS.md + .github/copilot-instructions.md + .github/prompts/ + .github/agents/", + "claude": "CLAUDE.md + .claude/commands/ + .claude/agents/ + .claude/skills/", + "cursor": ".cursor/agents/ + .cursor/skills/ + .cursor/rules/", + "opencode": "AGENTS.md + .opencode/agents/ + .opencode/commands/ + .opencode/skills/", + "codex": "AGENTS.md + .agents/skills/ + .codex/agents/ + .codex/hooks.json", + "gemini": "GEMINI.md + .gemini/commands/ + .gemini/skills/ + .gemini/settings.json (MCP/hooks)", + "windsurf": "AGENTS.md + .windsurf/rules/ + .windsurf/skills/ + .windsurf/workflows/ + .windsurf/hooks.json", + "agent-skills": ".agents/skills/ only (cross-client shared skills -- no agents, hooks, or commands)", + "all": "AGENTS.md + CLAUDE.md + GEMINI.md + .github/copilot-instructions.md + .github/ + .claude/ + .cursor/ + .opencode/ + .codex/ + .gemini/ + .windsurf/ + .agents/", + "minimal": "AGENTS.md only (create .github/, .claude/, or .gemini/ for full integration)", + } + if d, ok := descriptions[normalized]; ok { + return d + } + return "unknown target" +} + +// NormalizeTargetList normalizes a target value to a list of canonical names. +// Returns nil for nil input (meaning "auto-detect"). +func NormalizeTargetList(targets []string) []string { + if targets == nil { + return nil + } + for _, t := range targets { + if t == "all" { + return sortedKeys(AllCanonicalTargets) + } + } + seen := map[string]bool{} + var result []string + for _, item := range targets { + canonical := item + if a, ok := TargetAliases[item]; ok { + canonical = a + } + if !seen[canonical] { + seen[canonical] = true + result = append(result, canonical) + } + } + return result +} + +// resolveAlias converts a user-facing target name to its canonical form. +func resolveAlias(t string) string { + if a, ok := TargetAliases[t]; ok { + return a + } + return t +} + +// dirExists reports whether path is an existing directory. +func dirExists(path string) bool { + info, err := os.Stat(path) + return err == nil && info.IsDir() +} diff --git a/internal/core/token_manager.go b/internal/core/token_manager.go new file mode 100644 index 00000000..eb322bcc --- /dev/null +++ b/internal/core/token_manager.go @@ -0,0 +1,349 @@ +// Package core provides core APM CLI functionality. +// token_manager.go mirrors src/apm_cli/core/token_manager.py. +package core + +import ( + "net/url" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "sync" + "time" + "unicode" + + "github.com/githubnext/apm/internal/utils/githubhost" +) + +// tokenCacheKey identifies a host+port pair for credential caching. +type tokenCacheKey struct { + host string + port int // 0 means no port +} + +// TokenPrecedence maps purpose to ordered env var names. +// Mirrors GitHubTokenManager.TOKEN_PRECEDENCE in token_manager.py. +var TokenPrecedence = map[string][]string{ + "copilot": {"GITHUB_COPILOT_PAT", "GITHUB_TOKEN", "GITHUB_APM_PAT"}, + "models": {"GITHUB_TOKEN", "GITHUB_APM_PAT"}, + "modules": {"GITHUB_APM_PAT", "GITHUB_TOKEN", "GH_TOKEN"}, + "gitlab_modules": {"GITLAB_APM_PAT", "GITLAB_TOKEN"}, + "generic_modules": {}, + "ado_modules": {"ADO_APM_PAT"}, + "artifactory_modules": {"ARTIFACTORY_APM_TOKEN"}, +} + +// RuntimeEnvVars maps runtime to env var names to set. +var RuntimeEnvVars = map[string][]string{ + "copilot": {"GH_TOKEN", "GITHUB_PERSONAL_ACCESS_TOKEN"}, + "codex": {"GITHUB_TOKEN"}, + "llm": {"GITHUB_MODELS_KEY"}, +} + +const ( + adoBearerSource = "AAD_BEARER_AZ_CLI" + defaultCredentialTimeout = 60 + maxCredentialTimeout = 180 +) + +// GitHubTokenManager manages GitHub token environment setup for different AI runtimes. +// Mirrors GitHubTokenManager in token_manager.py. +type GitHubTokenManager struct { + PreserveExisting bool + mu sync.Mutex + credentialCache map[tokenCacheKey]*string // *string so nil = "not found", "" = cached none +} + +// NewGitHubTokenManager constructs a manager with preserve_existing=true (Python default). +func NewGitHubTokenManager() *GitHubTokenManager { + return &GitHubTokenManager{ + PreserveExisting: true, + credentialCache: make(map[tokenCacheKey]*string), + } +} + +// credentialTimeout returns the timeout in seconds for git credential fill. +// Configurable via APM_GIT_CREDENTIAL_TIMEOUT. +func credentialTimeout() time.Duration { + raw := strings.TrimSpace(os.Getenv("APM_GIT_CREDENTIAL_TIMEOUT")) + if raw == "" { + return time.Duration(defaultCredentialTimeout) * time.Second + } + v, err := strconv.Atoi(raw) + if err != nil { + return time.Duration(defaultCredentialTimeout) * time.Second + } + if v < 1 { + v = 1 + } + if v > maxCredentialTimeout { + v = maxCredentialTimeout + } + return time.Duration(v) * time.Second +} + +// isValidCredentialToken validates that a credential-fill token is not garbage. +func isValidCredentialToken(token string) bool { + if token == "" || len(token) > 1024 { + return false + } + for _, c := range token { + if c == ' ' || c == '\t' || c == '\n' || c == '\r' { + return false + } + } + promptFragments := []string{ + "Password for", "Username for", "password for", "username for", + } + for _, f := range promptFragments { + if strings.Contains(token, f) { + return false + } + } + return true +} + +// sanitizeCredentialPath strips leading '/', rejects control chars, allowlists URL schemes. +// Mirrors _sanitize_credential_path in token_manager.py. +func sanitizeCredentialPath(path string) string { + parsed, err := url.Parse(path) + if err != nil { + return "" + } + scheme := strings.ToLower(parsed.Scheme) + var cleaned string + if scheme != "" { + if scheme != "https" && scheme != "http" && scheme != "ssh" { + return "" + } + cleaned = strings.TrimLeft(parsed.Path, "/") + } else { + cleaned = strings.TrimLeft(path, "/") + } + if cleaned == "" { + return "" + } + for _, ch := range cleaned { + if ch < 0x20 || ch == 0x7F || unicode.IsSpace(ch) { + return "" + } + } + return cleaned +} + +// formatCredentialHost embeds a non-standard port into the host field per gitcredentials(7). +func formatCredentialHost(host string, port int) string { + if port != 0 { + return host + ":" + strconv.Itoa(port) + } + return host +} + +// ResolveCredentialFromGit queries git credential fill for a token. +// Mirrors GitHubTokenManager.resolve_credential_from_git. +func ResolveCredentialFromGit(host string, port int, path string) (string, bool) { + hostField := formatCredentialHost(host, port) + lines := []string{"protocol=https", "host=" + hostField} + if path != "" { + if sanitized := sanitizeCredentialPath(path); sanitized != "" { + lines = append(lines, "path="+sanitized) + } + } + stdin := strings.Join(lines, "\n") + "\n\n" + + env := os.Environ() + env = append(env, "GIT_TERMINAL_PROMPT=0") + if runtime.GOOS != "windows" { + env = append(env, "GIT_ASKPASS=") + } else { + env = append(env, "GIT_ASKPASS=echo") + } + + cmd := exec.Command("git", "credential", "fill") + cmd.Stdin = strings.NewReader(stdin) + cmd.Env = env + + done := make(chan struct{}) + var out []byte + var runErr error + go func() { + defer close(done) + out, runErr = cmd.Output() + }() + select { + case <-done: + case <-time.After(credentialTimeout()): + if cmd.Process != nil { + cmd.Process.Kill() + } + return "", false + } + if runErr != nil { + return "", false + } + for _, line := range strings.Split(string(out), "\n") { + if strings.HasPrefix(line, "password=") { + token := line[len("password="):] + if isValidCredentialToken(token) { + return token, true + } + return "", false + } + } + return "", false +} + +// ResolveCredentialFromGHCLI resolves a token from the active gh CLI account. +// Mirrors GitHubTokenManager.resolve_credential_from_gh_cli. +func ResolveCredentialFromGHCLI(host string) (string, bool) { + if !githubhost.SupportGHCLIHost(host) { + return "", false + } + env := os.Environ() + env = append(env, "GH_PROMPT_DISABLED=1", "GH_NO_UPDATE_NOTIFIER=1") + + cmd := exec.Command("gh", "auth", "token", "--hostname", host) + cmd.Env = env + cmd.Stdin = nil + + done := make(chan struct{}) + var out []byte + var runErr error + go func() { + defer close(done) + out, runErr = cmd.Output() + }() + select { + case <-done: + case <-time.After(credentialTimeout()): + if cmd.Process != nil { + cmd.Process.Kill() + } + return "", false + } + if runErr != nil { + return "", false + } + token := strings.TrimSpace(string(out)) + if isValidCredentialToken(token) { + return token, true + } + return "", false +} + +// GetTokenForPurpose returns the first available env var for the given purpose. +// Mirrors GitHubTokenManager.get_token_for_purpose. +func (m *GitHubTokenManager) GetTokenForPurpose(purpose string, env map[string]string) (string, bool) { + vars, ok := TokenPrecedence[purpose] + if !ok { + return "", false + } + for _, v := range vars { + if token := env[v]; token != "" { + return token, true + } + } + return "", false +} + +// GetTokenWithCredentialFallback tries env vars, then gh CLI, then git credential fill. +// Mirrors GitHubTokenManager.get_token_with_credential_fallback. +func (m *GitHubTokenManager) GetTokenWithCredentialFallback( + purpose, host string, env map[string]string, port int, +) (string, bool) { + if token, ok := m.GetTokenForPurpose(purpose, env); ok { + return token, true + } + + key := tokenCacheKey{host: host, port: port} + m.mu.Lock() + if cached, hit := m.credentialCache[key]; hit { + m.mu.Unlock() + if cached == nil || *cached == "" { + return "", false + } + return *cached, true + } + m.mu.Unlock() + + var result string + if githubhost.SupportGHCLIHost(host) { + if token, ok := ResolveCredentialFromGHCLI(host); ok { + result = token + } + } + if result == "" { + if token, ok := ResolveCredentialFromGit(host, port, ""); ok { + result = token + } + } + + m.mu.Lock() + if result != "" { + m.credentialCache[key] = &result + } else { + empty := "" + m.credentialCache[key] = &empty + } + m.mu.Unlock() + + if result != "" { + return result, true + } + return "", false +} + +// OSEnvMap returns os.Environ() as a map[string]string. +func OSEnvMap() map[string]string { + m := make(map[string]string) + for _, e := range os.Environ() { + idx := strings.Index(e, "=") + if idx < 0 { + continue + } + m[e[:idx]] = e[idx+1:] + } + return m +} + +// ValidateTokens checks that at least one useful token is available. +// Mirrors GitHubTokenManager.validate_tokens. +func (m *GitHubTokenManager) ValidateTokens(env map[string]string) (bool, string) { + if env == nil { + env = OSEnvMap() + } + hasAny := false + for _, purpose := range []string{"copilot", "models", "modules"} { + if _, ok := m.GetTokenForPurpose(purpose, env); ok { + hasAny = true + break + } + } + if !hasAny { + return false, "No tokens found. Set one of:\n" + + "- GITHUB_TOKEN (user-scoped PAT for GitHub Models)\n" + + "- GITHUB_APM_PAT (fine-grained PAT for APM modules on GitHub)\n" + + "- ADO_APM_PAT (PAT for APM modules on Azure DevOps)" + } + if _, ok := m.GetTokenForPurpose("models", env); !ok { + if env["GITHUB_APM_PAT"] != "" { + return true, "Warning: Only fine-grained PAT available. GitHub Models requires GITHUB_TOKEN (user-scoped PAT)" + } + } + return true, "Token validation passed" +} + +// IdentifyEnvSource returns the name of the first env var that matched for purpose. +// Mirrors AuthResolver._identify_env_source. +func (m *GitHubTokenManager) IdentifyEnvSource(purpose string) string { + for _, v := range TokenPrecedence[purpose] { + if os.Getenv(v) != "" { + return v + } + } + return "env" +} + +// ADOBearerSource is the diagnostic source label for ADO bearer tokens. +const ADOBearerSource = adoBearerSource diff --git a/internal/deps/deps_test.go b/internal/deps/deps_test.go new file mode 100644 index 00000000..990e9785 --- /dev/null +++ b/internal/deps/deps_test.go @@ -0,0 +1,283 @@ +package deps + +import ( + "strings" + "testing" +) + +// TestParityDependencyNodeGetID mirrors DependencyNode.get_id() parity. +func TestParityDependencyNodeGetID(t *testing.T) { + n := &DependencyNode{RepoURL: "https://github.com/owner/repo"} + if got := n.GetID(); got != "https://github.com/owner/repo" { + t.Errorf("GetID no-ref: got %q", got) + } + + n2 := &DependencyNode{RepoURL: "https://github.com/owner/repo", Reference: "main"} + if got := n2.GetID(); got != "https://github.com/owner/repo#main" { + t.Errorf("GetID with-ref: got %q", got) + } +} + +// TestParityDependencyNodeAncestorChain mirrors DependencyNode.get_ancestor_chain(). +func TestParityDependencyNodeAncestorChain(t *testing.T) { + root := &DependencyNode{RepoURL: "root"} + mid := &DependencyNode{RepoURL: "mid", Parent: root} + leaf := &DependencyNode{RepoURL: "leaf", Parent: mid} + + if got := leaf.GetAncestorChain(); got != "root > mid > leaf" { + t.Errorf("ancestor chain: got %q", got) + } +} + +// TestParityCircularRefString mirrors CircularRef.__str__(). +func TestParityCircularRefString(t *testing.T) { + cr := &CircularRef{CyclePath: []string{"a", "b", "c"}, DetectedAtDepth: 2} + got := cr.String() + if !strings.Contains(got, "a -> b -> c -> a") { + t.Errorf("circular ref string: got %q", got) + } + + // Single element -- no arrow appended + cr2 := &CircularRef{CyclePath: []string{"a"}} + got2 := cr2.String() + if !strings.Contains(got2, "a") { + t.Errorf("single circular ref: got %q", got2) + } + + // Empty + cr3 := &CircularRef{} + got3 := cr3.String() + if !strings.Contains(got3, "empty path") { + t.Errorf("empty circular ref: got %q", got3) + } +} + +// TestParityFlatDependencyMapFirstWins verifies first-wins conflict semantics. +func TestParityFlatDependencyMapFirstWins(t *testing.T) { + f := NewFlatDependencyMap() + f.AddDependency("https://github.com/owner/repo", "main") + f.AddDependency("https://github.com/owner/repo", "v1.0") // should not overwrite + + if got := f.Dependencies["https://github.com/owner/repo"]; got != "main" { + t.Errorf("first-wins: expected main, got %q", got) + } + if f.TotalDependencies() != 1 { + t.Errorf("total: expected 1, got %d", f.TotalDependencies()) + } +} + +// TestParityFlatDependencyMapInstallOrder mirrors install_order list. +func TestParityFlatDependencyMapInstallOrder(t *testing.T) { + f := NewFlatDependencyMap() + f.AddDependency("a", "main") + f.AddDependency("b", "main") + f.AddDependency("c", "main") + + order := f.GetInstallationList() + if len(order) != 3 || order[0] != "a" || order[1] != "b" || order[2] != "c" { + t.Errorf("install order: got %v", order) + } +} + +// TestParityDependencyTreeMaxDepth mirrors DependencyTree.max_depth. +func TestParityDependencyTreeMaxDepth(t *testing.T) { + tree := NewDependencyTree("root") + tree.AddNode(&DependencyNode{RepoURL: "a", Depth: 1}) + tree.AddNode(&DependencyNode{RepoURL: "b", Depth: 3}) + tree.AddNode(&DependencyNode{RepoURL: "c", Depth: 2}) + + if tree.MaxDepth != 3 { + t.Errorf("max_depth: expected 3, got %d", tree.MaxDepth) + } +} + +// TestParityDependencyTreeHasDependency mirrors DependencyTree.has_dependency(). +func TestParityDependencyTreeHasDependency(t *testing.T) { + tree := NewDependencyTree("root") + tree.AddNode(&DependencyNode{RepoURL: "https://github.com/owner/repo", Depth: 1}) + + if !tree.HasDependency("https://github.com/owner/repo") { + t.Error("should find existing dep") + } + if tree.HasDependency("https://github.com/other/repo") { + t.Error("should not find missing dep") + } +} + +// TestParityDependencyGraphIsValid mirrors DependencyGraph.is_valid(). +func TestParityDependencyGraphIsValid(t *testing.T) { + g := NewDependencyGraph("root") + if !g.IsValid() { + t.Error("empty graph should be valid") + } + + g.AddError("some error") + if g.IsValid() { + t.Error("graph with errors should be invalid") + } + + g2 := NewDependencyGraph("root2") + g2.AddCircularDependency(CircularRef{CyclePath: []string{"a", "b"}}) + if g2.IsValid() { + t.Error("graph with circular dep should be invalid") + } +} + +// TestParityDependencyGraphSummaryKeys mirrors DependencyGraph.get_summary() keys. +func TestParityDependencyGraphSummaryKeys(t *testing.T) { + g := NewDependencyGraph("my-pkg") + g.Flattened.AddDependency("dep1", "main") + summary := g.GetSummary() + + expected := []string{ + "root_package", "total_dependencies", "max_depth", + "has_circular_dependencies", "circular_count", + "has_conflicts", "conflict_count", + "has_errors", "error_count", "is_valid", + } + for _, k := range expected { + if _, ok := summary[k]; !ok { + t.Errorf("summary missing key %q", k) + } + } + if summary["root_package"] != "my-pkg" { + t.Errorf("root_package: got %v", summary["root_package"]) + } + if summary["total_dependencies"].(int) != 1 { + t.Errorf("total_dependencies: got %v", summary["total_dependencies"]) + } +} + +// TestParityLockedDependencyGetUniqueKey mirrors LockedDependency.get_unique_key(). +func TestParityLockedDependencyGetUniqueKey(t *testing.T) { + // Normal dep + ld := &LockedDependency{RepoURL: "https://github.com/owner/repo", Depth: 1} + if got := ld.GetUniqueKey(); got != "https://github.com/owner/repo" { + t.Errorf("normal dep key: got %q", got) + } + + // Local dep + ld2 := &LockedDependency{RepoURL: "https://github.com/owner/repo", Source: "local", LocalPath: "./local/pkg"} + if got := ld2.GetUniqueKey(); got != "./local/pkg" { + t.Errorf("local dep key: got %q", got) + } + + // Virtual dep + ld3 := &LockedDependency{ + RepoURL: "https://github.com/owner/mono", + IsVirtual: true, + VirtualPath: "packages/sub", + } + if got := ld3.GetUniqueKey(); got != "https://github.com/owner/mono/packages/sub" { + t.Errorf("virtual dep key: got %q", got) + } +} + +// TestParityLockedDependencyToMap mirrors LockedDependency.to_dict() behavior. +func TestParityLockedDependencyToMap(t *testing.T) { + ld := &LockedDependency{ + RepoURL: "https://github.com/owner/repo", + ResolvedCommit: "abc1234", + Depth: 2, + IsDev: true, + } + m := ld.ToMap() + if m["repo_url"] != "https://github.com/owner/repo" { + t.Errorf("repo_url: got %v", m["repo_url"]) + } + if m["resolved_commit"] != "abc1234" { + t.Errorf("resolved_commit: got %v", m["resolved_commit"]) + } + if m["depth"] != 2 { + t.Errorf("depth: got %v", m["depth"]) + } + if m["is_dev"] != true { + t.Errorf("is_dev: got %v", m["is_dev"]) + } + // depth==1 should be omitted (default) + ld2 := &LockedDependency{RepoURL: "x", Depth: 1} + m2 := ld2.ToMap() + if _, ok := m2["depth"]; ok { + t.Error("depth==1 should be omitted from map") + } +} + +// TestParityLockedDependencyDeployedFilesSorted mirrors sorted deployed_files. +func TestParityLockedDependencyDeployedFilesSorted(t *testing.T) { + ld := &LockedDependency{ + RepoURL: "x", + DeployedFiles: []string{"z.md", "a.md", "m.md"}, + } + m := ld.ToMap() + files := m["deployed_files"].([]string) + if files[0] != "a.md" || files[1] != "m.md" || files[2] != "z.md" { + t.Errorf("deployed_files not sorted: %v", files) + } +} + +// TestParityLockedDependencyFromMap mirrors LockedDependency.from_dict(). +func TestParityLockedDependencyFromMap(t *testing.T) { + data := map[string]interface{}{ + "repo_url": "https://github.com/owner/repo", + "resolved_commit": "deadbeef", + "depth": float64(2), + "is_dev": true, + "port": float64(7999), + } + ld := LockedDependencyFromMap(data) + if ld.RepoURL != "https://github.com/owner/repo" { + t.Errorf("repo_url: got %q", ld.RepoURL) + } + if ld.ResolvedCommit != "deadbeef" { + t.Errorf("resolved_commit: got %q", ld.ResolvedCommit) + } + if ld.Depth != 2 { + t.Errorf("depth: got %d", ld.Depth) + } + if !ld.IsDev { + t.Error("is_dev should be true") + } + if ld.Port != 7999 { + t.Errorf("port: got %d", ld.Port) + } +} + +// TestParityLockedDependencyPortValidation mirrors port range validation in from_dict(). +func TestParityLockedDependencyPortValidation(t *testing.T) { + // Invalid port (out of range) should be ignored + data := map[string]interface{}{ + "repo_url": "x", + "port": float64(99999), + } + ld := LockedDependencyFromMap(data) + if ld.Port != 0 { + t.Errorf("out-of-range port should be 0, got %d", ld.Port) + } + + // Valid port + data2 := map[string]interface{}{ + "repo_url": "x", + "port": float64(443), + } + ld2 := LockedDependencyFromMap(data2) + if ld2.Port != 443 { + t.Errorf("valid port: got %d", ld2.Port) + } +} + +// TestParityInstalledPackageFields checks InstalledPackage fields exist. +func TestParityInstalledPackageFields(t *testing.T) { + ip := InstalledPackage{ + RepoURL: "https://github.com/owner/repo", + Reference: "main", + ResolvedCommit: "abc1234", + Depth: 1, + IsDev: false, + } + if ip.RepoURL == "" { + t.Error("RepoURL should be set") + } + if ip.Depth != 1 { + t.Errorf("Depth: got %d", ip.Depth) + } +} diff --git a/internal/deps/graph.go b/internal/deps/graph.go new file mode 100644 index 00000000..bf56797a --- /dev/null +++ b/internal/deps/graph.go @@ -0,0 +1,229 @@ +// Package deps implements dependency graph data structures for APM. +// Mirrors src/apm_cli/deps/dependency_graph.py. +package deps + +// DependencyNode represents a single dependency node in the dependency graph. +// Mirrors src/apm_cli/deps/dependency_graph.py:DependencyNode. +type DependencyNode struct { + RepoURL string + Reference string // git ref (branch/tag/commit), empty means default + Depth int + Children []*DependencyNode + Parent *DependencyNode + IsDev bool +} + +// GetID returns a unique identifier for this node. +// Mirrors DependencyNode.get_id(). +func (n *DependencyNode) GetID() string { + if n.Reference != "" { + return n.RepoURL + "#" + n.Reference + } + return n.RepoURL +} + +// GetAncestorChain builds a breadcrumb from this node's ancestry. +// Mirrors DependencyNode.get_ancestor_chain(). +func (n *DependencyNode) GetAncestorChain() string { + var parts []string + current := n + for current != nil { + parts = append(parts, current.RepoURL) + current = current.Parent + } + // reverse + for i, j := 0, len(parts)-1; i < j; i, j = i+1, j-1 { + parts[i], parts[j] = parts[j], parts[i] + } + result := "" + for i, p := range parts { + if i > 0 { + result += " > " + } + result += p + } + return result +} + +// CircularRef represents a circular dependency reference. +// Mirrors src/apm_cli/deps/dependency_graph.py:CircularRef. +type CircularRef struct { + CyclePath []string + DetectedAtDepth int +} + +// String formats the circular dependency for display. +func (c *CircularRef) String() string { + if len(c.CyclePath) == 0 { + return "Circular dependency detected: (empty path)" + } + result := "Circular dependency detected: " + for i, p := range c.CyclePath { + if i > 0 { + result += " -> " + } + result += p + } + if len(c.CyclePath) > 1 && c.CyclePath[0] != c.CyclePath[len(c.CyclePath)-1] { + result += " -> " + c.CyclePath[0] + } + return result +} + +// ConflictInfo describes a dependency version conflict. +// Mirrors src/apm_cli/deps/dependency_graph.py:ConflictInfo. +type ConflictInfo struct { + RepoURL string + WinnerRef string // reference string of the winning dep + Conflicts []string + Reason string +} + +// FlatDependencyMap is the final flattened dependency mapping ready for install. +// Mirrors src/apm_cli/deps/dependency_graph.py:FlatDependencyMap. +type FlatDependencyMap struct { + Dependencies map[string]string // unique_key -> resolved ref + Conflicts []ConflictInfo + InstallOrder []string +} + +// NewFlatDependencyMap creates an empty FlatDependencyMap. +func NewFlatDependencyMap() *FlatDependencyMap { + return &FlatDependencyMap{ + Dependencies: make(map[string]string), + } +} + +// AddDependency adds a dependency to the flat map (first-wins on conflict). +func (f *FlatDependencyMap) AddDependency(uniqueKey, ref string) { + if _, exists := f.Dependencies[uniqueKey]; !exists { + f.Dependencies[uniqueKey] = ref + f.InstallOrder = append(f.InstallOrder, uniqueKey) + } +} + +// HasConflicts reports whether any conflicts were recorded. +func (f *FlatDependencyMap) HasConflicts() bool { + return len(f.Conflicts) > 0 +} + +// TotalDependencies returns the count of unique dependencies. +func (f *FlatDependencyMap) TotalDependencies() int { + return len(f.Dependencies) +} + +// GetInstallationList returns dependency keys in install order. +func (f *FlatDependencyMap) GetInstallationList() []string { + result := make([]string, 0, len(f.InstallOrder)) + for _, key := range f.InstallOrder { + if _, ok := f.Dependencies[key]; ok { + result = append(result, key) + } + } + return result +} + +// DependencyTree is the hierarchical representation before flattening. +// Mirrors src/apm_cli/deps/dependency_graph.py:DependencyTree. +type DependencyTree struct { + RootPackage string + Nodes map[string]*DependencyNode + MaxDepth int +} + +// NewDependencyTree creates an empty DependencyTree for the given root. +func NewDependencyTree(rootPackage string) *DependencyTree { + return &DependencyTree{ + RootPackage: rootPackage, + Nodes: make(map[string]*DependencyNode), + } +} + +// AddNode adds a node to the tree. +func (t *DependencyTree) AddNode(node *DependencyNode) { + id := node.GetID() + t.Nodes[id] = node + if node.Depth > t.MaxDepth { + t.MaxDepth = node.Depth + } +} + +// GetNode retrieves a node by its unique key. +func (t *DependencyTree) GetNode(id string) *DependencyNode { + return t.Nodes[id] +} + +// HasDependency checks if a repo URL is present in the tree. +func (t *DependencyTree) HasDependency(repoURL string) bool { + for _, n := range t.Nodes { + if n.RepoURL == repoURL { + return true + } + } + return false +} + +// DependencyGraph is the complete resolved dependency information. +// Mirrors src/apm_cli/deps/dependency_graph.py:DependencyGraph. +type DependencyGraph struct { + RootPackage string + Tree *DependencyTree + Flattened *FlatDependencyMap + CircularDependencies []CircularRef + ResolutionErrors []string +} + +// NewDependencyGraph creates a new empty DependencyGraph. +func NewDependencyGraph(rootPackage string) *DependencyGraph { + return &DependencyGraph{ + RootPackage: rootPackage, + Tree: NewDependencyTree(rootPackage), + Flattened: NewFlatDependencyMap(), + } +} + +// HasCircularDependencies reports whether any circular deps were detected. +func (g *DependencyGraph) HasCircularDependencies() bool { + return len(g.CircularDependencies) > 0 +} + +// HasConflicts reports whether any conflicts exist. +func (g *DependencyGraph) HasConflicts() bool { + return g.Flattened.HasConflicts() +} + +// HasErrors reports whether any resolution errors exist. +func (g *DependencyGraph) HasErrors() bool { + return len(g.ResolutionErrors) > 0 +} + +// IsValid reports whether the graph is free of circular deps and errors. +func (g *DependencyGraph) IsValid() bool { + return !g.HasCircularDependencies() && !g.HasErrors() +} + +// AddError appends a resolution error. +func (g *DependencyGraph) AddError(err string) { + g.ResolutionErrors = append(g.ResolutionErrors, err) +} + +// AddCircularDependency records a circular dependency detection. +func (g *DependencyGraph) AddCircularDependency(ref CircularRef) { + g.CircularDependencies = append(g.CircularDependencies, ref) +} + +// GetSummary returns a summary map of the dependency graph. +func (g *DependencyGraph) GetSummary() map[string]interface{} { + return map[string]interface{}{ + "root_package": g.RootPackage, + "total_dependencies": g.Flattened.TotalDependencies(), + "max_depth": g.Tree.MaxDepth, + "has_circular_dependencies": g.HasCircularDependencies(), + "circular_count": len(g.CircularDependencies), + "has_conflicts": g.HasConflicts(), + "conflict_count": len(g.Flattened.Conflicts), + "has_errors": g.HasErrors(), + "error_count": len(g.ResolutionErrors), + "is_valid": g.IsValid(), + } +} diff --git a/internal/deps/lockfile.go b/internal/deps/lockfile.go new file mode 100644 index 00000000..849cf4ff --- /dev/null +++ b/internal/deps/lockfile.go @@ -0,0 +1,246 @@ +// Package deps -- LockedDependency and LockFile data structures. +// Mirrors src/apm_cli/deps/lockfile.py (core types only). +package deps + +import ( + "sort" +) + +// LockedDependency is a resolved dependency with exact commit/version info. +// Mirrors src/apm_cli/deps/lockfile.py:LockedDependency. +type LockedDependency struct { + RepoURL string + Host string + Port int // 0 means not set + RegistryPrefix string + ResolvedCommit string + ResolvedRef string + Version string + VirtualPath string + IsVirtual bool + Depth int + ResolvedBy string + PackageType string + DeployedFiles []string + DeployedFileHashes map[string]string + Source string // "local" for local deps + LocalPath string + ContentHash string + IsDev bool + DiscoveredVia string + MarketplacePluginName string + IsInsecure bool + AllowInsecure bool + SkillSubset []string +} + +// GetUniqueKey returns a stable key for this locked dependency. +// Mirrors LockedDependency.get_unique_key(). +func (d *LockedDependency) GetUniqueKey() string { + if d.Source == "local" && d.LocalPath != "" { + return d.LocalPath + } + if d.IsVirtual && d.VirtualPath != "" { + return d.RepoURL + "/" + d.VirtualPath + } + return d.RepoURL +} + +// ToMap serializes the locked dependency to a map (for YAML output). +// Mirrors LockedDependency.to_dict(). +func (d *LockedDependency) ToMap() map[string]interface{} { + m := map[string]interface{}{"repo_url": d.RepoURL} + if d.Host != "" { + m["host"] = d.Host + } + if d.Port != 0 { + m["port"] = d.Port + } + if d.RegistryPrefix != "" { + m["registry_prefix"] = d.RegistryPrefix + } + if d.ResolvedCommit != "" { + m["resolved_commit"] = d.ResolvedCommit + } + if d.ResolvedRef != "" { + m["resolved_ref"] = d.ResolvedRef + } + if d.Version != "" { + m["version"] = d.Version + } + if d.VirtualPath != "" { + m["virtual_path"] = d.VirtualPath + } + if d.IsVirtual { + m["is_virtual"] = true + } + if d.Depth != 1 { + m["depth"] = d.Depth + } + if d.ResolvedBy != "" { + m["resolved_by"] = d.ResolvedBy + } + if d.PackageType != "" { + m["package_type"] = d.PackageType + } + if len(d.DeployedFiles) > 0 { + files := make([]string, len(d.DeployedFiles)) + copy(files, d.DeployedFiles) + sort.Strings(files) + m["deployed_files"] = files + } + if len(d.DeployedFileHashes) > 0 { + m["deployed_file_hashes"] = d.DeployedFileHashes + } + if d.Source != "" { + m["source"] = d.Source + } + if d.LocalPath != "" { + m["local_path"] = d.LocalPath + } + if d.ContentHash != "" { + m["content_hash"] = d.ContentHash + } + if d.IsDev { + m["is_dev"] = true + } + if d.DiscoveredVia != "" { + m["discovered_via"] = d.DiscoveredVia + } + if d.MarketplacePluginName != "" { + m["marketplace_plugin_name"] = d.MarketplacePluginName + } + if d.IsInsecure { + m["is_insecure"] = true + } + if d.AllowInsecure { + m["allow_insecure"] = true + } + if len(d.SkillSubset) > 0 { + ss := make([]string, len(d.SkillSubset)) + copy(ss, d.SkillSubset) + sort.Strings(ss) + m["skill_subset"] = ss + } + return m +} + +// LockedDependencyFromMap deserializes a LockedDependency from a map. +// Mirrors LockedDependency.from_dict(). +func LockedDependencyFromMap(data map[string]interface{}) *LockedDependency { + ld := &LockedDependency{ + RepoURL: stringField(data, "repo_url"), + Host: stringField(data, "host"), + RegistryPrefix: stringField(data, "registry_prefix"), + ResolvedCommit: stringField(data, "resolved_commit"), + ResolvedRef: stringField(data, "resolved_ref"), + Version: stringField(data, "version"), + VirtualPath: stringField(data, "virtual_path"), + IsVirtual: boolField(data, "is_virtual"), + Depth: intFieldDefault(data, "depth", 1), + ResolvedBy: stringField(data, "resolved_by"), + PackageType: stringField(data, "package_type"), + Source: stringField(data, "source"), + LocalPath: stringField(data, "local_path"), + ContentHash: stringField(data, "content_hash"), + IsDev: boolField(data, "is_dev"), + DiscoveredVia: stringField(data, "discovered_via"), + MarketplacePluginName: stringField(data, "marketplace_plugin_name"), + IsInsecure: boolField(data, "is_insecure"), + AllowInsecure: boolField(data, "allow_insecure"), + } + + // Port with validation (1-65535). + if pRaw, ok := data["port"]; ok && pRaw != nil { + switch v := pRaw.(type) { + case int: + if v >= 1 && v <= 65535 { + ld.Port = v + } + case float64: + iv := int(v) + if iv >= 1 && iv <= 65535 { + ld.Port = iv + } + } + } + + // deployed_files + if raw, ok := data["deployed_files"]; ok { + if sl, ok := raw.([]interface{}); ok { + for _, v := range sl { + if s, ok := v.(string); ok { + ld.DeployedFiles = append(ld.DeployedFiles, s) + } + } + } + } + // deployed_file_hashes + if raw, ok := data["deployed_file_hashes"]; ok { + if m, ok := raw.(map[string]interface{}); ok { + ld.DeployedFileHashes = make(map[string]string, len(m)) + for k, v := range m { + if s, ok := v.(string); ok { + ld.DeployedFileHashes[k] = s + } + } + } + } + // skill_subset + if raw, ok := data["skill_subset"]; ok { + if sl, ok := raw.([]interface{}); ok { + for _, v := range sl { + if s, ok := v.(string); ok { + ld.SkillSubset = append(ld.SkillSubset, s) + } + } + } + } + + return ld +} + +// InstalledPackage records a successfully-installed dependency. +// Mirrors src/apm_cli/deps/installed_package.py:InstalledPackage. +type InstalledPackage struct { + RepoURL string + Reference string + ResolvedCommit string + Depth int + ResolvedBy string + IsDev bool + RegistryHost string // from RegistryConfig.host if set + RegistryPrefix string // from RegistryConfig.prefix if set +} + +// helpers + +func stringField(m map[string]interface{}, key string) string { + if v, ok := m[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func boolField(m map[string]interface{}, key string) bool { + if v, ok := m[key]; ok { + if b, ok := v.(bool); ok { + return b + } + } + return false +} + +func intFieldDefault(m map[string]interface{}, key string, def int) int { + if v, ok := m[key]; ok { + switch n := v.(type) { + case int: + return n + case float64: + return int(n) + } + } + return def +} diff --git a/internal/models/dependency/types.go b/internal/models/dependency/types.go new file mode 100644 index 00000000..09de1083 --- /dev/null +++ b/internal/models/dependency/types.go @@ -0,0 +1,103 @@ +// Package dependency defines dependency reference types for APM packages. +// Mirrors src/apm_cli/models/dependency/types.py. +package dependency + +import "regexp" + +// GitReferenceType classifies a git reference as branch, tag, or commit. +type GitReferenceType int + +const ( + GitRefBranch GitReferenceType = iota + GitRefTag + GitRefCommit +) + +func (t GitReferenceType) String() string { + switch t { + case GitRefBranch: + return "branch" + case GitRefTag: + return "tag" + case GitRefCommit: + return "commit" + default: + return "unknown" + } +} + +// VirtualPackageType classifies a virtual package as a file or subdirectory. +type VirtualPackageType int + +const ( + VirtualPackageFile VirtualPackageType = iota + VirtualPackageSubdirectory +) + +func (t VirtualPackageType) String() string { + switch t { + case VirtualPackageFile: + return "file" + case VirtualPackageSubdirectory: + return "subdirectory" + default: + return "unknown" + } +} + +// RemoteRef represents a single remote git reference with its commit SHA. +type RemoteRef struct { + Name string + RefType GitReferenceType + CommitSHA string +} + +// ResolvedReference represents a resolved git reference. +type ResolvedReference struct { + OriginalRef string + RefType GitReferenceType + ResolvedCommit string + RefName string +} + +func (r ResolvedReference) String() string { + if r.ResolvedCommit == "" { + return r.RefName + } + if r.RefType == GitRefCommit { + if len(r.ResolvedCommit) > 8 { + return r.ResolvedCommit[:8] + } + return r.ResolvedCommit + } + commit := r.ResolvedCommit + if len(commit) > 8 { + commit = commit[:8] + } + return r.RefName + " (" + commit + ")" +} + +var ( + commitSHARE = regexp.MustCompile(`^[a-f0-9]{7,40}$`) + semverRE = regexp.MustCompile(`^v?\d+\.\d+\.\d+`) +) + +// ParseGitReference parses a git reference string to determine its type. +// Mirrors src/apm_cli/models/dependency/types.py:parse_git_reference. +func ParseGitReference(ref string) (GitReferenceType, string) { + if ref == "" { + return GitRefBranch, "main" + } + + // Check for commit SHA (7-40 hex chars) + if commitSHARE.MatchString(ref) { + return GitRefCommit, ref + } + + // Check for semantic version tag + if semverRE.MatchString(ref) { + return GitRefTag, ref + } + + return GitRefBranch, ref +} diff --git a/internal/models/dependency/types_test.go b/internal/models/dependency/types_test.go new file mode 100644 index 00000000..41964aca --- /dev/null +++ b/internal/models/dependency/types_test.go @@ -0,0 +1,130 @@ +package dependency_test + +import ( + "testing" + + "github.com/githubnext/apm/internal/models/dependency" +) + +// TestParityGitRefBranch mirrors test: parse_git_reference("main") -> (BRANCH, "main") +func TestParityGitRefBranch(t *testing.T) { + refType, ref := dependency.ParseGitReference("main") + if refType != dependency.GitRefBranch { + t.Errorf("expected BRANCH, got %s", refType) + } + if ref != "main" { + t.Errorf("expected 'main', got %s", ref) + } +} + +// TestParityGitRefEmpty mirrors test: parse_git_reference("") -> (BRANCH, "main") +func TestParityGitRefEmpty(t *testing.T) { + refType, ref := dependency.ParseGitReference("") + if refType != dependency.GitRefBranch { + t.Errorf("expected BRANCH, got %s", refType) + } + if ref != "main" { + t.Errorf("expected 'main', got %s", ref) + } +} + +// TestParityGitRefCommitSHA mirrors: parse_git_reference("abc1234") -> (COMMIT, "abc1234") +func TestParityGitRefCommitSHA(t *testing.T) { + refType, ref := dependency.ParseGitReference("abc1234") + if refType != dependency.GitRefCommit { + t.Errorf("expected COMMIT, got %s", refType) + } + if ref != "abc1234" { + t.Errorf("expected 'abc1234', got %s", ref) + } +} + +// TestParityGitRefFullSHA mirrors: full 40-char SHA -> COMMIT +func TestParityGitRefFullSHA(t *testing.T) { + sha := "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + refType, ref := dependency.ParseGitReference(sha) + if refType != dependency.GitRefCommit { + t.Errorf("expected COMMIT, got %s", refType) + } + if ref != sha { + t.Errorf("expected full SHA, got %s", ref) + } +} + +// TestParityGitRefSemver mirrors: "v1.2.3" -> (TAG, "v1.2.3") +func TestParityGitRefSemver(t *testing.T) { + refType, ref := dependency.ParseGitReference("v1.2.3") + if refType != dependency.GitRefTag { + t.Errorf("expected TAG, got %s", refType) + } + if ref != "v1.2.3" { + t.Errorf("expected 'v1.2.3', got %s", ref) + } +} + +// TestParityGitRefSemverNoV mirrors: "1.2.3" -> (TAG, "1.2.3") +func TestParityGitRefSemverNoV(t *testing.T) { + refType, ref := dependency.ParseGitReference("1.2.3") + if refType != dependency.GitRefTag { + t.Errorf("expected TAG, got %s", refType) + } + if ref != "1.2.3" { + t.Errorf("expected '1.2.3', got %s", ref) + } +} + +// TestParityGitRefTypeString validates string representations +func TestParityGitRefTypeString(t *testing.T) { + cases := []struct { + refType dependency.GitReferenceType + want string + }{ + {dependency.GitRefBranch, "branch"}, + {dependency.GitRefTag, "tag"}, + {dependency.GitRefCommit, "commit"}, + } + for _, c := range cases { + if got := c.refType.String(); got != c.want { + t.Errorf("GitReferenceType.String() = %s, want %s", got, c.want) + } + } +} + +// TestParityVirtualPackageTypeString mirrors VirtualPackageType string values +func TestParityVirtualPackageTypeString(t *testing.T) { + if dependency.VirtualPackageFile.String() != "file" { + t.Errorf("expected 'file', got %s", dependency.VirtualPackageFile.String()) + } + if dependency.VirtualPackageSubdirectory.String() != "subdirectory" { + t.Errorf("expected 'subdirectory', got %s", dependency.VirtualPackageSubdirectory.String()) + } +} + +// TestParityResolvedReferenceString mirrors ResolvedReference.__str__ +func TestParityResolvedReferenceString(t *testing.T) { + // No resolved commit: just refname + r := dependency.ResolvedReference{RefName: "main", RefType: dependency.GitRefBranch} + if r.String() != "main" { + t.Errorf("expected 'main', got %s", r.String()) + } + + // Commit type: short SHA + r2 := dependency.ResolvedReference{ + RefType: dependency.GitRefCommit, + ResolvedCommit: "abc1234def567890", + RefName: "abc1234", + } + if r2.String() != "abc1234d" { + t.Errorf("expected 'abc1234d', got %s", r2.String()) + } + + // Branch with commit: "main (abc1234de)" + r3 := dependency.ResolvedReference{ + RefType: dependency.GitRefBranch, + RefName: "main", + ResolvedCommit: "abc1234def567890", + } + if r3.String() != "main (abc1234d)" { + t.Errorf("expected 'main (abc1234d)', got %s", r3.String()) + } +} diff --git a/internal/models/errors.go b/internal/models/errors.go new file mode 100644 index 00000000..bce1943a --- /dev/null +++ b/internal/models/errors.go @@ -0,0 +1,7 @@ +package models + +import "fmt" + +func errorf(format string, args ...interface{}) error { + return fmt.Errorf(format, args...) +} diff --git a/internal/models/models.go b/internal/models/models.go new file mode 100644 index 00000000..e12a91c0 --- /dev/null +++ b/internal/models/models.go @@ -0,0 +1,138 @@ +// Package models defines core data structures for APM packages. +// Mirrors src/apm_cli/models/results.py and src/apm_cli/models/validation.py. +package models + +// InstallResult holds the result of an APM install operation. +// Mirrors src/apm_cli/models/results.py:InstallResult. +type InstallResult struct { + InstalledCount int + PromptsIntegrated int + AgentsIntegrated int + Diagnostics interface{} + PackageTypes map[string]string +} + +// PrimitiveCounts holds counts of primitives in a package. +// Mirrors src/apm_cli/models/results.py:PrimitiveCounts. +type PrimitiveCounts struct { + Prompts int + Agents int + Instructions int + Skills int + Hooks int + Commands int +} + +// PackageType classifies a package by its content. +// Mirrors src/apm_cli/models/validation.py:PackageType. +type PackageType int + +const ( + PackageTypeAPMPackage PackageType = iota + PackageTypeClaudeSkill + PackageTypeHookPackage + PackageTypeHybrid + PackageTypeMarketplacePlugin + PackageTypeSkillBundle + PackageTypeInvalid +) + +func (p PackageType) String() string { + switch p { + case PackageTypeAPMPackage: + return "apm_package" + case PackageTypeClaudeSkill: + return "claude_skill" + case PackageTypeHookPackage: + return "hook_package" + case PackageTypeHybrid: + return "hybrid" + case PackageTypeMarketplacePlugin: + return "marketplace_plugin" + case PackageTypeSkillBundle: + return "skill_bundle" + case PackageTypeInvalid: + return "invalid" + default: + return "unknown" + } +} + +// PackageContentType is the explicit package content type declared in apm.yml. +// Mirrors src/apm_cli/models/validation.py:PackageContentType. +type PackageContentType int + +const ( + PackageContentTypeInstructions PackageContentType = iota + PackageContentTypeSkill + PackageContentTypeHybrid + PackageContentTypePrompts +) + +func (p PackageContentType) String() string { + switch p { + case PackageContentTypeInstructions: + return "instructions" + case PackageContentTypeSkill: + return "skill" + case PackageContentTypeHybrid: + return "hybrid" + case PackageContentTypePrompts: + return "prompts" + default: + return "unknown" + } +} + +// ParsePackageContentType parses a string into a PackageContentType. +// Mirrors src/apm_cli/models/validation.py:PackageContentType.from_string. +func ParsePackageContentType(value string) (PackageContentType, error) { + if value == "" { + return 0, errorf("Package type cannot be empty") + } + switch value { + case "instructions": + return PackageContentTypeInstructions, nil + case "skill": + return PackageContentTypeSkill, nil + case "hybrid": + return PackageContentTypeHybrid, nil + case "prompts": + return PackageContentTypePrompts, nil + default: + return 0, errorf("Invalid package type '%s'. Valid types are: 'instructions', 'skill', 'hybrid', 'prompts'", value) + } +} + +// ValidationError enumerates types of validation errors for APM packages. +// Mirrors src/apm_cli/models/validation.py:ValidationError. +type ValidationErrorCode string + +const ( + ValidationErrMissingAPMYml ValidationErrorCode = "missing_apm_yml" + ValidationErrMissingAPMDir ValidationErrorCode = "missing_apm_dir" + ValidationErrInvalidYMLFormat ValidationErrorCode = "invalid_yml_format" + ValidationErrMissingRequired ValidationErrorCode = "missing_required_field" +) + +// ValidationResult holds the result of a package validation. +type ValidationResult struct { + Valid bool + Errors []ValidationErrorCode + PackageType PackageType +} + +// PluginMetadata holds metadata for a plugin. +// Mirrors src/apm_cli/models/plugin.py:PluginMetadata. +type PluginMetadata struct { + ID string + Name string + Version string + Description string + Author string + Repository string + Homepage string + License string + Tags []string + Dependencies []string +} diff --git a/internal/models/models_test.go b/internal/models/models_test.go new file mode 100644 index 00000000..f5c4cb65 --- /dev/null +++ b/internal/models/models_test.go @@ -0,0 +1,123 @@ +package models_test + +import ( + "testing" + + "github.com/githubnext/apm/internal/models" +) + +// TestParityPackageTypeString mirrors PackageType enum string values +func TestParityPackageTypeString(t *testing.T) { + cases := []struct { + pt models.PackageType + want string + }{ + {models.PackageTypeAPMPackage, "apm_package"}, + {models.PackageTypeClaudeSkill, "claude_skill"}, + {models.PackageTypeHookPackage, "hook_package"}, + {models.PackageTypeHybrid, "hybrid"}, + {models.PackageTypeMarketplacePlugin, "marketplace_plugin"}, + {models.PackageTypeSkillBundle, "skill_bundle"}, + {models.PackageTypeInvalid, "invalid"}, + } + for _, c := range cases { + if got := c.pt.String(); got != c.want { + t.Errorf("PackageType(%d).String() = %s, want %s", c.pt, got, c.want) + } + } +} + +// TestParityPackageContentTypeString mirrors PackageContentType enum string values +func TestParityPackageContentTypeString(t *testing.T) { + cases := []struct { + ct models.PackageContentType + want string + }{ + {models.PackageContentTypeInstructions, "instructions"}, + {models.PackageContentTypeSkill, "skill"}, + {models.PackageContentTypeHybrid, "hybrid"}, + {models.PackageContentTypePrompts, "prompts"}, + } + for _, c := range cases { + if got := c.ct.String(); got != c.want { + t.Errorf("PackageContentType.String() = %s, want %s", got, c.want) + } + } +} + +// TestParityParsePackageContentTypeValid mirrors PackageContentType.from_string for valid values +func TestParityParsePackageContentTypeValid(t *testing.T) { + cases := []struct { + input string + want models.PackageContentType + }{ + {"instructions", models.PackageContentTypeInstructions}, + {"skill", models.PackageContentTypeSkill}, + {"hybrid", models.PackageContentTypeHybrid}, + {"prompts", models.PackageContentTypePrompts}, + } + for _, c := range cases { + got, err := models.ParsePackageContentType(c.input) + if err != nil { + t.Errorf("ParsePackageContentType(%q) unexpected error: %v", c.input, err) + } + if got != c.want { + t.Errorf("ParsePackageContentType(%q) = %v, want %v", c.input, got, c.want) + } + } +} + +// TestParityParsePackageContentTypeEmpty mirrors PackageContentType.from_string("") raises ValueError +func TestParityParsePackageContentTypeEmpty(t *testing.T) { + _, err := models.ParsePackageContentType("") + if err == nil { + t.Error("expected error for empty string") + } +} + +// TestParityParsePackageContentTypeInvalid mirrors PackageContentType.from_string("invalid") +func TestParityParsePackageContentTypeInvalid(t *testing.T) { + _, err := models.ParsePackageContentType("bad_type") + if err == nil { + t.Error("expected error for invalid type") + } +} + +// TestParityInstallResultDefaults mirrors InstallResult default field values +func TestParityInstallResultDefaults(t *testing.T) { + r := models.InstallResult{PackageTypes: make(map[string]string)} + if r.InstalledCount != 0 { + t.Errorf("InstalledCount default should be 0, got %d", r.InstalledCount) + } + if r.PromptsIntegrated != 0 { + t.Errorf("PromptsIntegrated default should be 0") + } + if r.AgentsIntegrated != 0 { + t.Errorf("AgentsIntegrated default should be 0") + } +} + +// TestParityPrimitiveCounts mirrors PrimitiveCounts default zero values +func TestParityPrimitiveCounts(t *testing.T) { + pc := models.PrimitiveCounts{} + if pc.Prompts != 0 || pc.Agents != 0 || pc.Instructions != 0 || + pc.Skills != 0 || pc.Hooks != 0 || pc.Commands != 0 { + t.Error("PrimitiveCounts defaults should all be 0") + } +} + +// TestParityValidationErrorCodes mirrors ValidationError enum values +func TestParityValidationErrorCodes(t *testing.T) { + if models.ValidationErrMissingAPMYml != "missing_apm_yml" { + t.Error("ValidationErrMissingAPMYml mismatch") + } + if models.ValidationErrMissingAPMDir != "missing_apm_dir" { + t.Error("ValidationErrMissingAPMDir mismatch") + } + if models.ValidationErrInvalidYMLFormat != "invalid_yml_format" { + t.Error("ValidationErrInvalidYMLFormat mismatch") + } + if models.ValidationErrMissingRequired != "missing_required_field" { + t.Error("ValidationErrMissingRequired mismatch") + } +} diff --git a/internal/primitives/primitives.go b/internal/primitives/primitives.go new file mode 100644 index 00000000..197258b1 --- /dev/null +++ b/internal/primitives/primitives.go @@ -0,0 +1,107 @@ +// Package primitives defines data structures for APM primitive files +// (chatmodes, instructions, contexts, skills). +// Mirrors src/apm_cli/primitives/models.py. +package primitives + +import "path/filepath" + +// PrimitiveType classifies a primitive by its kind. +type PrimitiveType string + +const ( + PrimitiveTypeChatmode PrimitiveType = "chatmode" + PrimitiveTypeInstruction PrimitiveType = "instruction" + PrimitiveTypeContext PrimitiveType = "context" + PrimitiveTypeSkill PrimitiveType = "skill" +) + +// Chatmode represents a chatmode primitive. +// Mirrors src/apm_cli/primitives/models.py:Chatmode. +type Chatmode struct { + Name string + FilePath string + Description string + ApplyTo string // Glob pattern for file targeting (empty if not set) + Content string + Author string + Version string + Source string +} + +// Instruction represents an instruction primitive (.instructions.md). +// Mirrors src/apm_cli/primitives/models.py:Instruction. +type Instruction struct { + Name string + FilePath string + Description string + ApplyTo string + Content string + Author string + Version string + Source string +} + +// Context represents a context primitive (.context.md). +// Mirrors src/apm_cli/primitives/models.py:Context. +type Context struct { + Name string + FilePath string + Description string + Scope string + Content string + Author string + Version string + Source string +} + +// Skill represents a skill primitive (SKILL.md). +// Mirrors src/apm_cli/primitives/models.py:Skill. +type Skill struct { + Name string + FilePath string + Description string + Content string + Author string + Version string + Source string +} + +// PrimitiveConflict records a conflict between two primitives. +// Mirrors src/apm_cli/primitives/models.py:PrimitiveConflict. +type PrimitiveConflict struct { + Type PrimitiveType + Name string + Path1 string + Path2 string + Reason string +} + +// PrimitiveCollection holds all discovered primitives for a package. +// Mirrors src/apm_cli/primitives/models.py:PrimitiveCollection. +type PrimitiveCollection struct { + Chatmodes []Chatmode + Instructions []Instruction + Contexts []Context + Skills []Skill + Conflicts []PrimitiveConflict +} + +// FileNameWithoutExt returns the base filename without extension. +func FileNameWithoutExt(path string) string { + base := filepath.Base(path) + ext := filepath.Ext(base) + if ext == "" { + return base + } + return base[:len(base)-len(ext)] +} + +// TotalCount returns total number of primitives in the collection. +func (pc *PrimitiveCollection) TotalCount() int { + return len(pc.Chatmodes) + len(pc.Instructions) + len(pc.Contexts) + len(pc.Skills) +} + +// HasConflicts returns true if there are any conflicts. +func (pc *PrimitiveCollection) HasConflicts() bool { + return len(pc.Conflicts) > 0 +} diff --git a/internal/primitives/primitives_test.go b/internal/primitives/primitives_test.go new file mode 100644 index 00000000..d6bb374f --- /dev/null +++ b/internal/primitives/primitives_test.go @@ -0,0 +1,143 @@ +package primitives_test + +import ( + "testing" + + "github.com/githubnext/apm/internal/primitives" +) + +// TestParityPrimitiveTypeValues mirrors PrimitiveType string constants +func TestParityPrimitiveTypeValues(t *testing.T) { + if primitives.PrimitiveTypeChatmode != "chatmode" { + t.Error("PrimitiveTypeChatmode should be 'chatmode'") + } + if primitives.PrimitiveTypeInstruction != "instruction" { + t.Error("PrimitiveTypeInstruction should be 'instruction'") + } + if primitives.PrimitiveTypeContext != "context" { + t.Error("PrimitiveTypeContext should be 'context'") + } + if primitives.PrimitiveTypeSkill != "skill" { + t.Error("PrimitiveTypeSkill should be 'skill'") + } +} + +// TestParityChatmodeStruct mirrors Chatmode dataclass fields +func TestParityChatmodeStruct(t *testing.T) { + cm := primitives.Chatmode{ + Name: "test-chatmode", + FilePath: "/some/path/test-chatmode.chatmode.md", + Description: "A test chatmode", + ApplyTo: "**/*.go", + Content: "# Test", + Author: "acme", + Version: "1.0", + Source: "owner/repo", + } + if cm.Name != "test-chatmode" { + t.Errorf("Chatmode.Name = %s, want 'test-chatmode'", cm.Name) + } + if cm.ApplyTo != "**/*.go" { + t.Errorf("Chatmode.ApplyTo = %s, want '**/*.go'", cm.ApplyTo) + } +} + +// TestParityInstructionStruct mirrors Instruction dataclass fields +func TestParityInstructionStruct(t *testing.T) { + inst := primitives.Instruction{ + Name: "testing", + FilePath: "testing.instructions.md", + ApplyTo: "**/*.go", + Content: "content", + } + if inst.Name != "testing" { + t.Errorf("Instruction.Name = %s, want 'testing'", inst.Name) + } +} + +// TestParityContextStruct mirrors Context dataclass fields +func TestParityContextStruct(t *testing.T) { + ctx := primitives.Context{ + Name: "my-context", + Scope: "workspace", + Content: "context content", + } + if ctx.Scope != "workspace" { + t.Errorf("Context.Scope = %s, want 'workspace'", ctx.Scope) + } +} + +// TestParitySkillStruct mirrors Skill dataclass fields +func TestParitySkillStruct(t *testing.T) { + skill := primitives.Skill{ + Name: "my-skill", + FilePath: "skills/my-skill/SKILL.md", + Description: "Does something useful", + Content: "# My Skill", + } + if skill.Name != "my-skill" { + t.Errorf("Skill.Name = %s, want 'my-skill'", skill.Name) + } +} + +// TestParityPrimitiveConflict mirrors PrimitiveConflict dataclass fields +func TestParityPrimitiveConflict(t *testing.T) { + conflict := primitives.PrimitiveConflict{ + Type: primitives.PrimitiveTypeSkill, + Name: "duplicate-skill", + Path1: "/a/SKILL.md", + Path2: "/b/SKILL.md", + Reason: "duplicate name", + } + if conflict.Type != primitives.PrimitiveTypeSkill { + t.Errorf("PrimitiveConflict.Type = %s, want 'skill'", conflict.Type) + } +} + +// TestParityPrimitiveCollectionTotalCount mirrors PrimitiveCollection total count +func TestParityPrimitiveCollectionTotalCount(t *testing.T) { + pc := primitives.PrimitiveCollection{ + Chatmodes: []primitives.Chatmode{{Name: "c1"}, {Name: "c2"}}, + Instructions: []primitives.Instruction{{Name: "i1"}}, + Contexts: []primitives.Context{}, + Skills: []primitives.Skill{{Name: "s1"}, {Name: "s2"}, {Name: "s3"}}, + } + if pc.TotalCount() != 6 { + t.Errorf("TotalCount() = %d, want 6", pc.TotalCount()) + } +} + +// TestParityPrimitiveCollectionHasConflicts mirrors PrimitiveCollection conflict detection +func TestParityPrimitiveCollectionHasConflicts(t *testing.T) { + pcNoConflicts := primitives.PrimitiveCollection{} + if pcNoConflicts.HasConflicts() { + t.Error("empty collection should have no conflicts") + } + + pcWithConflicts := primitives.PrimitiveCollection{ + Conflicts: []primitives.PrimitiveConflict{{Name: "dup"}}, + } + if !pcWithConflicts.HasConflicts() { + t.Error("collection with conflicts should return true") + } +} + +// TestParityFileNameWithoutExt mirrors file stem extraction +func TestParityFileNameWithoutExt(t *testing.T) { + cases := []struct { + input string + want string + }{ + {"SKILL.md", "SKILL"}, + {"test-chatmode.chatmode.md", "test-chatmode.chatmode"}, + {"testing.instructions.md", "testing.instructions"}, + {"/some/path/my-skill/SKILL.md", "SKILL"}, + {"noext", "noext"}, + } + for _, c := range cases { + got := primitives.FileNameWithoutExt(c.input) + if got != c.want { + t.Errorf("FileNameWithoutExt(%q) = %q, want %q", c.input, got, c.want) + } + } +} diff --git a/internal/utils/githubhost/githubhost.go b/internal/utils/githubhost/githubhost.go new file mode 100644 index 00000000..fd4a2914 --- /dev/null +++ b/internal/utils/githubhost/githubhost.go @@ -0,0 +1,149 @@ +// Package githubhost provides host classification and URL utilities for +// GitHub, GitHub Enterprise, Azure DevOps, and GitLab hostnames. +// Mirrors src/apm_cli/utils/github_host.py. +package githubhost + +import ( + "fmt" + "os" + "regexp" + "strings" +) + +// fqdnPattern validates a Fully Qualified Domain Name per the Python implementation. +var fqdnPattern = regexp.MustCompile( + `^[a-zA-Z0-9](?:[a-zA-Z0-9-]*[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]*[a-zA-Z0-9])?)+$`, +) + +// IsValidFQDN reports whether hostname is a valid FQDN (at least two labels). +func IsValidFQDN(hostname string) bool { + if hostname == "" { + return false + } + // Strip path components. + if idx := strings.Index(hostname, "/"); idx >= 0 { + hostname = hostname[:idx] + } + return fqdnPattern.MatchString(hostname) +} + +// DefaultHost returns the configured default git host (GITHUB_HOST env or "github.com"). +func DefaultHost() string { + if h := os.Getenv("GITHUB_HOST"); h != "" { + return h + } + return "github.com" +} + +// IsAzureDevOpsHostname reports whether hostname is an ADO host +// (dev.azure.com or *.visualstudio.com). +func IsAzureDevOpsHostname(hostname string) bool { + if hostname == "" { + return false + } + h := strings.ToLower(hostname) + return h == "dev.azure.com" || strings.HasSuffix(h, ".visualstudio.com") +} + +// IsGitHubHostname reports whether hostname is GitHub SaaS or GHE Cloud (*.ghe.com). +func IsGitHubHostname(hostname string) bool { + if hostname == "" { + return false + } + h := strings.ToLower(hostname) + return h == "github.com" || strings.HasSuffix(h, ".ghe.com") +} + +// IsGitLabHostname reports whether hostname is GitLab SaaS or a configured +// self-managed GitLab host (GITLAB_HOST or APM_GITLAB_HOSTS env vars). +// GHES host takes precedence -- if GITHUB_HOST matches, this returns false. +func IsGitLabHostname(hostname string) bool { + if hostname == "" { + return false + } + h := strings.ToLower(strings.SplitN(hostname, "/", 2)[0]) + + // GHES precedence check. + ghesHost := strings.ToLower(strings.SplitN(os.Getenv("GITHUB_HOST"), "/", 2)[0]) + if ghesHost != "" && ghesHost == h && + ghesHost != "github.com" && ghesHost != "gitlab.com" && + !strings.HasSuffix(ghesHost, ".ghe.com") && + IsValidFQDN(ghesHost) { + return false + } + + if h == "gitlab.com" { + return true + } + if single := strings.ToLower(strings.SplitN(os.Getenv("GITLAB_HOST"), "/", 2)[0]); single != "" && single == h { + return IsValidFQDN(h) + } + for _, part := range strings.Split(os.Getenv("APM_GITLAB_HOSTS"), ",") { + entry := strings.ToLower(strings.SplitN(strings.TrimSpace(part), "/", 2)[0]) + if entry != "" && entry == h && IsValidFQDN(entry) { + return true + } + } + return false +} + +// SupportGHCLIHost reports whether host should use gh CLI token fallback. +func SupportGHCLIHost(host string) bool { + if host == "" { + return false + } + if IsGitHubHostname(host) { + return true + } + configured := strings.ToLower(DefaultHost()) + hostLower := strings.ToLower(host) + if hostLower != configured { + return false + } + if configured == "github.com" || strings.HasSuffix(configured, ".ghe.com") { + return false + } + if IsAzureDevOpsHostname(configured) { + return false + } + return IsValidFQDN(configured) +} + +// adoAuthFailureSignals are the case-insensitive signals for ADO auth failures. +var adoAuthFailureSignals = []string{ + "401", + "403", + "authentication failed", + "unauthorized", + "could not read username", +} + +// IsADOAuthFailureSignal reports whether text contains an ADO auth-failure signal. +func IsADOAuthFailureSignal(text string) bool { + if text == "" { + return false + } + lower := strings.ToLower(text) + for _, sig := range adoAuthFailureSignals { + if strings.Contains(lower, sig) { + return true + } + } + return false +} + +// BuildAuthorizationHeaderGitEnv builds env vars to inject an HTTP Authorization +// header into git operations via GIT_CONFIG_COUNT/KEY_N/VALUE_N. +func BuildAuthorizationHeaderGitEnv(scheme, credential string) map[string]string { + return map[string]string{ + "GIT_CONFIG_COUNT": "1", + "GIT_CONFIG_KEY_0": "http.extraheader", + "GIT_CONFIG_VALUE_0": fmt.Sprintf("Authorization: %s %s", scheme, credential), + } +} + +// BuildADOBearerGitEnv builds env vars to authenticate to Azure DevOps +// with an Entra ID bearer token. +func BuildADOBearerGitEnv(bearerToken string) map[string]string { + return BuildAuthorizationHeaderGitEnv("Bearer", bearerToken) +} diff --git a/internal/utils/githubhost/githubhost_test.go b/internal/utils/githubhost/githubhost_test.go new file mode 100644 index 00000000..84c06d83 --- /dev/null +++ b/internal/utils/githubhost/githubhost_test.go @@ -0,0 +1,131 @@ +// Package githubhost_test provides parity tests for githubhost utilities. +// Mirrors src/apm_cli/utils/github_host.py behaviour. +package githubhost_test + +import ( + "os" + "testing" + + "github.com/githubnext/apm/internal/utils/githubhost" +) + +func TestParityIsValidFQDN(t *testing.T) { + cases := []struct { + in string + want bool + }{ + {"github.com", true}, + {"dev.azure.com", true}, + {"my.custom.host.example.com", true}, + {"", false}, + {"localhost", false}, + {"invalid-", false}, + {"-invalid.com", false}, + {"has space.com", false}, + } + for _, c := range cases { + got := githubhost.IsValidFQDN(c.in) + if got != c.want { + t.Errorf("IsValidFQDN(%q) = %v, want %v", c.in, got, c.want) + } + } +} + +func TestParityDefaultHost(t *testing.T) { + os.Unsetenv("GITHUB_HOST") + if h := githubhost.DefaultHost(); h != "github.com" { + t.Errorf("DefaultHost() = %q, want github.com", h) + } + t.Setenv("GITHUB_HOST", "myghe.com") + if h := githubhost.DefaultHost(); h != "myghe.com" { + t.Errorf("DefaultHost() with GITHUB_HOST = %q, want myghe.com", h) + } +} + +func TestParityIsAzureDevOpsHostname(t *testing.T) { + cases := []struct { + in string + want bool + }{ + {"dev.azure.com", true}, + {"myorg.visualstudio.com", true}, + {"github.com", false}, + {"", false}, + } + for _, c := range cases { + got := githubhost.IsAzureDevOpsHostname(c.in) + if got != c.want { + t.Errorf("IsAzureDevOpsHostname(%q) = %v, want %v", c.in, got, c.want) + } + } +} + +func TestParityIsGitHubHostname(t *testing.T) { + cases := []struct { + in string + want bool + }{ + {"github.com", true}, + {"myenterprise.ghe.com", true}, + {"gitlab.com", false}, + {"dev.azure.com", false}, + {"", false}, + } + for _, c := range cases { + got := githubhost.IsGitHubHostname(c.in) + if got != c.want { + t.Errorf("IsGitHubHostname(%q) = %v, want %v", c.in, got, c.want) + } + } +} + +func TestParityIsGitLabHostname(t *testing.T) { + if got := githubhost.IsGitLabHostname("gitlab.com"); !got { + t.Error("IsGitLabHostname(gitlab.com) should be true") + } + if got := githubhost.IsGitLabHostname("github.com"); got { + t.Error("IsGitLabHostname(github.com) should be false") + } + if got := githubhost.IsGitLabHostname(""); got { + t.Error("IsGitLabHostname('') should be false") + } + t.Setenv("GITLAB_HOST", "mygitlab.example.com") + if got := githubhost.IsGitLabHostname("mygitlab.example.com"); !got { + t.Error("IsGitLabHostname with GITLAB_HOST should be true") + } +} + +func TestParityIsADOAuthFailureSignal(t *testing.T) { + cases := []struct { + in string + want bool + }{ + {"HTTP 401 Unauthorized", true}, + {"authentication failed", true}, + {"could not read username", true}, + {"403 Forbidden", true}, + {"Unauthorized access", true}, + {"", false}, + {"everything is fine", false}, + } + for _, c := range cases { + got := githubhost.IsADOAuthFailureSignal(c.in) + if got != c.want { + t.Errorf("IsADOAuthFailureSignal(%q) = %v, want %v", c.in, got, c.want) + } + } +} + +func TestParityBuildADOBearerGitEnv(t *testing.T) { + env := githubhost.BuildADOBearerGitEnv("mytoken") + if env["GIT_CONFIG_COUNT"] != "1" { + t.Errorf("GIT_CONFIG_COUNT = %q", env["GIT_CONFIG_COUNT"]) + } + if env["GIT_CONFIG_KEY_0"] != "http.extraheader" { + t.Errorf("GIT_CONFIG_KEY_0 = %q", env["GIT_CONFIG_KEY_0"]) + } + want := "Authorization: Bearer mytoken" + if env["GIT_CONFIG_VALUE_0"] != want { + t.Errorf("GIT_CONFIG_VALUE_0 = %q, want %q", env["GIT_CONFIG_VALUE_0"], want) + } +} diff --git a/internal/utils/normalization/normalization.go b/internal/utils/normalization/normalization.go new file mode 100644 index 00000000..71ceb7b3 --- /dev/null +++ b/internal/utils/normalization/normalization.go @@ -0,0 +1,43 @@ +// Package normalization provides bytes-in / bytes-out content normalization +// helpers. Mirrors src/apm_cli/utils/normalization.py. +// +// Used by drift-detection to compare deployed file bytes against the replay +// scratch tree without flagging legitimate, deterministic differences: +// - Line-ending differences (CRLF vs LF) +// - UTF-8 BOMs at the start of the file +// - APM headers re-stamped on every recompile +package normalization + +import ( + "bytes" + "regexp" +) + +// BOM is the UTF-8 byte order mark. +var BOM = []byte{0xef, 0xbb, 0xbf} + +// buildIDPattern matches APM headers. +var buildIDPattern = regexp.MustCompile(`(?i)\s*\n?`) + +// StripBuildID removes APM headers wherever they appear. +func StripBuildID(content []byte) []byte { + return buildIDPattern.ReplaceAll(content, nil) +} + +// NormalizeLineEndings converts CRLF to LF; leaves bare CR alone. +func NormalizeLineEndings(content []byte) []byte { + return bytes.ReplaceAll(content, []byte("\r\n"), []byte("\n")) +} + +// StripBOM drops a UTF-8 BOM at the start of the file (only at offset 0). +func StripBOM(content []byte) []byte { + if bytes.HasPrefix(content, BOM) { + return content[len(BOM):] + } + return content +} + +// Normalize applies all drift-tolerant normalizations to a file's bytes. +func Normalize(content []byte) []byte { + return StripBuildID(NormalizeLineEndings(StripBOM(content))) +} diff --git a/internal/utils/normalization/normalization_test.go b/internal/utils/normalization/normalization_test.go new file mode 100644 index 00000000..9c9b8480 --- /dev/null +++ b/internal/utils/normalization/normalization_test.go @@ -0,0 +1,90 @@ +// Package normalization_test provides parity tests for normalization helpers. +// Mirrors the behaviour of src/apm_cli/utils/normalization.py. +package normalization_test + +import ( + "bytes" + "testing" + + "github.com/githubnext/apm/internal/utils/normalization" +) + +// TestParityNormalizationStripBOM verifies BOM stripping matches Python. +func TestParityNormalizationStripBOM(t *testing.T) { + bom := normalization.BOM + cases := []struct { + name string + input []byte + expected []byte + }{ + {"no bom", []byte("hello"), []byte("hello")}, + {"bom prefix", append(append([]byte{}, bom...), []byte("hello")...), []byte("hello")}, + {"bom only", bom, []byte{}}, + {"bom in middle not stripped", []byte("hel\xef\xbb\xbflo"), []byte("hel\xef\xbb\xbflo")}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := normalization.StripBOM(tc.input) + if !bytes.Equal(got, tc.expected) { + t.Errorf("got %q, want %q", got, tc.expected) + } + }) + } +} + +// TestParityNormalizationCRLF verifies CRLF-to-LF conversion matches Python. +func TestParityNormalizationCRLF(t *testing.T) { + cases := []struct { + name string + input []byte + expected []byte + }{ + {"no crlf", []byte("hello\nworld\n"), []byte("hello\nworld\n")}, + {"crlf", []byte("hello\r\nworld\r\n"), []byte("hello\nworld\n")}, + {"bare cr preserved", []byte("hello\rworld"), []byte("hello\rworld")}, + {"mixed", []byte("a\r\nb\nc\r\n"), []byte("a\nb\nc\n")}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := normalization.NormalizeLineEndings(tc.input) + if !bytes.Equal(got, tc.expected) { + t.Errorf("got %q, want %q", got, tc.expected) + } + }) + } +} + +// TestParityNormalizationStripBuildID verifies Build ID header stripping. +func TestParityNormalizationStripBuildID(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + {"no build id", "hello world", "hello world"}, + {"build id", "\nhello", "hello"}, + {"build id uppercase", "\nhello", "hello"}, + {"build id no newline", "hello", "hello"}, + {"build id with spaces", "\nhello", "hello"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := string(normalization.StripBuildID([]byte(tc.input))) + if got != tc.expected { + t.Errorf("got %q, want %q", got, tc.expected) + } + }) + } +} + +// TestParityNormalizationNormalize verifies composite Normalize function. +func TestParityNormalizationNormalize(t *testing.T) { + bom := normalization.BOM + // Build a payload with BOM + Build ID + CRLF + input := append(append([]byte{}, bom...), []byte("\r\nhello\r\nworld\r\n")...) + expected := []byte("hello\nworld\n") + got := normalization.Normalize(input) + if !bytes.Equal(got, expected) { + t.Errorf("got %q, want %q", got, expected) + } +} diff --git a/internal/utils/paths/paths.go b/internal/utils/paths/paths.go new file mode 100644 index 00000000..45b3fd03 --- /dev/null +++ b/internal/utils/paths/paths.go @@ -0,0 +1,33 @@ +// Package paths provides cross-platform path utilities for APM CLI. +// Mirrors src/apm_cli/utils/paths.py (portable_relpath function). +package paths + +import ( + "path/filepath" + "strings" +) + +// PortableRelpath returns a forward-slash relative path from base to path, +// resolving both sides first. When path is not under base (or resolution +// fails), falls back to the resolved absolute path. +// +// Mirrors Python's portable_relpath() from utils/paths.py. +func PortableRelpath(path, base string) string { + absPath, err := filepath.Abs(path) + if err != nil { + return toSlash(path) + } + absBase, err := filepath.Abs(base) + if err != nil { + return toSlash(absPath) + } + rel, err := filepath.Rel(absBase, absPath) + if err != nil { + return toSlash(absPath) + } + return toSlash(rel) +} + +func toSlash(p string) string { + return strings.ReplaceAll(p, "\\", "/") +} diff --git a/internal/utils/paths/paths_test.go b/internal/utils/paths/paths_test.go new file mode 100644 index 00000000..c3b25885 --- /dev/null +++ b/internal/utils/paths/paths_test.go @@ -0,0 +1,37 @@ +// Package paths_test provides parity tests for path utilities. +package paths_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/githubnext/apm/internal/utils/paths" +) + +// TestParityPathsPortableRelpath verifies portable_relpath behavior. +func TestParityPathsPortableRelpath(t *testing.T) { + tmp := t.TempDir() + sub := filepath.Join(tmp, "a", "b") + if err := os.MkdirAll(sub, 0o755); err != nil { + t.Fatal(err) + } + cases := []struct { + name string + path string + base string + expected string + }{ + {"nested path", filepath.Join(tmp, "a", "b", "c.txt"), tmp, "a/b/c.txt"}, + {"direct child", filepath.Join(tmp, "file.txt"), tmp, "file.txt"}, + {"same dir", tmp, tmp, "."}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := paths.PortableRelpath(tc.path, tc.base) + if got != tc.expected { + t.Errorf("got %q, want %q", got, tc.expected) + } + }) + } +} diff --git a/internal/utils/sha/sha.go b/internal/utils/sha/sha.go new file mode 100644 index 00000000..1db341af --- /dev/null +++ b/internal/utils/sha/sha.go @@ -0,0 +1,43 @@ +// Package sha provides short SHA formatting helpers. +// Mirrors src/apm_cli/utils/short_sha.py. +package sha + +import "strings" + +// sentinels are values that collapse to empty string. +var sentinels = map[string]bool{ + "cached": true, + "unknown": true, +} + +// FormatShortSHA returns an 8-char short SHA or "" for invalid inputs. +// Rules: +// - nil / non-string -> "" +// - sentinel strings ("cached", "unknown") -> "" +// - shorter than 8 chars -> "" +// - contains non-hex characters -> "" +// - otherwise: first 8 chars +func FormatShortSHA(value string) string { + candidate := strings.TrimSpace(value) + if candidate == "" { + return "" + } + if sentinels[strings.ToLower(candidate)] { + return "" + } + if len(candidate) < 8 { + return "" + } + for _, ch := range candidate { + if !isHex(ch) { + return "" + } + } + return candidate[:8] +} + +func isHex(ch rune) bool { + return (ch >= '0' && ch <= '9') || + (ch >= 'a' && ch <= 'f') || + (ch >= 'A' && ch <= 'F') +} diff --git a/internal/utils/sha/sha_test.go b/internal/utils/sha/sha_test.go new file mode 100644 index 00000000..be483b32 --- /dev/null +++ b/internal/utils/sha/sha_test.go @@ -0,0 +1,41 @@ +// Package sha_test provides parity tests for short SHA formatting. +// Mirrors src/apm_cli/utils/short_sha.py. +package sha_test + +import ( + "testing" + + "github.com/githubnext/apm/internal/utils/sha" +) + +// TestParitySHAFormatShortSHA verifies FormatShortSHA matches the Python +// format_short_sha implementation. +func TestParitySHAFormatShortSHA(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + {"valid sha40", "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", "a1b2c3d4"}, + {"valid sha8", "a1b2c3d4", "a1b2c3d4"}, + {"valid sha64", "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2a1b2c3d4e5f6a1b2c3d4e5f6", "a1b2c3d4"}, + {"empty", "", ""}, + {"whitespace only", " ", ""}, + {"sentinel cached", "cached", ""}, + {"sentinel unknown", "unknown", ""}, + {"sentinel cached upper", "CACHED", ""}, + {"too short", "abc123", ""}, + {"exactly 7", "abcdef1", ""}, + {"non hex chars", "xyz12345", ""}, + {"uppercase valid", "ABCDEF12", "ABCDEF12"}, + {"mixed case valid", "aAbBcCdD", "aAbBcCdD"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := sha.FormatShortSHA(tc.input) + if got != tc.expected { + t.Errorf("FormatShortSHA(%q): got %q, want %q", tc.input, got, tc.expected) + } + }) + } +}