diff --git a/pkg/workflow/engine.go b/pkg/workflow/engine.go index e6a2db40dc..5e067b5507 100644 --- a/pkg/workflow/engine.go +++ b/pkg/workflow/engine.go @@ -8,6 +8,7 @@ import ( "github.com/github/gh-aw/pkg/logger" "github.com/github/gh-aw/pkg/stringutil" "github.com/github/gh-aw/pkg/types" + "github.com/github/gh-aw/pkg/typeutil" ) var engineLog = logger.New("workflow:engine") @@ -186,10 +187,8 @@ func (c *Compiler) ExtractEngineConfig(frontmatter map[string]any) (string, *Eng // Extract optional 'max-turns' field if maxTurns, hasMaxTurns := engineObj["max-turns"]; hasMaxTurns { - if maxTurnsInt, ok := maxTurns.(int); ok { - config.MaxTurns = strconv.Itoa(maxTurnsInt) - } else if maxTurnsUint64, ok := maxTurns.(uint64); ok { - config.MaxTurns = strconv.FormatUint(maxTurnsUint64, 10) + if val, ok := typeutil.ParseIntValue(maxTurns); ok { + config.MaxTurns = strconv.Itoa(val) } else if maxTurnsStr, ok := maxTurns.(string); ok { config.MaxTurns = maxTurnsStr } @@ -197,10 +196,8 @@ func (c *Compiler) ExtractEngineConfig(frontmatter map[string]any) (string, *Eng // Extract optional 'max-continuations' field if maxCont, hasMaxCont := engineObj["max-continuations"]; hasMaxCont { - if maxContInt, ok := maxCont.(int); ok { - config.MaxContinuations = maxContInt - } else if maxContUint64, ok := maxCont.(uint64); ok { - config.MaxContinuations = int(maxContUint64) + if val, ok := typeutil.ParseIntValue(maxCont); ok { + config.MaxContinuations = val } else if maxContStr, ok := maxCont.(string); ok { if parsed, err := strconv.Atoi(maxContStr); err == nil { config.MaxContinuations = parsed diff --git a/pkg/workflow/engine_config_test.go b/pkg/workflow/engine_config_test.go index a588fe9c6c..df25764834 100644 --- a/pkg/workflow/engine_config_test.go +++ b/pkg/workflow/engine_config_test.go @@ -141,6 +141,19 @@ func TestExtractEngineConfig(t *testing.T) { expectedEngineSetting: "claude", expectedConfig: &EngineConfig{ID: "claude", Version: "beta", Model: "claude-3-5-sonnet-20241022", MaxTurns: "10"}, }, + { + // float64 is what json.Unmarshal produces for numbers when deserializing engine + // config JSON from shared imports (JSON roundtrip: YAML int -> JSON -> Go float64) + name: "object format - with max-turns as float64 (JSON roundtrip from shared import)", + frontmatter: map[string]any{ + "engine": map[string]any{ + "id": "claude", + "max-turns": float64(100), + }, + }, + expectedEngineSetting: "claude", + expectedConfig: &EngineConfig{ID: "claude", MaxTurns: "100"}, + }, { name: "object format - with env vars", frontmatter: map[string]any{ diff --git a/pkg/workflow/max_turns_test.go b/pkg/workflow/max_turns_test.go index 981c2571ba..0139fdc75b 100644 --- a/pkg/workflow/max_turns_test.go +++ b/pkg/workflow/max_turns_test.go @@ -254,3 +254,82 @@ engine: }) } } + +func TestMaxTurnsFromSharedImport(t *testing.T) { + // This test verifies that engine.max-turns is correctly propagated when + // the engine config is sourced from a shared import rather than defined inline. + // The bug was that max-turns was silently dropped because it was serialized as + // JSON (int -> float64) but only int/uint64/string types were handled. + + // Create a temporary directory for the test + tmpDir := testutil.TempDir(t, "max-turns-import-test") + + // Create the shared import file with engine config including max-turns + sharedContent := `--- +engine: + id: claude + max-turns: 100 +permissions: + contents: read + issues: read + pull-requests: read +--- +` + sharedDir := filepath.Join(tmpDir, "shared") + if err := os.MkdirAll(sharedDir, 0755); err != nil { + t.Fatal(err) + } + sharedFile := filepath.Join(sharedDir, "common.md") + if err := os.WriteFile(sharedFile, []byte(sharedContent), 0644); err != nil { + t.Fatal(err) + } + + // Create the main workflow that imports the shared config + mainContent := `--- +on: + workflow_dispatch: +permissions: + contents: read + issues: read + pull-requests: read +imports: + - shared/common.md +tools: + github: + allowed: [issue_read] +--- + +# Test Max Turns From Shared Import + +This workflow imports max-turns from a shared import. +` + mainFile := filepath.Join(tmpDir, "test-workflow.md") + if err := os.WriteFile(mainFile, []byte(mainContent), 0644); err != nil { + t.Fatal(err) + } + + // Compile the workflow + compiler := NewCompiler() + if err := compiler.CompileWorkflow(mainFile); err != nil { + t.Fatalf("Failed to compile workflow: %v", err) + } + + // Read the generated lock file + lockFile := stringutil.MarkdownToLockFile(mainFile) + lockContent, err := os.ReadFile(lockFile) + if err != nil { + t.Fatalf("Failed to read lock file: %v", err) + } + + lockContentStr := string(lockContent) + + // Verify --max-turns 100 is present in the compiled output + if !strings.Contains(lockContentStr, "--max-turns 100") { + t.Errorf("Expected --max-turns 100 in compiled output when max-turns is set in shared import.\nLock file content:\n%s", lockContentStr) + } + + // Verify GH_AW_MAX_TURNS env var is set + if !strings.Contains(lockContentStr, "GH_AW_MAX_TURNS: 100") { + t.Errorf("Expected GH_AW_MAX_TURNS: 100 in compiled output.\nLock file content:\n%s", lockContentStr) + } +}