diff --git a/cmd/template/template.go b/cmd/template/template.go index c5aa710..aaf7fd3 100644 --- a/cmd/template/template.go +++ b/cmd/template/template.go @@ -199,7 +199,7 @@ func execute(cmd *cobra.Command, args []string, def *Definition) error { } } - userMsg, err := buildExtendedUserMessage(rootFS, meta, ec, files) + userRequest, err := buildWorkspaceChangeRequest(rootFS, meta, ec, files) if err != nil { return err } @@ -217,7 +217,7 @@ func execute(cmd *cobra.Command, args []string, def *Definition) error { systemMessage := rendered - proposal, err := llm.GetWorkspaceChangeProposals(cfg, def.Model.Family, def.Model.Size, systemMessage, userMsg) + proposal, err := llm.GetWorkspaceChangeProposals(cfg, def.Model.Family, def.Model.Size, systemMessage, userRequest) if err != nil { return err } diff --git a/cmd/template/user_msg_builder.go b/cmd/template/user_msg_builder.go index ab5b292..fe17be1 100644 --- a/cmd/template/user_msg_builder.go +++ b/cmd/template/user_msg_builder.go @@ -11,19 +11,21 @@ import ( "github.com/vybdev/vyb/workspace/project" ) -// buildExtendedUserMessage composes the user-message payload that will be +// buildWorkspaceChangeRequest composes a payload.WorkspaceChangeRequest that will be // sent to the LLM. It prepends module context information — as dictated -// by the specification — before the raw file contents. When metadata is -// nil or when any contextual information is missing the function falls -// back gracefully, emitting only what is available. -func buildExtendedUserMessage(rootFS fs.FS, meta *project.Metadata, ec *context.ExecutionContext, filePaths []string) (string, error) { - // If metadata is missing we revert to the original behaviour – emit - // just the files. - if meta == nil || meta.Modules == nil { - return payload.BuildUserMessage(rootFS, filePaths) +// by the specification — before the raw file contents. Both meta and +// meta.Modules must be non-nil. +func buildWorkspaceChangeRequest(rootFS fs.FS, meta *project.Metadata, ec *context.ExecutionContext, filePaths []string) (*payload.WorkspaceChangeRequest, error) { + if meta == nil { + return nil, fmt.Errorf("metadata cannot be nil") } + if meta.Modules == nil { + return nil, fmt.Errorf("metadata.Modules cannot be nil") + } + + request := &payload.WorkspaceChangeRequest{} - // Helper to clean/normalise relative paths. + // Helper to clean/normalise relative paths rel := func(abs string) string { if abs == "" { return "" @@ -35,42 +37,43 @@ func buildExtendedUserMessage(rootFS fs.FS, meta *project.Metadata, ec *context. workingRel := rel(ec.WorkingDir) targetRel := rel(ec.TargetDir) + request.TargetDirectory = targetRel + + // Find modules (metadata is guaranteed to be valid) workingMod := project.FindModule(meta.Modules, workingRel) targetMod := project.FindModule(meta.Modules, targetRel) if workingMod == nil || targetMod == nil { - return "", fmt.Errorf("failed to locate working and target modules") + return nil, fmt.Errorf("failed to locate working and target modules") } - var sb strings.Builder + // Set target module information + request.TargetModule = targetMod.Name - // ------------------------------------------------------------ - // 1. External context of working module. - // ------------------------------------------------------------ - if ann := workingMod.Annotation; ann != nil && ann.ExternalContext != "" { - sb.WriteString(fmt.Sprintf("# Module: `%s`\n", workingMod.Name)) - sb.WriteString("## External Context\n") - sb.WriteString(ann.ExternalContext + "\n") + // Set target module context (combined internal and external context) + var targetContext strings.Builder + if ann := targetMod.Annotation; ann != nil { + if ann.ExternalContext != "" { + targetContext.WriteString("External Context: ") + targetContext.WriteString(ann.ExternalContext) + targetContext.WriteString("\n\n") + } + if ann.InternalContext != "" { + targetContext.WriteString("Internal Context: ") + targetContext.WriteString(ann.InternalContext) + } } - // ------------------------------------------------------------ - // 2. Internal context of modules between working and target. - // ------------------------------------------------------------ - for m := targetMod.Parent; m != nil && m != workingMod; m = m.Parent { - if ann := m.Annotation; ann != nil && ann.InternalContext != "" { - sb.WriteString(fmt.Sprintf("# Module: `%s`\n", m.Name)) - sb.WriteString("## Internal Context\n") - sb.WriteString(ann.InternalContext + "\n") - } + // Ensure TargetModuleContext is never empty + if targetContext.Len() == 0 { + targetContext.WriteString("No specific context available for this module.") } + request.TargetModuleContext = targetContext.String() - // ------------------------------------------------------------ - // 3. Public context of sibling modules along the path from the - // parent of the target module up to (and including) the working - // module. This replaces the previous logic that only considered - // direct children of the working module. - // ------------------------------------------------------------ + var parentModuleContexts []payload.ModuleContext + var subModuleContexts []payload.ModuleContext + // Collect parent and sibling module contexts isAncestor := func(a, b string) bool { return a == b || (a != "." && strings.HasPrefix(b, a+"/")) } @@ -82,9 +85,10 @@ func buildExtendedUserMessage(rootFS fs.FS, meta *project.Metadata, ec *context. continue } if ann := child.Annotation; ann != nil && ann.PublicContext != "" { - sb.WriteString(fmt.Sprintf("# Module: `%s`\n", child.Name)) - sb.WriteString("## Public Context\n") - sb.WriteString(ann.PublicContext + "\n") + parentModuleContexts = append(parentModuleContexts, payload.ModuleContext{ + Name: child.Name, + Content: ann.PublicContext, + }) } } if ancestor == workingMod { @@ -92,26 +96,32 @@ func buildExtendedUserMessage(rootFS fs.FS, meta *project.Metadata, ec *context. } } - // ------------------------------------------------------------ - // 4. Public context of immediate sub-modules of target module. - // ------------------------------------------------------------ + // Collect immediate sub-modules of target module for _, child := range targetMod.Modules { if ann := child.Annotation; ann != nil && ann.PublicContext != "" { - sb.WriteString(fmt.Sprintf("# Module: `%s`\n", child.Name)) - sb.WriteString("## Public Context\n") - sb.WriteString(ann.PublicContext + "\n") + subModuleContexts = append(subModuleContexts, payload.ModuleContext{ + Name: child.Name, + Content: ann.PublicContext, + }) } } - // ------------------------------------------------------------ - // 5. Append file contents (only files from target module were - // selected by selector.Select). - // ------------------------------------------------------------ - filesMsg, err := payload.BuildUserMessage(rootFS, filePaths) - if err != nil { - return "", err + request.ParentModuleContexts = parentModuleContexts + request.SubModuleContexts = subModuleContexts + + // Append file contents + var files []payload.FileContent + for _, path := range filePaths { + content, err := fs.ReadFile(rootFS, path) + if err != nil { + return nil, fmt.Errorf("failed to read file %s: %w", path, err) + } + files = append(files, payload.FileContent{ + Path: path, + Content: string(content), + }) } - sb.WriteString(filesMsg) + request.Files = files - return sb.String(), nil + return request, nil } diff --git a/cmd/template/user_msg_builder_test.go b/cmd/template/user_msg_builder_test.go index 0817a44..1f4a12d 100644 --- a/cmd/template/user_msg_builder_test.go +++ b/cmd/template/user_msg_builder_test.go @@ -2,10 +2,11 @@ package template import ( "fmt" - "strings" + "reflect" "testing" "testing/fstest" + "github.com/vybdev/vyb/llm/payload" "github.com/vybdev/vyb/workspace/context" "github.com/vybdev/vyb/workspace/project" ) @@ -18,7 +19,7 @@ func Test_buildExtendedUserMessage(t *testing.T) { InternalContext: fmt.Sprintf("%s internal", s), } } - // Build minimal module tree: root -> work (w) -> tgt (w/child) + // Build minimal module tree: root -> work (w) -> mid -> tgt (w/child) root := &project.Module{Name: "."} work := &project.Module{Name: "w", Parent: root, Annotation: ann("W")} mid := &project.Module{Name: "w/mid", Parent: work, Annotation: ann("Mid")} @@ -49,31 +50,64 @@ func Test_buildExtendedUserMessage(t *testing.T) { TargetDir: "w/mid/child", } - msg, err := buildExtendedUserMessage(mfs, meta, ec, []string{"w/mid/child/file.txt"}) + req, err := buildWorkspaceChangeRequest(mfs, meta, ec, []string{"w/mid/child/file.txt"}) if err != nil { t.Fatalf("unexpected error: %v", err) } // Basic assertions – ensure expected contexts are present. - mustContain := []string{"W external", "Mid internal", "Sibling public", "Cousin public", "hello"} - for _, s := range mustContain { - if !strings.Contains(msg, s) { - t.Fatalf("expected message to contain %q", s) - } + expectedFiles := []payload.FileContent{ + {Path: "w/mid/child/file.txt", Content: "hello"}, + } + if !reflect.DeepEqual(req.Files, expectedFiles) { + t.Errorf("Files mismatch: got %+v, want %+v", req.Files, expectedFiles) } - mustNotContain := []string{ - "W public", "W internal", - "Mid public", "Mid external", - "Sibling internal", "Sibling external", - "Cousin internal", "Cousin external", - "Out public", "Out internal", "Out external", - "mid content", "sibling content", "w content", "cousin content", "out content", + // Verify target module information + if req.TargetModule != "w/mid/child" { + t.Errorf("TargetModule mismatch: got %q, want %q", req.TargetModule, "w/mid/child") } - for _, s := range mustNotContain { - if strings.Contains(msg, s) { - t.Fatalf("should not include contexts for target module itself, got message:\n%s", msg) - } + if req.TargetDirectory != "w/mid/child" { + t.Errorf("TargetDirectory mismatch: got %q, want %q", req.TargetDirectory, "w/mid/child") + } + + expectedParentContexts := []payload.ModuleContext{ + {Name: "w/mid/sibling", Content: "Sibling public"}, + {Name: "w/cousin", Content: "Cousin public"}, + } + + if !reflect.DeepEqual(req.ParentModuleContexts, expectedParentContexts) { + t.Errorf("ParentModuleContexts mismatch:\ngot: %+v\nwant: %+v", req.ParentModuleContexts, expectedParentContexts) + } + + // Should be empty since target module has no sub-modules + if len(req.SubModuleContexts) != 0 { + t.Errorf("SubModuleContexts should be empty, got: %+v", req.SubModuleContexts) + } +} + +func Test_buildExtendedUserMessage_nilValidation(t *testing.T) { + mfs := fstest.MapFS{ + "file.txt": &fstest.MapFile{Data: []byte("content")}, + } + + ec := &context.ExecutionContext{ + ProjectRoot: ".", + WorkingDir: ".", + TargetDir: ".", + } + + // Test nil metadata + _, err := buildWorkspaceChangeRequest(mfs, nil, ec, []string{"file.txt"}) + if err == nil || err.Error() != "metadata cannot be nil" { + t.Errorf("Expected 'metadata cannot be nil' error, got: %v", err) + } + + // Test nil modules + meta := &project.Metadata{Modules: nil} + _, err = buildWorkspaceChangeRequest(mfs, meta, ec, []string{"file.txt"}) + if err == nil || err.Error() != "metadata.Modules cannot be nil" { + t.Errorf("Expected 'metadata.Modules cannot be nil' error, got: %v", err) } } diff --git a/llm/README.md b/llm/README.md index 203ff1e..935b538 100644 --- a/llm/README.md +++ b/llm/README.md @@ -40,13 +40,11 @@ debugging. ### `llm/payload` -Pure data & helper utilities: +Pure data structures for LLM communication: -* `BuildUserMessage` – turns a list of files into a Markdown payload. -* `BuildModuleContextUserMessage` – embeds annotations into the payload - according to precise inclusion rules. -* Go structs mirroring every JSON schema (WorkspaceChangeProposal, - ModuleSelfContainedContext, …). +* Go structs for request payloads (WorkspaceChangeRequest, ModuleContextRequest, ExternalContextsRequest) +* Go structs for response payloads (WorkspaceChangeProposal, ModuleSelfContainedContext, ModuleExternalContextResponse) +* All structs support JSON marshalling/unmarshalling for LLM interactions ## JSON Schema enforcement diff --git a/llm/dispatcher.go b/llm/dispatcher.go index f99d99a..ae2882e 100644 --- a/llm/dispatcher.go +++ b/llm/dispatcher.go @@ -1,13 +1,13 @@ package llm import ( - "fmt" - "strings" + "fmt" + "strings" - "github.com/vybdev/vyb/config" - "github.com/vybdev/vyb/llm/internal/gemini" - "github.com/vybdev/vyb/llm/internal/openai" - "github.com/vybdev/vyb/llm/payload" + "github.com/vybdev/vyb/config" + "github.com/vybdev/vyb/llm/internal/gemini" + "github.com/vybdev/vyb/llm/internal/openai" + "github.com/vybdev/vyb/llm/payload" ) // provider captures the common operations expected from any LLM backend. @@ -18,88 +18,97 @@ import ( // Additional methods should be appended here whenever new high-level // helpers are added to the llm façade. type provider interface { - GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, systemMessage, userMessage string) (*payload.WorkspaceChangeProposal, error) - GetModuleContext(systemMessage, userMessage string) (*payload.ModuleSelfContainedContext, error) - GetModuleExternalContexts(systemMessage, userMessage string) (*payload.ModuleExternalContextResponse, error) + GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, systemMessage string, request *payload.WorkspaceChangeRequest) (*payload.WorkspaceChangeProposal, error) + GetModuleContext(systemMessage string, request *payload.ModuleContextRequest) (*payload.ModuleSelfContainedContext, error) + GetModuleExternalContexts(systemMessage string, request *payload.ExternalContextsRequest) (*payload.ModuleExternalContextResponse, error) } type openAIProvider struct{} type geminiProvider struct{} -func (*openAIProvider) GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, sysMsg, userMsg string) (*payload.WorkspaceChangeProposal, error) { - return openai.GetWorkspaceChangeProposals(fam, sz, sysMsg, userMsg) +type unknownProvider struct{} + +func (*openAIProvider) GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, sysMsg string, request *payload.WorkspaceChangeRequest) (*payload.WorkspaceChangeProposal, error) { + return openai.GetWorkspaceChangeProposals(fam, sz, sysMsg, request) } -func (*openAIProvider) GetModuleContext(sysMsg, userMsg string) (*payload.ModuleSelfContainedContext, error) { - return openai.GetModuleContext(sysMsg, userMsg) +func (*openAIProvider) GetModuleContext(sysMsg string, request *payload.ModuleContextRequest) (*payload.ModuleSelfContainedContext, error) { + return openai.GetModuleContext(sysMsg, request) } -func (*openAIProvider) GetModuleExternalContexts(sysMsg, userMsg string) (*payload.ModuleExternalContextResponse, error) { - return openai.GetModuleExternalContexts(sysMsg, userMsg) +func (*openAIProvider) GetModuleExternalContexts(sysMsg string, request *payload.ExternalContextsRequest) (*payload.ModuleExternalContextResponse, error) { + return openai.GetModuleExternalContexts(sysMsg, request) } // ----------------------------------------------------------------------------- -// Gemini provider implementation – WorkspaceChangeProposals hooked up +// Gemini provider implementation // ----------------------------------------------------------------------------- func mapGeminiModel(fam config.ModelFamily, sz config.ModelSize) (string, error) { - switch sz { - case config.ModelSizeSmall: - return "gemini-2.5-flash-preview-05-20", nil - case config.ModelSizeLarge: - return "gemini-2.5-pro-preview-06-05", nil - default: - return "", fmt.Errorf("gemini: unsupported model size %s", sz) - } + switch sz { + case config.ModelSizeSmall: + return "gemini-2.5-flash-preview-05-20", nil + case config.ModelSizeLarge: + return "gemini-2.5-pro-preview-06-05", nil + default: + return "", fmt.Errorf("gemini: unsupported model size %s", sz) + } } -func (*geminiProvider) GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, sysMsg, userMsg string) (*payload.WorkspaceChangeProposal, error) { - return gemini.GetWorkspaceChangeProposals(fam, sz, sysMsg, userMsg) +func (*geminiProvider) GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, sysMsg string, request *payload.WorkspaceChangeRequest) (*payload.WorkspaceChangeProposal, error) { + return gemini.GetWorkspaceChangeProposals(fam, sz, sysMsg, request) } -func (*geminiProvider) GetModuleContext(sysMsg, userMsg string) (*payload.ModuleSelfContainedContext, error) { - return gemini.GetModuleContext(sysMsg, userMsg) +func (*geminiProvider) GetModuleContext(sysMsg string, request *payload.ModuleContextRequest) (*payload.ModuleSelfContainedContext, error) { + return gemini.GetModuleContext(sysMsg, request) } -func (*geminiProvider) GetModuleExternalContexts(sysMsg, userMsg string) (*payload.ModuleExternalContextResponse, error) { - return gemini.GetModuleExternalContexts(sysMsg, userMsg) +func (*geminiProvider) GetModuleExternalContexts(sysMsg string, request *payload.ExternalContextsRequest) (*payload.ModuleExternalContextResponse, error) { + return gemini.GetModuleExternalContexts(sysMsg, request) +} + +// ----------------------------------------------------------------------------- +// Unknown Provider is a throwing stub +// ----------------------------------------------------------------------------- + +func (*unknownProvider) GetWorkspaceChangeProposals(_ config.ModelFamily, _ config.ModelSize, _ string, _ *payload.WorkspaceChangeRequest) (*payload.WorkspaceChangeProposal, error) { + return nil, fmt.Errorf("unknown provider") +} + +func (*unknownProvider) GetModuleContext(_ string, _ *payload.ModuleContextRequest) (*payload.ModuleSelfContainedContext, error) { + return nil, fmt.Errorf("unknown provider") +} + +func (*unknownProvider) GetModuleExternalContexts(_ string, _ *payload.ExternalContextsRequest) (*payload.ModuleExternalContextResponse, error) { + return nil, fmt.Errorf("unknown provider") } // ----------------------------------------------------------------------------- // Public façade helpers remain unchanged (dispatcher section). // ----------------------------------------------------------------------------- -func GetModuleExternalContexts(cfg *config.Config, sysMsg, userMsg string) (*payload.ModuleExternalContextResponse, error) { - if provider, err := resolveProvider(cfg); err != nil { - return nil, err - } else { - return provider.GetModuleExternalContexts(sysMsg, userMsg) - } +func GetModuleExternalContexts(cfg *config.Config, sysMsg string, request *payload.ExternalContextsRequest) (*payload.ModuleExternalContextResponse, error) { + return resolveProvider(cfg).GetModuleExternalContexts(sysMsg, request) } -func GetModuleContext(cfg *config.Config, sysMsg, userMsg string) (*payload.ModuleSelfContainedContext, error) { - if provider, err := resolveProvider(cfg); err != nil { - return nil, err - } else { - return provider.GetModuleContext(sysMsg, userMsg) - } +func GetModuleContext(cfg *config.Config, sysMsg string, request *payload.ModuleContextRequest) (*payload.ModuleSelfContainedContext, error) { + return resolveProvider(cfg).GetModuleContext(sysMsg, request) + } -func GetWorkspaceChangeProposals(cfg *config.Config, fam config.ModelFamily, sz config.ModelSize, sysMsg, userMsg string) (*payload.WorkspaceChangeProposal, error) { - if provider, err := resolveProvider(cfg); err != nil { - return nil, err - } else { - return provider.GetWorkspaceChangeProposals(fam, sz, sysMsg, userMsg) - } +func GetWorkspaceChangeProposals(cfg *config.Config, fam config.ModelFamily, sz config.ModelSize, sysMsg string, request *payload.WorkspaceChangeRequest) (*payload.WorkspaceChangeProposal, error) { + return resolveProvider(cfg).GetWorkspaceChangeProposals(fam, sz, sysMsg, request) } -func resolveProvider(cfg *config.Config) (provider, error) { - switch strings.ToLower(cfg.Provider) { - case "openai": - return &openAIProvider{}, nil - case "gemini": - return &geminiProvider{}, nil - default: - return nil, fmt.Errorf("unknown provider: %s", cfg.Provider) - } +// resolveProvider resolves the value of cfg.Provider to one of the known providers. +// Returns a throwing stub if it can't map the value to any known provider. +func resolveProvider(cfg *config.Config) provider { + switch strings.ToLower(cfg.Provider) { + case "openai": + return &openAIProvider{} + case "gemini": + return &geminiProvider{} + default: + return &unknownProvider{} + } } diff --git a/llm/dispatcher_test.go b/llm/dispatcher_test.go index 92027c9..7077ae0 100644 --- a/llm/dispatcher_test.go +++ b/llm/dispatcher_test.go @@ -6,6 +6,11 @@ import ( "github.com/vybdev/vyb/config" ) +// The following checks ensure that the provider implementations adhere to the +// provider interface. +var _ provider = (*openAIProvider)(nil) +var _ provider = (*geminiProvider)(nil) + // TestMapGeminiModel ensures that the (family,size) tuple is translated to // the correct concrete model identifier and that unsupported sizes are // properly rejected. diff --git a/llm/internal/gemini/gemini.go b/llm/internal/gemini/gemini.go index 5431623..a520c20 100644 --- a/llm/internal/gemini/gemini.go +++ b/llm/internal/gemini/gemini.go @@ -6,11 +6,12 @@ import ( "errors" "fmt" "github.com/vybdev/vyb/config" - gemschema "github.com/vybdev/vyb/llm/internal/gemini/internal/schema" + "github.com/vybdev/vyb/llm/internal/gemini/internal/schema" "github.com/vybdev/vyb/llm/payload" "io" "net/http" "os" + "strings" ) // mapModel converts the (family,size) tuple into the concrete Gemini @@ -34,7 +35,11 @@ func mapModel(fam config.ModelFamily, sz config.ModelSize) (string, error) { // // The function mirrors the public surface exposed by the OpenAI provider so // callers can remain provider-agnostic. -func GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, systemMessage, userMessage string) (*payload.WorkspaceChangeProposal, error) { +func GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, systemMessage string, request *payload.WorkspaceChangeRequest) (*payload.WorkspaceChangeProposal, error) { + userMessage, err := serializeWorkspaceChangeRequest(request) + if err != nil { + return nil, fmt.Errorf("gemini: failed to serialize workspace change request: %w", err) + } model, err := mapModel(fam, sz) if err != nil { return nil, err @@ -44,9 +49,7 @@ func GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, sy return nil, errors.New("GEMINI_API_KEY is not set") } - schema := gemschema.GetWorkspaceChangeProposalSchema() - - resp, err := callGemini(systemMessage, userMessage, schema, model) + resp, err := callGemini([]string{systemMessage, userMessage}, schema.GetWorkspaceChangeProposalSchema(), model) if err != nil { return nil, err } @@ -64,15 +67,17 @@ func GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, sy return &proposal, nil } -func GetModuleContext(systemMessage, userMessage string) (*payload.ModuleSelfContainedContext, error) { +func GetModuleContext(systemMessage string, request *payload.ModuleContextRequest) (*payload.ModuleSelfContainedContext, error) { + userMessage, err := serializeModuleContextRequest(request) + if err != nil { + return nil, fmt.Errorf("gemini: failed to serialize module context request: %w", err) + } model, err := mapModel(config.ModelFamilyReasoning, config.ModelSizeSmall) if err != nil { return nil, err } - schema := gemschema.GetModuleContextSchema() - - resp, err := callGemini(systemMessage, userMessage, schema, model) + resp, err := callGemini([]string{systemMessage, userMessage}, schema.GetModuleContextSchema(), model) if err != nil { return nil, err } @@ -90,15 +95,17 @@ func GetModuleContext(systemMessage, userMessage string) (*payload.ModuleSelfCon return &ctx, nil } -func GetModuleExternalContexts(systemMessage, userMessage string) (*payload.ModuleExternalContextResponse, error) { +func GetModuleExternalContexts(systemMessage string, request *payload.ExternalContextsRequest) (*payload.ModuleExternalContextResponse, error) { + userMessage, err := serializeExternalContextsRequest(request) + if err != nil { + return nil, fmt.Errorf("gemini: failed to serialize external contexts request: %w", err) + } model, err := mapModel(config.ModelFamilyReasoning, config.ModelSizeSmall) if err != nil { return nil, err } - schema := gemschema.GetModuleExternalContextSchema() - - resp, err := callGemini(systemMessage, userMessage, schema, model) + resp, err := callGemini([]string{systemMessage, userMessage}, schema.GetModuleExternalContextSchema(), model) if err != nil { return nil, err } @@ -116,6 +123,192 @@ func GetModuleExternalContexts(systemMessage, userMessage string) (*payload.Modu return &ext, nil } +// ----------------------------------------------------------------------------- +// +// Request Serializers +// +// ----------------------------------------------------------------------------- + +func serializeWorkspaceChangeRequest(request *payload.WorkspaceChangeRequest) (string, error) { + if request == nil { + return "", fmt.Errorf("WorkspaceChangeRequest must not be nil") + } + if request.TargetModule == "" { + return "", fmt.Errorf("TargetModule is required") + } + if request.TargetDirectory == "" { + return "", fmt.Errorf("TargetDirectory is required") + } + + var sb strings.Builder + + // Write target module information (these are now required) + sb.WriteString(fmt.Sprintf("# Target Module: `%s`\n", request.TargetModule)) + sb.WriteString("## Target Module Context\n") + sb.WriteString(fmt.Sprintf("%s\n\n", request.TargetModuleContext)) + sb.WriteString(fmt.Sprintf("## Target Directory: `%s`\n\n", request.TargetDirectory)) + + // Write parent module contexts + if len(request.ParentModuleContexts) > 0 { + sb.WriteString("# Parent Module Contexts\n") + for _, mc := range request.ParentModuleContexts { + ctx := &payload.ModuleSelfContainedContext{ + Name: mc.Name, + PublicContext: mc.Content, + } + writeModule(&sb, mc.Name, ctx) + } + sb.WriteString("\n") + } + + // Write sub-module contexts + if len(request.SubModuleContexts) > 0 { + sb.WriteString("# Sub-Module Contexts\n") + for _, mc := range request.SubModuleContexts { + ctx := &payload.ModuleSelfContainedContext{ + Name: mc.Name, + PublicContext: mc.Content, + } + writeModule(&sb, mc.Name, ctx) + } + sb.WriteString("\n") + } + + // Write files + if len(request.Files) > 0 { + sb.WriteString("# Files\n") + for _, f := range request.Files { + writeFile(&sb, f.Path, f.Content) + } + } + + return sb.String(), nil +} + +func serializeModuleContextRequest(request *payload.ModuleContextRequest) (string, error) { + if request == nil { + return "", fmt.Errorf("ModuleContextRequest must not be nil") + } + + var sb strings.Builder + rootPrefix := request.TargetModuleName + + // Only spend these tokens if we need to teach the LLM that a directory != module. + if len(request.TargetModuleDirectories) > 1 { + sb.WriteString(fmt.Sprintf("## Directories in module `%s`\n", rootPrefix)) + sb.WriteString(fmt.Sprintf("The following is a list of directories that are part of the module `%s`\n.", rootPrefix)) + sb.WriteString(fmt.Sprintf("These ARE NOT MODULES, they are directories within the module. When summarizing their file contents, include them in the summary of `%s`, do not make up modules for them.\n", rootPrefix)) + for _, dir := range request.TargetModuleDirectories { + sb.WriteString(fmt.Sprintf("- %s\n", dir)) + } + } + + sb.WriteString(fmt.Sprintf("## Files in module `%s`\n", rootPrefix)) + // Emit root-module files. + for _, file := range request.TargetModuleFiles { + writeFile(&sb, file.Path, file.Content) + } + + // Emit public context of immediate sub-modules. + for _, sub := range request.SubModulesPublicContexts { + // We only expose the public context of immediate sub-modules. + if sub.Content == "" && sub.Name == "" { + continue + } + + trimmedCtx := &payload.ModuleSelfContainedContext{ + Name: sub.Name, + PublicContext: sub.Content, + } + writeModule(&sb, trimmedCtx.Name, trimmedCtx) + } + + return sb.String(), nil +} + +func serializeExternalContextsRequest(request *payload.ExternalContextsRequest) (string, error) { + if request == nil { + return "", fmt.Errorf("ExternalContextsRequest must not be nil") + } + + var sb strings.Builder + + // Write each module with H1 headers + for _, module := range request.Modules { + if module.Name == "" { + continue + } + sb.WriteString(fmt.Sprintf("# Module: `%s`\n", module.Name)) + if module.ParentName != "" { + sb.WriteString(fmt.Sprintf("Parent Module: `%s`\n\n", module.ParentName)) + } + if module.InternalContext != "" { + sb.WriteString("## Internal Context\n") + sb.WriteString(fmt.Sprintf("%s\n\n", module.InternalContext)) + } + if module.PublicContext != "" { + sb.WriteString("## Public Context\n") + sb.WriteString(fmt.Sprintf("%s\n\n", module.PublicContext)) + } + } + + return sb.String(), nil +} + +func writeModule(sb *strings.Builder, path string, context *payload.ModuleSelfContainedContext) { + if sb == nil { + return + } + if path == "" && (context == nil || (context.ExternalContext == "" && context.InternalContext == "" && context.PublicContext == "")) { + return + } + sb.WriteString(fmt.Sprintf("# Module: `%s`\n", path)) + if context != nil { + if context.ExternalContext != "" { + sb.WriteString("## External Context\n") + sb.WriteString(fmt.Sprintf("%s\n", context.ExternalContext)) + } + if context.InternalContext != "" { + sb.WriteString("## Internal Context\n") + sb.WriteString(fmt.Sprintf("%s\n", context.InternalContext)) + } + if context.PublicContext != "" { + sb.WriteString("## Public Context\n") + sb.WriteString(fmt.Sprintf("%s\n", context.PublicContext)) + } + } +} + +func writeFile(sb *strings.Builder, filepath, content string) { + if sb == nil { + return + } + lang := getLanguageFromFilename(filepath) + sb.WriteString(fmt.Sprintf("### %s\n", filepath)) + sb.WriteString(fmt.Sprintf("```%s\n", lang)) + sb.WriteString(content) + // Ensure a trailing newline before closing the code block. + if !strings.HasSuffix(content, "\n") { + sb.WriteString("\n") + } + sb.WriteString("```\n\n") +} + +// getLanguageFromFilename returns a language identifier based on file extension. +func getLanguageFromFilename(filename string) string { + if strings.HasSuffix(filename, ".go") { + return "go" + } else if strings.HasSuffix(filename, ".md") { + return "markdown" + } else if strings.HasSuffix(filename, ".json") { + return "json" + } else if strings.HasSuffix(filename, ".txt") { + return "text" + } + // Default: no language specified. + return "" +} + // ----------------------------------------------------------------------------- // Provider-specific data structures & helpers (non-exported) // ----------------------------------------------------------------------------- @@ -175,16 +368,28 @@ func (e geminiErrorResponse) Error() string { return fmt.Sprintf("Gemini API error (%d %s): %s", e.Err.Code, e.Err.Status, e.Err.Message) } -func buildRequest(systemMessage, userMessage string, schema interface{}) ([]byte, error) { - if userMessage == "" { - return nil, errors.New("gemini: user message must not be empty") +func buildRequest(messages []string, schema interface{}) ([]byte, error) { + if len(messages) == 0 { + return nil, errors.New("gemini: messages cannot be empty") + } + + // Create a part for each message + var parts []part + for _, msg := range messages { + if msg != "" { + parts = append(parts, part{Text: msg}) + } + } + + if len(parts) == 0 { + return nil, errors.New("gemini: all messages are empty") } r := requestPayload{ Contents: []content{ { Role: "user", - Parts: []part{{Text: systemMessage + "\n\n" + userMessage}}, + Parts: parts, }, }, GenerationConfig: generationConfig{ @@ -196,7 +401,7 @@ func buildRequest(systemMessage, userMessage string, schema interface{}) ([]byte return json.Marshal(r) } -func callGemini(systemMessage, userMessage string, schema interface{}, model string) (*geminiResponse, error) { +func callGemini(messages []string, schema interface{}, model string) (*geminiResponse, error) { apiKey := os.Getenv("GEMINI_API_KEY") if apiKey == "" { return nil, errors.New("GEMINI_API_KEY is not set") @@ -207,7 +412,7 @@ func callGemini(systemMessage, userMessage string, schema interface{}, model str } // Build request body. - bodyBytes, err := buildRequest(systemMessage, userMessage, schema) + bodyBytes, err := buildRequest(messages, schema) if err != nil { return nil, err } diff --git a/llm/internal/gemini/gemini_test.go b/llm/internal/gemini/gemini_test.go index f8a7ed2..3b74404 100644 --- a/llm/internal/gemini/gemini_test.go +++ b/llm/internal/gemini/gemini_test.go @@ -8,9 +8,54 @@ import ( "reflect" "testing" + "github.com/vybdev/vyb/config" "github.com/vybdev/vyb/llm/payload" ) +func TestGetWorkspaceChangeProposals(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]any{ + "candidates": []any{ + map[string]any{ + "content": map[string]any{ + "parts": []any{ + map[string]any{ + "text": `{"summary":"s","description":"d","proposals":[]}`, + }, + }, + }, + }, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer srv.Close() + + oldBase := baseEndpoint + baseEndpoint = srv.URL + defer func() { baseEndpoint = oldBase }() + + os.Setenv("GEMINI_API_KEY", "x") + defer os.Unsetenv("GEMINI_API_KEY") + + req := &payload.WorkspaceChangeRequest{ + TargetModule: "test-module", + TargetModuleContext: "Test module context", + TargetDirectory: "src/", + Files: []payload.FileContent{ + {Path: "test.go", Content: "package main"}, + }, + } + got, err := GetWorkspaceChangeProposals(config.ModelFamilyGPT, config.ModelSizeSmall, "sys", req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := &payload.WorkspaceChangeProposal{Summary: "s", Description: "d", Proposals: []payload.FileChangeProposal{}} + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected proposal: got %+v, want %+v", got, want) + } +} + func TestGetModuleContext(t *testing.T) { // Dummy server returning minimal module context JSON. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -38,7 +83,11 @@ func TestGetModuleContext(t *testing.T) { os.Setenv("GEMINI_API_KEY", "x") defer os.Unsetenv("GEMINI_API_KEY") - got, err := GetModuleContext("sys", "usr") + req := &payload.ModuleContextRequest{ + TargetModuleName: "test-module", + } + + got, err := GetModuleContext("sys", req) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -74,7 +123,13 @@ func TestGetModuleExternalContexts(t *testing.T) { os.Setenv("GEMINI_API_KEY", "x") defer os.Unsetenv("GEMINI_API_KEY") - got, err := GetModuleExternalContexts("sys", "usr") + req := &payload.ExternalContextsRequest{ + Modules: []payload.ModuleInfoForExternalContext{ + {Name: "foo"}, + }, + } + + got, err := GetModuleExternalContexts("sys", req) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/llm/internal/openai/openai.go b/llm/internal/openai/openai.go index a9d5219..4b32f35 100644 --- a/llm/internal/openai/openai.go +++ b/llm/internal/openai/openai.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "os" + "strings" "github.com/vybdev/vyb/llm/payload" "time" @@ -83,7 +84,11 @@ func mapModel(fam config.ModelFamily, sz config.ModelSize) (string, error) { // GetModuleContext calls the LLM and returns a parsed ModuleSelfContainedContext // value using the model derived from family/size. -func GetModuleContext(systemMessage, userMessage string) (*payload.ModuleSelfContainedContext, error) { +func GetModuleContext(systemMessage string, request *payload.ModuleContextRequest) (*payload.ModuleSelfContainedContext, error) { + userMessage, err := serializeModuleContextRequest(request) + if err != nil { + return nil, fmt.Errorf("openai: failed to serialize module context request: %w", err) + } model := "o4-mini" openaiResp, err := callOpenAI(systemMessage, userMessage, schema.GetModuleContextSchema(), model) if err != nil { @@ -92,7 +97,7 @@ func GetModuleContext(systemMessage, userMessage string) (*payload.ModuleSelfCon if openAIErrResp.OpenAIError.Code == "rate_limit_exceeded" { fmt.Printf("Rate limit exceeded, retrying after 30s\n") <-time.After(30 * time.Second) - return GetModuleContext(systemMessage, userMessage) + return GetModuleContext(systemMessage, request) } } return nil, err @@ -106,7 +111,11 @@ func GetModuleContext(systemMessage, userMessage string) (*payload.ModuleSelfCon // GetWorkspaceChangeProposals sends the given messages to the OpenAI API and // returns the structured workspace change proposal. -func GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, systemMessage, userMessage string) (*payload.WorkspaceChangeProposal, error) { +func GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, systemMessage string, request *payload.WorkspaceChangeRequest) (*payload.WorkspaceChangeProposal, error) { + userMessage, err := serializeWorkspaceChangeRequest(request) + if err != nil { + return nil, fmt.Errorf("openai: failed to serialize workspace change request: %w", err) + } model, err := mapModel(fam, sz) if err != nil { return nil, err @@ -124,6 +133,9 @@ func GetWorkspaceChangeProposals(fam config.ModelFamily, sz config.ModelSize, sy return &proposal, nil } +// NOTE: baseEndpoint is a var (not const) to allow test overrides. +var baseEndpoint = "https://api.openai.com/v1/chat/completions" + // callOpenAI sends a request to OpenAI, returns the parsed response, and logs // the request/response pair to a uniquely-named JSON file in the OS temp dir. func callOpenAI(systemMessage, userMessage string, structuredOutput schema.StructuredOutputSchema, model string) (*openaiResponse, error) { @@ -156,7 +168,7 @@ func callOpenAI(systemMessage, userMessage string, structuredOutput schema.Struc return nil, err } - req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(reqBytes)) + req, err := http.NewRequest("POST", baseEndpoint, bytes.NewBuffer(reqBytes)) if err != nil { return nil, err } @@ -232,7 +244,11 @@ func callOpenAI(systemMessage, userMessage string, structuredOutput schema.Struc // GetModuleExternalContexts calls the LLM and returns a list of external // context strings – one per module. -func GetModuleExternalContexts(systemMessage, userMessage string) (*payload.ModuleExternalContextResponse, error) { +func GetModuleExternalContexts(systemMessage string, request *payload.ExternalContextsRequest) (*payload.ModuleExternalContextResponse, error) { + userMessage, err := serializeExternalContextsRequest(request) + if err != nil { + return nil, fmt.Errorf("openai: failed to serialize external contexts request: %w", err) + } model := "o4-mini" openaiResp, err := callOpenAI(systemMessage, userMessage, schema.GetModuleExternalContextSchema(), model) if err != nil { @@ -245,3 +261,189 @@ func GetModuleExternalContexts(systemMessage, userMessage string) (*payload.Modu } return &resp, nil } + +// ----------------------------------------------------------------------------- +// +// Request Serializers +// +// ----------------------------------------------------------------------------- + +func serializeWorkspaceChangeRequest(request *payload.WorkspaceChangeRequest) (string, error) { + if request == nil { + return "", fmt.Errorf("WorkspaceChangeRequest must not be nil") + } + if request.TargetModule == "" { + return "", fmt.Errorf("TargetModule is required") + } + if request.TargetDirectory == "" { + return "", fmt.Errorf("TargetDirectory is required") + } + + var sb strings.Builder + + // Write target module information (these are now required) + sb.WriteString(fmt.Sprintf("# Target Module: `%s`\n", request.TargetModule)) + sb.WriteString("## Target Module Context\n") + sb.WriteString(fmt.Sprintf("%s\n\n", request.TargetModuleContext)) + sb.WriteString(fmt.Sprintf("## Target Directory: `%s`\n\n", request.TargetDirectory)) + + // Write parent module contexts + if len(request.ParentModuleContexts) > 0 { + sb.WriteString("# Parent Module Contexts\n") + for _, mc := range request.ParentModuleContexts { + ctx := &payload.ModuleSelfContainedContext{ + Name: mc.Name, + PublicContext: mc.Content, + } + writeModule(&sb, mc.Name, ctx) + } + sb.WriteString("\n") + } + + // Write sub-module contexts + if len(request.SubModuleContexts) > 0 { + sb.WriteString("# Sub-Module Contexts\n") + for _, mc := range request.SubModuleContexts { + ctx := &payload.ModuleSelfContainedContext{ + Name: mc.Name, + PublicContext: mc.Content, + } + writeModule(&sb, mc.Name, ctx) + } + sb.WriteString("\n") + } + + // Write files + if len(request.Files) > 0 { + sb.WriteString("# Files\n") + for _, f := range request.Files { + writeFile(&sb, f.Path, f.Content) + } + } + + return sb.String(), nil +} + +func serializeModuleContextRequest(request *payload.ModuleContextRequest) (string, error) { + if request == nil { + return "", fmt.Errorf("ModuleContextRequest must not be nil") + } + + var sb strings.Builder + rootPrefix := request.TargetModuleName + + // Only spend these tokens if we need to teach the LLM that a directory != module. + if len(request.TargetModuleDirectories) > 1 { + sb.WriteString(fmt.Sprintf("## Directories in module `%s`\n", rootPrefix)) + sb.WriteString(fmt.Sprintf("The following is a list of directories that are part of the module `%s`\n.", rootPrefix)) + sb.WriteString(fmt.Sprintf("These ARE NOT MODULES, they are directories within the module. When summarizing their file contents, include them in the summary of `%s`, do not make up modules for them.\n", rootPrefix)) + for _, dir := range request.TargetModuleDirectories { + sb.WriteString(fmt.Sprintf("- %s\n", dir)) + } + } + + sb.WriteString(fmt.Sprintf("## Files in module `%s`\n", rootPrefix)) + // Emit root-module files. + for _, file := range request.TargetModuleFiles { + writeFile(&sb, file.Path, file.Content) + } + + // Emit public context of immediate sub-modules. + for _, sub := range request.SubModulesPublicContexts { + // We only expose the public context of immediate sub-modules. + if sub.Content == "" && sub.Name == "" { + continue + } + + trimmedCtx := &payload.ModuleSelfContainedContext{ + Name: sub.Name, + PublicContext: sub.Content, + } + writeModule(&sb, trimmedCtx.Name, trimmedCtx) + } + + return sb.String(), nil +} + +func serializeExternalContextsRequest(request *payload.ExternalContextsRequest) (string, error) { + if request == nil { + return "", fmt.Errorf("ExternalContextsRequest must not be nil") + } + + var sb strings.Builder + + // Write each module with H1 headers + for _, module := range request.Modules { + if module.Name == "" { + continue + } + sb.WriteString(fmt.Sprintf("# Module: `%s`\n", module.Name)) + if module.ParentName != "" { + sb.WriteString(fmt.Sprintf("Parent Module: `%s`\n\n", module.ParentName)) + } + if module.InternalContext != "" { + sb.WriteString("## Internal Context\n") + sb.WriteString(fmt.Sprintf("%s\n\n", module.InternalContext)) + } + if module.PublicContext != "" { + sb.WriteString("## Public Context\n") + sb.WriteString(fmt.Sprintf("%s\n\n", module.PublicContext)) + } + } + + return sb.String(), nil +} + +func writeModule(sb *strings.Builder, path string, context *payload.ModuleSelfContainedContext) { + if sb == nil { + return + } + if path == "" && (context == nil || (context.ExternalContext == "" && context.InternalContext == "" && context.PublicContext == "")) { + return + } + sb.WriteString(fmt.Sprintf("# Module: `%s`\n", path)) + if context != nil { + if context.ExternalContext != "" { + sb.WriteString("## External Context\n") + sb.WriteString(fmt.Sprintf("%s\n", context.ExternalContext)) + } + if context.InternalContext != "" { + sb.WriteString("## Internal Context\n") + sb.WriteString(fmt.Sprintf("%s\n", context.InternalContext)) + } + if context.PublicContext != "" { + sb.WriteString("## Public Context\n") + sb.WriteString(fmt.Sprintf("%s\n", context.PublicContext)) + } + } +} + +func writeFile(sb *strings.Builder, filepath, content string) { + if sb == nil { + return + } + lang := getLanguageFromFilename(filepath) + sb.WriteString(fmt.Sprintf("### %s\n", filepath)) + sb.WriteString(fmt.Sprintf("```%s\n", lang)) + sb.WriteString(content) + // Ensure a trailing newline before closing the code block. + if !strings.HasSuffix(content, "\n") { + sb.WriteString("\n") + } + sb.WriteString("```\n\n") +} + +// getLanguageFromFilename returns a language identifier based on file extension. +func getLanguageFromFilename(filename string) string { + if strings.HasSuffix(filename, ".go") { + return "go" + } else if strings.HasSuffix(filename, ".md") { + return "markdown" + } else if strings.HasSuffix(filename, ".json") { + return "json" + } else if strings.HasSuffix(filename, ".txt") { + return "text" + } + // Default: no language specified. + return "" +} diff --git a/llm/payload/payload.go b/llm/payload/payload.go index db15bb7..8d29ea3 100644 --- a/llm/payload/payload.go +++ b/llm/payload/payload.go @@ -1,39 +1,79 @@ +// Package payload contains data structures for LLM requests and responses. package payload -import ( - "fmt" - "io/fs" - "path/filepath" - "strings" -) - -// fileEntry represents a file with its relative path and content. -type fileEntry struct { - Path string - Content string +// --- Request Payloads --- + +// FileContent holds the path and content of a file. +type FileContent struct { + Path string `json:"path"` + Content string `json:"content"` +} + +// WorkspaceChangeRequest contains all the necessary context and files for +// proposing workspace changes. +type WorkspaceChangeRequest struct { + + // WorkingModule represents the topmost module whose context is included in the workspace request + WorkingModule string `json:"working_module"` + // WorkingModuleContext is the context of the working module + WorkingModuleContext string `json:"working_module_context"` + + // TargetModule represents the module that contains the target directory + TargetModule string `json:"target_module"` + // TargetModuleContext contains the context of the target module + TargetModuleContext string `json:"target_module_context"` + // TargetDirectory is the root directory from which the change request + // should be applied (no change is expected outside of this directory or its subdirectories) + TargetDirectory string `json:"target_directory"` + + // ParentModuleContexts contains the context of the parent and sibling modules + // of the TargetModule contained within the working module, if any + ParentModuleContexts []ModuleContext `json:"parent_module_contexts"` + + // SubModuleContexts contains the context of all the direct submodules of the TargetModule, if any. + SubModuleContexts []ModuleContext `json:"submodule_contexts"` + + // Files contains the content of files relevant to the task. + Files []FileContent `json:"files"` +} + +// ModuleContext represents a piece of named context from a module. +type ModuleContext struct { + Name string `json:"name"` + Content string `json:"content"` +} + +// ModuleContextRequest provides the necessary information to generate +// the internal and public contexts for a single module. +type ModuleContextRequest struct { + // TargetModuleName is the name of the module being processed. + TargetModuleName string `json:"target_module_name"` + + // TargetModuleFiles are the files within the module to be summarized. + TargetModuleFiles []FileContent `json:"target_module_files"` + + // TargetModuleDirectories are the directories within the module. + TargetModuleDirectories []string `json:"target_module_directories"` + + // SubModulesPublicContexts are the public contexts of immediate sub-modules. + SubModulesPublicContexts []ModuleContext `json:"sub_modules_public_contexts"` } -// BuildUserMessage constructs a Markdown-formatted string that includes the content of all files in scope. -// projectRoot represents the base directory for this project, and all file paths in the given filePaths parameter are relative to projectRoot. -func BuildUserMessage(projectRoot fs.FS, filePaths []string) (string, error) { - var files []fileEntry - for _, path := range filePaths { - data, err := fs.ReadFile(projectRoot, path) - if err != nil { - return "", err - } - files = append(files, fileEntry{ - Path: path, - Content: string(data), - }) - } - markdown := buildPayload(files) - return markdown, nil +// ExternalContextsRequest contains information about a module hierarchy +// needed to generate external contexts for each module. +type ExternalContextsRequest struct { + Modules []ModuleInfoForExternalContext `json:"modules"` } -// --------------------- -// Data abstractions -// --------------------- +// ModuleInfoForExternalContext holds the data for a single module. +type ModuleInfoForExternalContext struct { + Name string `json:"name"` + ParentName string `json:"parent_name,omitempty"` + InternalContext string `json:"internal_context,omitempty"` + PublicContext string `json:"public_context,omitempty"` +} + +// --- Response Payloads --- // WorkspaceChangeProposal is a concrete description of proposed workspace // changes coming from the LLM. @@ -63,164 +103,6 @@ type ModuleExternalContext struct { Name string `json:"name,omitempty"` ExternalContext string `json:"external_context,omitempty"` } -type ModuleSelfContainedContextRequest struct { - FilePaths []string - Directories []string - ModuleCtx *ModuleSelfContainedContext - SubModules []*ModuleSelfContainedContextRequest -} - -// BuildModuleContextUserMessage constructs a Markdown-formatted string that -// includes the content of all files referenced by the provided -// ModuleSelfContainedContextRequest *root* and the public context of its immediate -// sub-modules. -// -// Behaviour rules: -// 1. The files listed in the root request are included verbatim. -// 2. For each *immediate* sub-module of the root request, only its -// PublicContext (if any) is emitted – no files are rendered for those -// sub-modules and no information is emitted for modules that are -// grandchildren or deeper. -// -// If any referenced file cannot be read this function returns an error. -func BuildModuleContextUserMessage(projectRoot fs.FS, request *ModuleSelfContainedContextRequest) (string, error) { - if projectRoot == nil { - return "", fmt.Errorf("projectRoot fs.FS must not be nil") - } - if request == nil { - return "", fmt.Errorf("ModuleSelfContainedContextRequest must not be nil") - } - - var sb strings.Builder - - // Helper that resolves the absolute workspace path for a file declared in - // the *root* module. - resolvePath := func(rootPrefix, rel string) string { - if rootPrefix == "" || rootPrefix == "." || strings.HasPrefix(rel, rootPrefix+string(filepath.Separator)) { - return rel - } - return filepath.Join(rootPrefix, rel) - } - - // ----------------------------- - // 1. Emit root-module information. - // ----------------------------- - rootPrefix := "" - if request.ModuleCtx != nil && request.ModuleCtx.Name != "" { - rootPrefix = request.ModuleCtx.Name - } - - // Header for the root module if we have a name or any context data. - if rootPrefix != "" || (request.ModuleCtx != nil && (request.ModuleCtx.ExternalContext != "" || request.ModuleCtx.InternalContext != "" || request.ModuleCtx.PublicContext != "")) { - writeModule(&sb, rootPrefix, request.ModuleCtx) - } - // Only spend these tokens if we need to teach the LLM that a directory != module. - if len(request.Directories) > 1 { - sb.WriteString(fmt.Sprintf("## Directories in module `%s`\n", rootPrefix)) - sb.WriteString(fmt.Sprintf("The following is a list of directories that are part of the module `%s`\n.", rootPrefix)) - sb.WriteString(fmt.Sprintf("These ARE NOT MODULES, they are directories within the module. When summarizing their file contents, include them in the summary of `%s`, do not make up modules for them.\n", rootPrefix)) - for _, dir := range request.Directories { - sb.WriteString(fmt.Sprintf("- %s\n", dir)) - } - } - - sb.WriteString(fmt.Sprintf("## Files in module `%s`\n", rootPrefix)) - // Emit root-module files. - for _, relFile := range request.FilePaths { - fullPath := resolvePath(rootPrefix, relFile) - data, err := fs.ReadFile(projectRoot, fullPath) - if err != nil { - return "", fmt.Errorf("failed to read file %s: %w", fullPath, err) - } - writeFile(&sb, fullPath, string(data)) - } - - // ----------------------------- - // 2. Emit public context of immediate sub-modules. - // ----------------------------- - for _, sub := range request.SubModules { - if sub == nil || sub.ModuleCtx == nil { - continue // nothing useful to emit - } - - // We only expose the public context of immediate sub-modules. - if sub.ModuleCtx.PublicContext == "" && sub.ModuleCtx.Name == "" { - continue - } - - trimmedCtx := &ModuleSelfContainedContext{ - Name: sub.ModuleCtx.Name, - PublicContext: sub.ModuleCtx.PublicContext, - } - writeModule(&sb, trimmedCtx.Name, trimmedCtx) - } - - return sb.String(), nil -} - -// buildPayload constructs a Markdown payload from a slice of fileEntry. -// Each file is represented with an H1 header for its relative path, followed by a code block. -func buildPayload(files []fileEntry) string { - var sb strings.Builder - for _, f := range files { - writeFile(&sb, f.Path, f.Content) - } - return sb.String() -} - -func writeModule(sb *strings.Builder, path string, context *ModuleSelfContainedContext) { - if sb == nil { - return - } - if path == "" && (context == nil || (context.ExternalContext == "" && context.InternalContext == "" && context.PublicContext == "")) { - return - } - sb.WriteString(fmt.Sprintf("# Module: `%s`\n", path)) - if context != nil { - if context.ExternalContext != "" { - sb.WriteString("## External Context\n") - sb.WriteString(fmt.Sprintf("%s\n", context.ExternalContext)) - } - if context.InternalContext != "" { - sb.WriteString("## Internal Context\n") - sb.WriteString(fmt.Sprintf("%s\n", context.InternalContext)) - } - if context.PublicContext != "" { - sb.WriteString("## Public Context\n") - sb.WriteString(fmt.Sprintf("%s\n", context.PublicContext)) - } - } -} - -func writeFile(sb *strings.Builder, filepath, content string) { - if sb == nil { - return - } - lang := getLanguageFromFilename(filepath) - sb.WriteString(fmt.Sprintf("### %s\n", filepath)) - sb.WriteString(fmt.Sprintf("```%s\n", lang)) - sb.WriteString(content) - // Ensure a trailing newline before closing the code block. - if !strings.HasSuffix(content, "\n") { - sb.WriteString("\n") - } - sb.WriteString("```\n\n") -} - -// getLanguageFromFilename returns a language identifier based on file extension. -func getLanguageFromFilename(filename string) string { - if strings.HasSuffix(filename, ".go") { - return "go" - } else if strings.HasSuffix(filename, ".md") { - return "markdown" - } else if strings.HasSuffix(filename, ".json") { - return "json" - } else if strings.HasSuffix(filename, ".txt") { - return "text" - } - // Default: no language specified. - return "" -} // ModuleExternalContextResponse captures the LLM response when generating // external contexts for a set of modules. diff --git a/llm/payload/payload_test.go b/llm/payload/payload_test.go index 8eccb4f..697bcef 100644 --- a/llm/payload/payload_test.go +++ b/llm/payload/payload_test.go @@ -1,134 +1,80 @@ package payload import ( + "encoding/json" + "reflect" "testing" - "testing/fstest" ) -func context(name string) *ModuleSelfContainedContext { return &ModuleSelfContainedContext{Name: name} } -func pcontext(name string) *ModuleSelfContainedContext { return context(name) } - -func TestBuildModuleContextUserMessage(t *testing.T) { - // Files arranged in a nested module hierarchy: - // - root.txt (root module / no module name) - // - moduleA/a.go - // - moduleA/subB/b.md - mfs := fstest.MapFS{ - "root.txt": &fstest.MapFile{Data: []byte("root")}, - "moduleA/a.go": &fstest.MapFile{Data: []byte("package foo\n")}, - "moduleA/subB/b.md": &fstest.MapFile{Data: []byte("Markdown content")}, - } - - // Construct the ModuleSelfContainedContextRequest tree that mirrors the hierarchy. - req := &ModuleSelfContainedContextRequest{ - ModuleCtx: &ModuleSelfContainedContext{Name: "."}, - FilePaths: []string{"root.txt"}, - SubModules: []*ModuleSelfContainedContextRequest{ - { - ModuleCtx: &ModuleSelfContainedContext{Name: "moduleA", PublicContext: "moduleA public"}, - FilePaths: []string{"moduleA/a.go"}, - SubModules: []*ModuleSelfContainedContextRequest{ - { - ModuleCtx: &ModuleSelfContainedContext{Name: "moduleA/subB", PublicContext: "subB public"}, - FilePaths: []string{"moduleA/subB/b.md"}, - }, +func TestRequestPayloads_JSONMarshalling(t *testing.T) { + testcases := []struct { + name string + payload interface{} + newInst func() interface{} + }{ + { + name: "WorkspaceChangeRequest", + payload: &WorkspaceChangeRequest{ + TargetModule: "my-module", + TargetModuleContext: "context info", + TargetDirectory: "src/", + ParentModuleContexts: []ModuleContext{ + {Name: "parent1", Content: "parent context"}, + }, + SubModuleContexts: []ModuleContext{ + {Name: "sub1", Content: "sub context"}, + }, + Files: []FileContent{ + {Path: "file1.go", Content: "package main"}, }, }, + newInst: func() interface{} { return &WorkspaceChangeRequest{} }, }, - } - - got, err := BuildModuleContextUserMessage(mfs, req) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - expected := "# Module: `.`\n## Files in module `.`\n" + - "### root.txt\n```text\nroot\n```\n\n" + - "# Module: `moduleA`\n## Public Context\nmoduleA public\n" - - if got != expected { - t.Errorf("payload mismatch.\nGot:\n%s\nExpected:\n%s", got, expected) - } -} - -func TestBuildModuleContextUserMessage_FileNotFound(t *testing.T) { - // Empty filesystem – any file access should fail. - mfs := fstest.MapFS{} - - req := &ModuleSelfContainedContextRequest{ - FilePaths: []string{"does_not_exist.txt"}, - } - - if _, err := BuildModuleContextUserMessage(mfs, req); err == nil { - t.Fatalf("expected error for missing file, got nil") - } -} - -// New test validating selective context inclusion semantics. -func TestBuildModuleContextUserMessage_Selectivity(t *testing.T) { - /* - Module hierarchy used in this test: - A (root) - ├── B - └── C - └── D - */ - mfs := fstest.MapFS{ - "A/a.go": &fstest.MapFile{Data: []byte("package a\n")}, - "A/C/c.go": &fstest.MapFile{Data: []byte("package c\n")}, - } - - // Full tree rooted at A. - treeA := &ModuleSelfContainedContextRequest{ - ModuleCtx: &ModuleSelfContainedContext{Name: "A"}, - FilePaths: []string{"A/a.go"}, - SubModules: []*ModuleSelfContainedContextRequest{ - { - ModuleCtx: &ModuleSelfContainedContext{Name: "A/B", PublicContext: "This is B"}, + { + name: "ModuleContextRequest", + payload: &ModuleContextRequest{ + TargetModuleName: "my-module", + TargetModuleFiles: []FileContent{ + {Path: "file1.go", Content: "package main"}, + }, + TargetModuleDirectories: []string{"dir1"}, + SubModulesPublicContexts: []ModuleContext{ + {Name: "sub1", Content: "pub_ctx"}, + }, }, - { - ModuleCtx: &ModuleSelfContainedContext{Name: "A/C", PublicContext: "This is C", InternalContext: "This is C's internal context."}, - FilePaths: []string{"A/C/c.go"}, - SubModules: []*ModuleSelfContainedContextRequest{{ - ModuleCtx: &ModuleSelfContainedContext{Name: "A/C/D", PublicContext: "This is D. It won't be included."}, - }}, + newInst: func() interface{} { return &ModuleContextRequest{} }, + }, + { + name: "ExternalContextsRequest", + payload: &ExternalContextsRequest{ + Modules: []ModuleInfoForExternalContext{ + { + Name: "mod1", + ParentName: "", + InternalContext: "int_ctx", + PublicContext: "pub_ctx", + }, + }, }, + newInst: func() interface{} { return &ExternalContextsRequest{} }, }, } - gotA, err := BuildModuleContextUserMessage(mfs, treeA) - if err != nil { - t.Fatalf("unexpected error building payload for A: %v", err) - } - - expectedA := "# Module: `A`\n## Files in module `A`\n" + - "### A/a.go\n```go\npackage a\n```\n\n" + - "# Module: `A/B`\n## Public Context\nThis is B\n" + - "# Module: `A/C`\n## Public Context\nThis is C\n" - - if gotA != expectedA { - t.Errorf("payload for A mismatch.\nGot:\n%s\nExpected:\n%s", gotA, expectedA) - } - - // Sub-tree rooted at C. - treeC := &ModuleSelfContainedContextRequest{ - ModuleCtx: &ModuleSelfContainedContext{Name: "A/C"}, - FilePaths: []string{"A/C/c.go"}, - SubModules: []*ModuleSelfContainedContextRequest{{ - ModuleCtx: &ModuleSelfContainedContext{Name: "A/C/D", PublicContext: "This is D"}, - }}, - } - - gotC, err := BuildModuleContextUserMessage(mfs, treeC) - if err != nil { - t.Fatalf("unexpected error building payload for C: %v", err) - } - - expectedC := "# Module: `A/C`\n## Files in module `A/C`\n" + - "### A/C/c.go\n```go\npackage c\n```\n\n" + - "# Module: `A/C/D`\n## Public Context\nThis is D\n" - - if gotC != expectedC { - t.Errorf("payload for C mismatch.\nGot:\n%s\nExpected:\n%s", gotC, expectedC) + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + data, err := json.Marshal(tc.payload) + if err != nil { + t.Fatalf("json.Marshal() failed: %v", err) + } + + unmarshaled := tc.newInst() + if err := json.Unmarshal(data, unmarshaled); err != nil { + t.Fatalf("json.Unmarshal() failed: %v", err) + } + + if !reflect.DeepEqual(tc.payload, unmarshaled) { + t.Errorf("round-trip mismatch.\nGot: %#v\nWant: %#v", unmarshaled, tc.payload) + } + }) } -} +} \ No newline at end of file diff --git a/workspace/project/annotation.go b/workspace/project/annotation.go index 3ec86da..0dae880 100644 --- a/workspace/project/annotation.go +++ b/workspace/project/annotation.go @@ -102,51 +102,42 @@ func collectModulesInPostOrder(root *Module) []*Module { return result } -// buildModuleContextRequest converts a *Module hierarchy to a *payload.ModuleSelfContainedContextRequest tree. -func buildModuleContextRequest(m *Module) *payload.ModuleSelfContainedContextRequest { - if m == nil { - return nil - } - - // Collect file paths relative to this module (just the file names). - var paths []string - for _, f := range m.Files { - paths = append(paths, f.Name) +// addOrUpdateSelfContainedContext calls the LLM to construct the internal and public context of a given module. +func addOrUpdateSelfContainedContext(cfg *config.Config, m *Module, sysfs fs.FS) error { + // Build the ModuleContextRequest for this module. + var targetFiles []payload.FileContent + for _, fileRef := range m.Files { + content, err := fs.ReadFile(sysfs, fileRef.Name) + if err != nil { + return fmt.Errorf("failed to read file %s: %w", fileRef.Name, err) + } + targetFiles = append(targetFiles, payload.FileContent{ + Path: fileRef.Name, + Content: string(content), + }) } - // Recursively process sub-modules. - var subs []*payload.ModuleSelfContainedContextRequest - for _, sm := range m.Modules { - subs = append(subs, buildModuleContextRequest(sm)) + var subContexts []payload.ModuleContext + for _, subMod := range m.Modules { + var publicContext string + if subMod.Annotation != nil && subMod.Annotation.PublicContext != "" { + publicContext = subMod.Annotation.PublicContext + } + subContexts = append(subContexts, payload.ModuleContext{ + Name: subMod.Name, + Content: publicContext, + }) } - // For the root module (name == ".") we omit the ModuleSelfContainedContext so we don’t get a "# ." header. - var ctxPtr *payload.ModuleSelfContainedContext - //if m.Name != "." { - // ctxPtr = &payload.ModuleSelfContainedContext{Name: m.Name} - //} - - return &payload.ModuleSelfContainedContextRequest{ - FilePaths: paths, - Directories: m.Directories, - ModuleCtx: ctxPtr, - SubModules: subs, + req := &payload.ModuleContextRequest{ + TargetModuleName: m.Name, + TargetModuleFiles: targetFiles, + TargetModuleDirectories: m.Directories, + SubModulesPublicContexts: subContexts, } -} - -// addOrUpdateSelfContainedContext calls OpenAI to construct the internal and public context of a given module. -func addOrUpdateSelfContainedContext(cfg *config.Config, m *Module, sysfs fs.FS) error { - // Build the ModuleSelfContainedContextRequest tree starting from this module. - req := buildModuleContextRequest(m) fmt.Printf("annotating module %q\n", m.Name) - // Construct user message including the files for this module. - userMsg, err := payload.BuildModuleContextUserMessage(sysfs, req) - if err != nil { - return fmt.Errorf("failed to build user message: %w", err) - } - // System prompt instructing the LLM to summarize code into JSON schema. systemMessage := `You are a prompt engineer, structuring information about an application's code base so context can be provided to an LLM in the most efficient way. @@ -172,12 +163,12 @@ you included in the Internal Context, but also all the Public Context informatio Each type of context should be as descriptive as possible, using around one thousand LLM tokens, each.` - context, err := llm.GetModuleContext(cfg, systemMessage, userMsg) + context, err := llm.GetModuleContext(cfg, systemMessage, req) fmt.Printf(" Got response for module %q\n", m.Name) if err != nil { - return fmt.Errorf("failed to call openAI: %w", err) + return fmt.Errorf("failed to call llm provider: %w", err) } if m.Annotation == nil { @@ -247,27 +238,32 @@ func addOrUpdateExternalContext(cfg *config.Config, m *Module) error { } // ------------------------------------------------------------ - // 2. Build user-message containing internal & public context that the + // 2. Build request containing internal & public context that the // LLM will use to infer external context. // ------------------------------------------------------------ - var sb strings.Builder + var modulesForRequest []payload.ModuleInfoForExternalContext for _, mod := range modules { - sb.WriteString(fmt.Sprintf("## Module: %s\n", mod.Name)) + var parentName string if mod.Parent != nil { - sb.WriteString(fmt.Sprintf("### Parent: %s\n", mod.Parent.Name)) + parentName = mod.Parent.Name } + + var internalCtx, publicCtx string if mod.Annotation != nil { - if mod.Annotation.InternalContext != "" { - sb.WriteString("### Internal Context\n") - sb.WriteString(mod.Annotation.InternalContext + "\n") - } - if mod.Annotation.PublicContext != "" { - sb.WriteString("### Public Context\n") - sb.WriteString(mod.Annotation.PublicContext + "\n") - } + internalCtx = mod.Annotation.InternalContext + publicCtx = mod.Annotation.PublicContext } + + modulesForRequest = append(modulesForRequest, payload.ModuleInfoForExternalContext{ + Name: mod.Name, + ParentName: parentName, + InternalContext: internalCtx, + PublicContext: publicCtx, + }) + } + request := &payload.ExternalContextsRequest{ + Modules: modulesForRequest, } - userMsg := sb.String() // ------------------------------------------------------------ // 3. Call LLM. @@ -286,7 +282,7 @@ concise explanation of where the module lives in the hierarchy and what lives Return your answer as JSON following the schema you have been provided.` - resp, err := llm.GetModuleExternalContexts(cfg, sysPrompt, userMsg) + resp, err := llm.GetModuleExternalContexts(cfg, sysPrompt, request) if err != nil { return err }