diff --git a/examples/rag/custom_provider.yaml b/examples/rag/custom_provider.yaml new file mode 100644 index 000000000..78364d92e --- /dev/null +++ b/examples/rag/custom_provider.yaml @@ -0,0 +1,40 @@ +# This example demonstrates using a custom provider for RAG embedding models. +# For instance, you can use a local Ollama instance or any OpenAI-compatible +# API endpoint for generating embeddings. + +providers: + local-ollama: + base_url: http://localhost:11434/v1 + +models: + local-embed: + provider: local-ollama + model: nomic-embed-text + +agents: + root: + model: openai/gpt-5-mini + description: assistant with RAG using custom embedding provider + instruction: | + You are a helpful assistant with access to a knowledge base. + Use the search tool to find relevant information before answering. + toolsets: + - type: rag + ref: knowledge_base + +rag: + knowledge_base: + tool: + description: search the knowledge base for relevant information + docs: + - ./docs + strategies: + - type: chunked-embeddings + embedding_model: local-embed # References the model defined above using the custom provider + database: ./custom_provider.db + vector_dimensions: 768 + chunking: + size: 1000 + overlap: 100 + results: + limit: 5 diff --git a/pkg/config/runtime.go b/pkg/config/runtime.go index cc1e8fb80..2c1dccfba 100644 --- a/pkg/config/runtime.go +++ b/pkg/config/runtime.go @@ -25,6 +25,7 @@ type Config struct { GlobalCodeMode bool WorkingDir string Models map[string]latest.ModelConfig + Providers map[string]latest.ProviderConfig // Hook overrides from CLI flags HookPreToolUse []string @@ -40,6 +41,7 @@ func (runConfig *RuntimeConfig) Clone() *RuntimeConfig { } clone.EnvFiles = slices.Clone(runConfig.EnvFiles) clone.Models = maps.Clone(runConfig.Models) + clone.Providers = maps.Clone(runConfig.Providers) clone.DefaultModel = runConfig.DefaultModel.Clone() clone.HookPreToolUse = slices.Clone(runConfig.HookPreToolUse) clone.HookPostToolUse = slices.Clone(runConfig.HookPostToolUse) diff --git a/pkg/rag/builder.go b/pkg/rag/builder.go index a243513ec..72ffb5e5d 100644 --- a/pkg/rag/builder.go +++ b/pkg/rag/builder.go @@ -5,12 +5,11 @@ import ( "errors" "fmt" "log/slog" - "maps" - "slices" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/rag/rerank" "github.com/docker/docker-agent/pkg/rag/strategy" "github.com/docker/docker-agent/pkg/rag/types" @@ -21,7 +20,16 @@ type ManagersBuildConfig struct { ParentDir string ModelsGateway string Env environment.Provider - Models map[string]latest.ModelConfig // Model configurations from config + Models map[string]latest.ModelConfig // Model configurations from config + Providers map[string]latest.ProviderConfig // Custom provider configurations from config +} + +// NewProvider creates a model provider using the build config's environment, +// gateway, and custom provider settings. +func (c ManagersBuildConfig) NewProvider(ctx context.Context, cfg *latest.ModelConfig) (provider.Provider, error) { + return provider.New(ctx, cfg, c.Env, + options.WithGateway(c.ModelsGateway), + options.WithProviders(c.Providers)) } // NewManager constructs a single RAG manager from a RAGConfig. @@ -46,6 +54,7 @@ func NewManager( ParentDir: buildCfg.ParentDir, SharedDocs: GetAbsolutePaths(buildCfg.ParentDir, ragCfg.Docs), Models: buildCfg.Models, + Providers: buildCfg.Providers, Env: buildCfg.Env, ModelsGateway: buildCfg.ModelsGateway, RespectVCS: ragCfg.GetRespectVCS(), @@ -146,20 +155,21 @@ func buildRerankingConfig( "model_ref", rerankCfg.Model) // Resolve model config - check if it's a reference to a defined model or inline - modelCfg, err := resolveModelConfig(rerankCfg.Model, buildCfg) + modelCfgVal, err := strategy.ResolveModelConfig(rerankCfg.Model, buildCfg.Models) if err != nil { slog.Error("Failed to resolve reranking model", "model_ref", rerankCfg.Model, "error", err) return nil, fmt.Errorf("failed to resolve reranking model %q: %w", rerankCfg.Model, err) } + modelCfg := &modelCfgVal slog.Debug("Resolved reranking model config", "provider", modelCfg.Provider, "model", modelCfg.Model) // Create provider for reranking model - rerankProvider, err := provider.New(ctx, modelCfg, buildCfg.Env) + rerankProvider, err := buildCfg.NewProvider(ctx, modelCfg) if err != nil { slog.Error("Failed to create reranking provider", "provider", modelCfg.Provider, @@ -206,55 +216,6 @@ func buildRerankingConfig( }, nil } -// resolveModelConfig resolves a model name to a ModelConfig -// Handles both inline model references (e.g., "dmr/model-name") and defined model names -func resolveModelConfig(modelName string, buildCfg ManagersBuildConfig) (*latest.ModelConfig, error) { - // Check if it's an inline model reference (contains a '/') - if modelName != "" { - parts := splitModelRef(modelName) - if len(parts) == 2 { - // Inline model reference like "dmr/hf.co/model" or "openai/gpt-5" - slog.Debug("Using inline model reference", - "provider", parts[0], - "model", parts[1]) - return &latest.ModelConfig{ - Provider: parts[0], - Model: parts[1], - }, nil - } - } - - // Try to find model in defined models - if modelCfg, exists := buildCfg.Models[modelName]; exists { - slog.Debug("Using defined model from config", - "model_name", modelName, - "provider", modelCfg.Provider, - "model", modelCfg.Model) - return &modelCfg, nil - } - - slog.Error("Model not found in configuration", - "model_name", modelName, - "available_models", getModelNames(buildCfg.Models)) - return nil, fmt.Errorf("model %q not found in configuration", modelName) -} - -// getModelNames extracts model names from the models map for logging -func getModelNames(models map[string]latest.ModelConfig) []string { - return slices.Collect(maps.Keys(models)) -} - -// splitModelRef splits a model reference into provider and model parts -func splitModelRef(ref string) []string { - // Handle common patterns: "provider/model" - for i := range len(ref) { - if ref[i] == '/' { - return []string{ref[:i], ref[i+1:]} - } - } - return []string{ref} -} - // buildStrategyConfigs builds the strategy configs for the RAG. // Returns a slice of strategy configs and a channel for receiving strategy events. func buildStrategyConfigs( diff --git a/pkg/rag/strategy/embedding.go b/pkg/rag/strategy/embedding.go index cc8239995..0848bf0c5 100644 --- a/pkg/rag/strategy/embedding.go +++ b/pkg/rag/strategy/embedding.go @@ -9,7 +9,6 @@ import ( "github.com/docker/docker-agent/pkg/config" "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/model/provider" - "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/modelsdev" "github.com/docker/docker-agent/pkg/rag/embed" ) @@ -41,8 +40,7 @@ func CreateEmbeddingProvider(ctx context.Context, modelName string, buildCtx Bui return nil, fmt.Errorf("model '%s' not found: %w", modelName, err) } - embedModel, err = provider.New(ctx, &modelCfg, buildCtx.Env, - options.WithGateway(buildCtx.ModelsGateway)) + embedModel, err = buildCtx.NewProvider(ctx, &modelCfg) if err != nil { return nil, fmt.Errorf("failed to create embedding model: %w", err) } @@ -80,8 +78,7 @@ func createAutoEmbeddingModel(ctx context.Context, buildCtx BuildContext) (provi Model: autoModelCfg.Model, } - model, err := provider.New(ctx, &modelCfg, buildCtx.Env, - options.WithGateway(buildCtx.ModelsGateway)) + model, err := buildCtx.NewProvider(ctx, &modelCfg) if err != nil { lastErr = err continue diff --git a/pkg/rag/strategy/semantic_embeddings.go b/pkg/rag/strategy/semantic_embeddings.go index 615cf62d3..e42257cb3 100644 --- a/pkg/rag/strategy/semantic_embeddings.go +++ b/pkg/rag/strategy/semantic_embeddings.go @@ -16,7 +16,6 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/js" "github.com/docker/docker-agent/pkg/model/provider" - "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/rag/chunk" "github.com/docker/docker-agent/pkg/rag/types" "github.com/docker/docker-agent/pkg/tools" @@ -89,8 +88,7 @@ func NewSemanticEmbeddingsFromConfig(ctx context.Context, cfg latest.RAGStrategy return nil, fmt.Errorf("invalid chat_model %q: %w", chatModelName, err) } - chatProvider, err := provider.New(ctx, &chatModelCfg, buildCtx.Env, - options.WithGateway(buildCtx.ModelsGateway)) + chatProvider, err := buildCtx.NewProvider(ctx, &chatModelCfg) if err != nil { return nil, fmt.Errorf("failed to create chat model provider: %w", err) } diff --git a/pkg/rag/strategy/strategy.go b/pkg/rag/strategy/strategy.go index e7801e843..eb1b6e74d 100644 --- a/pkg/rag/strategy/strategy.go +++ b/pkg/rag/strategy/strategy.go @@ -6,20 +6,31 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/model/provider" + "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/rag/types" ) -// BuildContext contains everything needed to build a strategy +// BuildContext contains everything needed to build a strategy. type BuildContext struct { RAGName string ParentDir string SharedDocs []string Models map[string]latest.ModelConfig + Providers map[string]latest.ProviderConfig Env environment.Provider ModelsGateway string RespectVCS bool // Whether to respect VCS ignore files (e.g., .gitignore) when collecting files } +// NewProvider creates a model provider using the build context's environment, +// gateway, and custom provider settings. +func (c BuildContext) NewProvider(ctx context.Context, cfg *latest.ModelConfig) (provider.Provider, error) { + return provider.New(ctx, cfg, c.Env, + options.WithGateway(c.ModelsGateway), + options.WithProviders(c.Providers)) +} + // BuildStrategy builds a strategy from config // Explicitly dispatches to the appropriate constructor based on type func BuildStrategy(ctx context.Context, cfg latest.RAGStrategyConfig, buildCtx BuildContext, events chan<- types.Event) (*Config, error) { diff --git a/pkg/teamloader/registry.go b/pkg/teamloader/registry.go index df128b7a1..8f39572a7 100644 --- a/pkg/teamloader/registry.go +++ b/pkg/teamloader/registry.go @@ -369,6 +369,7 @@ func createRAGTool(ctx context.Context, toolset latest.Toolset, parentDir string ModelsGateway: runConfig.ModelsGateway, Env: runConfig.EnvProvider(), Models: runConfig.Models, + Providers: runConfig.Providers, }) if err != nil { return nil, fmt.Errorf("failed to create RAG manager: %w", err) diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 04b023724..028c43513 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -123,6 +123,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c // Make model definitions available to toolset creators (e.g., RAG reranking) runConfig.Models = cfg.Models + runConfig.Providers = cfg.Providers // Load agents parentDir := cmp.Or(agentSource.ParentDir(), runConfig.WorkingDir)