Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/rag/custom_provider.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This example demonstrates using a custom provider for RAG embedding models.
# For instance, you can use a local Ollama instance or any OpenAI-compatible
# API endpoint for generating embeddings.

providers:
local-ollama:
base_url: http://localhost:11434/v1

models:
local-embed:
provider: local-ollama
model: nomic-embed-text

agents:
root:
model: openai/gpt-5-mini
description: assistant with RAG using custom embedding provider
instruction: |
You are a helpful assistant with access to a knowledge base.
Use the search tool to find relevant information before answering.
toolsets:
- type: rag
ref: knowledge_base

rag:
knowledge_base:
tool:
description: search the knowledge base for relevant information
docs:
- ./docs
strategies:
- type: chunked-embeddings
embedding_model: local-embed # References the model defined above using the custom provider
database: ./custom_provider.db
vector_dimensions: 768
chunking:
size: 1000
overlap: 100
results:
limit: 5
2 changes: 2 additions & 0 deletions pkg/config/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Config struct {
GlobalCodeMode bool
WorkingDir string
Models map[string]latest.ModelConfig
Providers map[string]latest.ProviderConfig

// Hook overrides from CLI flags
HookPreToolUse []string
Expand All @@ -40,6 +41,7 @@ func (runConfig *RuntimeConfig) Clone() *RuntimeConfig {
}
clone.EnvFiles = slices.Clone(runConfig.EnvFiles)
clone.Models = maps.Clone(runConfig.Models)
clone.Providers = maps.Clone(runConfig.Providers)
clone.DefaultModel = runConfig.DefaultModel.Clone()
clone.HookPreToolUse = slices.Clone(runConfig.HookPreToolUse)
clone.HookPostToolUse = slices.Clone(runConfig.HookPostToolUse)
Expand Down
69 changes: 15 additions & 54 deletions pkg/rag/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ import (
"errors"
"fmt"
"log/slog"
"maps"
"slices"

"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/environment"
"github.com/docker/docker-agent/pkg/model/provider"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/rag/rerank"
"github.com/docker/docker-agent/pkg/rag/strategy"
"github.com/docker/docker-agent/pkg/rag/types"
Expand All @@ -21,7 +20,16 @@ type ManagersBuildConfig struct {
ParentDir string
ModelsGateway string
Env environment.Provider
Models map[string]latest.ModelConfig // Model configurations from config
Models map[string]latest.ModelConfig // Model configurations from config
Providers map[string]latest.ProviderConfig // Custom provider configurations from config
}

// NewProvider creates a model provider using the build config's environment,
// gateway, and custom provider settings.
func (c ManagersBuildConfig) NewProvider(ctx context.Context, cfg *latest.ModelConfig) (provider.Provider, error) {
return provider.New(ctx, cfg, c.Env,
options.WithGateway(c.ModelsGateway),
options.WithProviders(c.Providers))
}

// NewManager constructs a single RAG manager from a RAGConfig.
Expand All @@ -46,6 +54,7 @@ func NewManager(
ParentDir: buildCfg.ParentDir,
SharedDocs: GetAbsolutePaths(buildCfg.ParentDir, ragCfg.Docs),
Models: buildCfg.Models,
Providers: buildCfg.Providers,
Env: buildCfg.Env,
ModelsGateway: buildCfg.ModelsGateway,
RespectVCS: ragCfg.GetRespectVCS(),
Expand Down Expand Up @@ -146,20 +155,21 @@ func buildRerankingConfig(
"model_ref", rerankCfg.Model)

// Resolve model config - check if it's a reference to a defined model or inline
modelCfg, err := resolveModelConfig(rerankCfg.Model, buildCfg)
modelCfgVal, err := strategy.ResolveModelConfig(rerankCfg.Model, buildCfg.Models)
if err != nil {
slog.Error("Failed to resolve reranking model",
"model_ref", rerankCfg.Model,
"error", err)
return nil, fmt.Errorf("failed to resolve reranking model %q: %w", rerankCfg.Model, err)
}
modelCfg := &modelCfgVal

slog.Debug("Resolved reranking model config",
"provider", modelCfg.Provider,
"model", modelCfg.Model)

// Create provider for reranking model
rerankProvider, err := provider.New(ctx, modelCfg, buildCfg.Env)
rerankProvider, err := buildCfg.NewProvider(ctx, modelCfg)
if err != nil {
slog.Error("Failed to create reranking provider",
"provider", modelCfg.Provider,
Expand Down Expand Up @@ -206,55 +216,6 @@ func buildRerankingConfig(
}, nil
}

// resolveModelConfig resolves a model name to a ModelConfig
// Handles both inline model references (e.g., "dmr/model-name") and defined model names
func resolveModelConfig(modelName string, buildCfg ManagersBuildConfig) (*latest.ModelConfig, error) {
// Check if it's an inline model reference (contains a '/')
if modelName != "" {
parts := splitModelRef(modelName)
if len(parts) == 2 {
// Inline model reference like "dmr/hf.co/model" or "openai/gpt-5"
slog.Debug("Using inline model reference",
"provider", parts[0],
"model", parts[1])
return &latest.ModelConfig{
Provider: parts[0],
Model: parts[1],
}, nil
}
}

// Try to find model in defined models
if modelCfg, exists := buildCfg.Models[modelName]; exists {
slog.Debug("Using defined model from config",
"model_name", modelName,
"provider", modelCfg.Provider,
"model", modelCfg.Model)
return &modelCfg, nil
}

slog.Error("Model not found in configuration",
"model_name", modelName,
"available_models", getModelNames(buildCfg.Models))
return nil, fmt.Errorf("model %q not found in configuration", modelName)
}

// getModelNames extracts model names from the models map for logging
func getModelNames(models map[string]latest.ModelConfig) []string {
return slices.Collect(maps.Keys(models))
}

// splitModelRef splits a model reference into provider and model parts
func splitModelRef(ref string) []string {
// Handle common patterns: "provider/model"
for i := range len(ref) {
if ref[i] == '/' {
return []string{ref[:i], ref[i+1:]}
}
}
return []string{ref}
}

// buildStrategyConfigs builds the strategy configs for the RAG.
// Returns a slice of strategy configs and a channel for receiving strategy events.
func buildStrategyConfigs(
Expand Down
7 changes: 2 additions & 5 deletions pkg/rag/strategy/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/docker/docker-agent/pkg/config"
"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/model/provider"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/modelsdev"
"github.com/docker/docker-agent/pkg/rag/embed"
)
Expand Down Expand Up @@ -41,8 +40,7 @@ func CreateEmbeddingProvider(ctx context.Context, modelName string, buildCtx Bui
return nil, fmt.Errorf("model '%s' not found: %w", modelName, err)
}

embedModel, err = provider.New(ctx, &modelCfg, buildCtx.Env,
options.WithGateway(buildCtx.ModelsGateway))
embedModel, err = buildCtx.NewProvider(ctx, &modelCfg)
if err != nil {
return nil, fmt.Errorf("failed to create embedding model: %w", err)
}
Expand Down Expand Up @@ -80,8 +78,7 @@ func createAutoEmbeddingModel(ctx context.Context, buildCtx BuildContext) (provi
Model: autoModelCfg.Model,
}

model, err := provider.New(ctx, &modelCfg, buildCtx.Env,
options.WithGateway(buildCtx.ModelsGateway))
model, err := buildCtx.NewProvider(ctx, &modelCfg)
if err != nil {
lastErr = err
continue
Expand Down
4 changes: 1 addition & 3 deletions pkg/rag/strategy/semantic_embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/js"
"github.com/docker/docker-agent/pkg/model/provider"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/rag/chunk"
"github.com/docker/docker-agent/pkg/rag/types"
"github.com/docker/docker-agent/pkg/tools"
Expand Down Expand Up @@ -89,8 +88,7 @@ func NewSemanticEmbeddingsFromConfig(ctx context.Context, cfg latest.RAGStrategy
return nil, fmt.Errorf("invalid chat_model %q: %w", chatModelName, err)
}

chatProvider, err := provider.New(ctx, &chatModelCfg, buildCtx.Env,
options.WithGateway(buildCtx.ModelsGateway))
chatProvider, err := buildCtx.NewProvider(ctx, &chatModelCfg)
if err != nil {
return nil, fmt.Errorf("failed to create chat model provider: %w", err)
}
Expand Down
13 changes: 12 additions & 1 deletion pkg/rag/strategy/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,31 @@ import (

"github.com/docker/docker-agent/pkg/config/latest"
"github.com/docker/docker-agent/pkg/environment"
"github.com/docker/docker-agent/pkg/model/provider"
"github.com/docker/docker-agent/pkg/model/provider/options"
"github.com/docker/docker-agent/pkg/rag/types"
)

// BuildContext contains everything needed to build a strategy
// BuildContext contains everything needed to build a strategy.
type BuildContext struct {
RAGName string
ParentDir string
SharedDocs []string
Models map[string]latest.ModelConfig
Providers map[string]latest.ProviderConfig
Env environment.Provider
ModelsGateway string
RespectVCS bool // Whether to respect VCS ignore files (e.g., .gitignore) when collecting files
}

// NewProvider creates a model provider using the build context's environment,
// gateway, and custom provider settings.
func (c BuildContext) NewProvider(ctx context.Context, cfg *latest.ModelConfig) (provider.Provider, error) {
return provider.New(ctx, cfg, c.Env,
options.WithGateway(c.ModelsGateway),
options.WithProviders(c.Providers))
}

// BuildStrategy builds a strategy from config
// Explicitly dispatches to the appropriate constructor based on type
func BuildStrategy(ctx context.Context, cfg latest.RAGStrategyConfig, buildCtx BuildContext, events chan<- types.Event) (*Config, error) {
Expand Down
1 change: 1 addition & 0 deletions pkg/teamloader/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ func createRAGTool(ctx context.Context, toolset latest.Toolset, parentDir string
ModelsGateway: runConfig.ModelsGateway,
Env: runConfig.EnvProvider(),
Models: runConfig.Models,
Providers: runConfig.Providers,
})
if err != nil {
return nil, fmt.Errorf("failed to create RAG manager: %w", err)
Expand Down
1 change: 1 addition & 0 deletions pkg/teamloader/teamloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c

// Make model definitions available to toolset creators (e.g., RAG reranking)
runConfig.Models = cfg.Models
runConfig.Providers = cfg.Providers

// Load agents
parentDir := cmp.Or(agentSource.ParentDir(), runConfig.WorkingDir)
Expand Down
Loading