From b602037eaa32fdd9ca5095b4e49ee25aa19eeeaf Mon Sep 17 00:00:00 2001 From: Christopher Petito Date: Wed, 8 Apr 2026 18:48:17 +0200 Subject: [PATCH] improve DMR support - adds 'context_size' provider_opt for DMR usage instead of giving 'max_tokens' double responsibility to avoid confusion - improves how flags are sent to the DMR model/runtime configuration endpoint Signed-off-by: Christopher Petito --- docs/providers/dmr/index.md | 44 ++- pkg/model/provider/dmr/client.go | 25 +- pkg/model/provider/dmr/client_test.go | 434 +++++++++++++++++++++----- pkg/model/provider/dmr/configure.go | 301 +++++++++++------- 4 files changed, 597 insertions(+), 207 deletions(-) diff --git a/docs/providers/dmr/index.md b/docs/providers/dmr/index.md index 71f6e46e4..f465de247 100644 --- a/docs/providers/dmr/index.md +++ b/docs/providers/dmr/index.md @@ -64,29 +64,49 @@ models: model: ai/qwen3 max_tokens: 8192 provider_opts: - runtime_flags: ["--ngl=33", "--top-p=0.9"] + runtime_flags: ["--threads", "8"] ``` Runtime flags also accept a single string: ```yaml provider_opts: - runtime_flags: "--ngl=33 --top-p=0.9" + runtime_flags: "--threads 8" ``` -## Parameter Mapping +Use only flags your Model Runner backend allows (see `docker model configure --help` and backend docs). **Do not** put sampling parameters (`temperature`, `top_p`, penalties) in `runtime_flags` — set them on the model (`temperature`, `top_p`, etc.); they are sent **per request** via the OpenAI-compatible chat API. -docker-agent model config fields map to llama.cpp flags automatically: +## Context size -| Config | llama.cpp Flag | -| ------------------- | --------------------- | -| `temperature` | `--temp` | -| `top_p` | `--top-p` | -| `frequency_penalty` | `--frequency-penalty` | -| `presence_penalty` | `--presence-penalty` | -| `max_tokens` | `--context-size` | +`max_tokens` controls the **maximum output tokens** per chat completion request. To set the engine's **total context window**, use `provider_opts.context_size`: -`runtime_flags` always take priority over derived flags on conflict. +```yaml +models: + local: + provider: dmr + model: ai/qwen3 + max_tokens: 4096 # max output tokens (per-request) + provider_opts: + context_size: 32768 # total context window (sent via _configure) +``` + +If `context_size` is omitted, Model Runner uses its default. `max_tokens` is **not** used as the context window. + +## Thinking / reasoning budget + +When using the **llama.cpp** backend, `thinking_budget` is sent as structured `llamacpp.reasoning-budget` on `_configure` (maps to `--reasoning-budget`). String efforts use the same token mapping as other providers; `adaptive` maps to unlimited (`-1`). + +When using the **vLLM** backend, `thinking_budget` is sent as `thinking_token_budget` in each chat completion request. Effort levels map to token counts using the same scale as other providers; `adaptive` maps to unlimited (`-1`). + +```yaml +models: + local: + provider: dmr + model: ai/qwen3 + thinking_budget: medium # llama.cpp: reasoning-budget=8192; vLLM: thinking_token_budget=8192 +``` + +On **MLX** and **SGLang** backends, `thinking_budget` is silently ignored — those engines do not currently expose a per-request reasoning token budget knob. ## Speculative Decoding diff --git a/pkg/model/provider/dmr/client.go b/pkg/model/provider/dmr/client.go index 439e31918..cee62eea0 100644 --- a/pkg/model/provider/dmr/client.go +++ b/pkg/model/provider/dmr/client.go @@ -54,6 +54,7 @@ type Client struct { client openai.Client baseURL string httpClient *http.Client + engine string } // NewClient creates a new DMR client from the provided configuration @@ -103,18 +104,13 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt clientOptions = append(clientOptions, option.WithBaseURL(baseURL), option.WithAPIKey("")) // DMR doesn't need auth - // Build runtime flags from ModelConfig and engine - contextSize, providerRuntimeFlags, specOpts := parseDMRProviderOpts(cfg) - configFlags := buildRuntimeFlagsFromModelConfig(engine, cfg) - finalFlags, warnings := mergeRuntimeFlagsPreferUser(configFlags, providerRuntimeFlags) - for _, w := range warnings { - slog.Warn(w) - } - slog.Debug("DMR provider_opts parsed", "model", cfg.Model, "context_size", contextSize, "runtime_flags", finalFlags, "speculative_opts", specOpts, "engine", engine) + contextSize, runtimeFlags, specOpts, llamaCpp := parseDMRProviderOpts(engine, cfg) + backendCfg := buildConfigureBackendConfig(contextSize, runtimeFlags, specOpts, llamaCpp) + slog.Debug("DMR provider_opts parsed", "model", cfg.Model, "context_size", derefInt64(contextSize), "runtime_flags", runtimeFlags, "speculative_opts", specOpts, "llamacpp", llamaCpp, "engine", engine) // Skip model configuration when generating titles to avoid reconfiguring the model // with different settings (e.g., smaller max_tokens) that would affect the main agent. if !globalOptions.GeneratingTitle() { - if err := configureModel(ctx, httpClient, baseURL, cfg.Model, contextSize, finalFlags, specOpts); err != nil { + if err := configureModel(ctx, httpClient, baseURL, cfg.Model, backendCfg); err != nil { slog.Debug("model configure via API skipped or failed", "error", err) } } @@ -129,6 +125,7 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, opts ...options.Opt client: openai.NewClient(clientOptions...), baseURL: baseURL, httpClient: httpClient, + engine: engine, }, nil } @@ -214,6 +211,14 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat } } + // For vLLM, apply engine-specific per-request fields (e.g. thinking_token_budget). + if c.engine == engineVLLM { + if fields := buildVLLMRequestFields(&c.ModelConfig); fields != nil { + params.SetExtraFields(fields) + slog.Debug("DMR vLLM extra request fields applied", "fields", fields) + } + } + // Log the request in JSON format for debugging if requestJSON, err := json.Marshal(params); err == nil { slog.Debug("DMR chat completion request", "request", string(requestJSON)) @@ -222,7 +227,7 @@ func (c *Client) CreateChatCompletionStream(ctx context.Context, messages []chat } if structuredOutput := c.ModelOptions.StructuredOutput(); structuredOutput != nil { - slog.Debug("Adding structured output to DMR request", "structured_output", structuredOutput) + slog.Debug("Adding structured output to DMR request", "name", structuredOutput.Name, "strict", structuredOutput.Strict) params.ResponseFormat.OfJSONSchema = &openai.ResponseFormatJSONSchemaParam{ JSONSchema: openai.ResponseFormatJSONSchemaJSONSchemaParam{ diff --git a/pkg/model/provider/dmr/client_test.go b/pkg/model/provider/dmr/client_test.go index cfe9de28c..fc55d6465 100644 --- a/pkg/model/provider/dmr/client_test.go +++ b/pkg/model/provider/dmr/client_test.go @@ -201,30 +201,40 @@ func TestBuildConfigureRequest(t *testing.T) { acceptanceRate: 0.8, } contextSize := int64(8192) + backendCfg := buildConfigureBackendConfig(&contextSize, []string{"--threads", "8"}, specOpts, nil) - req := buildConfigureRequest("ai/qwen3:14B-Q6_K", &contextSize, []string{"--temp", "0.7", "--top-p", "0.9"}, specOpts) + req := buildConfigureRequest("ai/qwen3:14B-Q6_K", backendCfg) assert.Equal(t, "ai/qwen3:14B-Q6_K", req.Model) require.NotNil(t, req.ContextSize) assert.Equal(t, int32(8192), *req.ContextSize) - assert.Equal(t, []string{"--temp", "0.7", "--top-p", "0.9"}, req.RuntimeFlags) + assert.Equal(t, []string{"--threads", "8"}, req.RuntimeFlags) require.NotNil(t, req.Speculative) assert.Equal(t, "ai/qwen3:1B", req.Speculative.DraftModel) assert.Equal(t, 5, req.Speculative.NumTokens) assert.InEpsilon(t, 0.8, req.Speculative.MinAcceptanceRate, 0.001) + assert.Nil(t, req.Mode) + assert.Empty(t, req.RawRuntimeFlags) + assert.Nil(t, req.KeepAlive) + assert.Nil(t, req.VLLM) }) t.Run("without speculative options", func(t *testing.T) { t.Parallel() contextSize := int64(4096) + backendCfg := buildConfigureBackendConfig(&contextSize, []string{"--threads", "8"}, nil, nil) - req := buildConfigureRequest("ai/qwen3:14B-Q6_K", &contextSize, []string{"--threads", "8"}, nil) + req := buildConfigureRequest("ai/qwen3:14B-Q6_K", backendCfg) assert.Equal(t, "ai/qwen3:14B-Q6_K", req.Model) require.NotNil(t, req.ContextSize) assert.Equal(t, int32(4096), *req.ContextSize) assert.Equal(t, []string{"--threads", "8"}, req.RuntimeFlags) assert.Nil(t, req.Speculative) + assert.Nil(t, req.Mode) + assert.Empty(t, req.RawRuntimeFlags) + assert.Nil(t, req.KeepAlive) + assert.Nil(t, req.VLLM) }) t.Run("without context size", func(t *testing.T) { @@ -233,8 +243,9 @@ func TestBuildConfigureRequest(t *testing.T) { draftModel: "ai/qwen3:1B", numTokens: 5, } + backendCfg := buildConfigureBackendConfig(nil, nil, specOpts, nil) - req := buildConfigureRequest("ai/qwen3:14B-Q6_K", nil, nil, specOpts) + req := buildConfigureRequest("ai/qwen3:14B-Q6_K", backendCfg) assert.Equal(t, "ai/qwen3:14B-Q6_K", req.Model) assert.Nil(t, req.ContextSize) @@ -242,16 +253,41 @@ func TestBuildConfigureRequest(t *testing.T) { require.NotNil(t, req.Speculative) assert.Equal(t, "ai/qwen3:1B", req.Speculative.DraftModel) assert.Equal(t, 5, req.Speculative.NumTokens) + assert.Nil(t, req.Mode) + assert.Empty(t, req.RawRuntimeFlags) + assert.Nil(t, req.KeepAlive) + assert.Nil(t, req.VLLM) }) t.Run("minimal config", func(t *testing.T) { t.Parallel() - req := buildConfigureRequest("ai/qwen3:14B-Q6_K", nil, nil, nil) + backendCfg := buildConfigureBackendConfig(nil, nil, nil, nil) + req := buildConfigureRequest("ai/qwen3:14B-Q6_K", backendCfg) assert.Equal(t, "ai/qwen3:14B-Q6_K", req.Model) assert.Nil(t, req.ContextSize) assert.Nil(t, req.RuntimeFlags) assert.Nil(t, req.Speculative) + assert.Nil(t, req.LlamaCpp) + assert.Nil(t, req.Mode) + assert.Empty(t, req.RawRuntimeFlags) + assert.Nil(t, req.KeepAlive) + assert.Nil(t, req.VLLM) + }) + + t.Run("with llama.cpp reasoning budget", func(t *testing.T) { + t.Parallel() + rb := int32(16384) + llama := &llamaCppConfig{ReasoningBudget: &rb} + backendCfg := buildConfigureBackendConfig(nil, nil, nil, llama) + req := buildConfigureRequest("ai/qwen3:14B-Q6_K", backendCfg) + require.NotNil(t, req.LlamaCpp) + require.NotNil(t, req.LlamaCpp.ReasoningBudget) + assert.Equal(t, int32(16384), *req.LlamaCpp.ReasoningBudget) + assert.Nil(t, req.Mode) + assert.Empty(t, req.RawRuntimeFlags) + assert.Nil(t, req.KeepAlive) + assert.Nil(t, req.VLLM) }) } @@ -289,15 +325,16 @@ func TestConfigureModelViaAPI(t *testing.T) { numTokens: 5, acceptanceRate: 0.8, } + backendCfg := buildConfigureBackendConfig(&contextSize, []string{"--threads", "8"}, specOpts, nil) - err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", &contextSize, []string{"--temp", "0.7"}, specOpts) + err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", backendCfg) require.NoError(t, err) // Verify request body assert.Equal(t, "ai/qwen3:14B", receivedRequest.Model) require.NotNil(t, receivedRequest.ContextSize) assert.Equal(t, int32(8192), *receivedRequest.ContextSize) - assert.Equal(t, []string{"--temp", "0.7"}, receivedRequest.RuntimeFlags) + assert.Equal(t, []string{"--threads", "8"}, receivedRequest.RuntimeFlags) require.NotNil(t, receivedRequest.Speculative) assert.Equal(t, "ai/qwen3:1B", receivedRequest.Speculative.DraftModel) assert.Equal(t, 5, receivedRequest.Speculative.NumTokens) @@ -314,7 +351,7 @@ func TestConfigureModelViaAPI(t *testing.T) { defer server.Close() baseURL := server.URL + "/engines/v1/" - err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", nil, nil, nil) + err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", buildConfigureBackendConfig(nil, nil, nil, nil)) require.Error(t, err) assert.Contains(t, err.Error(), "500") assert.Contains(t, err.Error(), "internal error") @@ -330,71 +367,20 @@ func TestConfigureModelViaAPI(t *testing.T) { defer server.Close() baseURL := server.URL + "/engines/v1/" - err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", nil, nil, nil) + err := configureModel(t.Context(), server.Client(), baseURL, "ai/qwen3:14B", buildConfigureBackendConfig(nil, nil, nil, nil)) require.Error(t, err) assert.Contains(t, err.Error(), "409") assert.Contains(t, err.Error(), "runner already active") }) } -func TestBuildRuntimeFlagsFromModelConfig_LlamaCpp(t *testing.T) { - t.Parallel() - - flags := buildRuntimeFlagsFromModelConfig("llama.cpp", &latest.ModelConfig{ - Temperature: new(0.6), - TopP: new(0.95), - FrequencyPenalty: new(0.2), - PresencePenalty: new(0.1), - }) - - assert.Equal(t, []string{"--temp", "0.6", "--top-p", "0.95", "--frequency-penalty", "0.2", "--presence-penalty", "0.1"}, flags) -} - -func TestIntegrateFlagsWithProviderOptsOrder(t *testing.T) { - t.Parallel() - - cfg := &latest.ModelConfig{ - Temperature: new(0.6), - TopP: new(0.9), - MaxTokens: new(int64(4096)), - ProviderOpts: map[string]any{ - "runtime_flags": []string{"--threads", "6"}, - }, - } - // derive config flags first, then merge provider opts (simulating NewClient path) - derived := buildRuntimeFlagsFromModelConfig("llama.cpp", cfg) - // provider opts should be appended after derived flags so they can override by order - merged := append(derived, []string{"--threads", "6"}...) - - req := buildConfigureRequest("ai/qwen3:14B-Q6_K", cfg.MaxTokens, merged, nil) - assert.Equal(t, "ai/qwen3:14B-Q6_K", req.Model) - require.NotNil(t, req.ContextSize) - assert.Equal(t, int32(4096), *req.ContextSize) - assert.Equal(t, []string{"--temp", "0.6", "--top-p", "0.9", "--threads", "6"}, req.RuntimeFlags) -} - -func TestMergeRuntimeFlagsPreferUser_WarnsAndPrefersUser(t *testing.T) { - t.Parallel() - - // Derived suggests temp/top-p, user overrides both and adds threads - derived := []string{"--temp", "0.5", "--top-p", "0.8"} - user := []string{"--temp", "0.7", "--threads", "8"} - - merged, warnings := mergeRuntimeFlagsPreferUser(derived, user) - - // Expect 1 warnings for --temp overriding - require.Len(t, warnings, 1) - - // Derived conflicting flags should be dropped, user ones kept and appended - assert.Equal(t, []string{"--top-p", "0.8", "--temp", "0.7", "--threads", "8"}, merged) -} - func TestParseDMRProviderOptsWithSpeculativeDecoding(t *testing.T) { t.Parallel() cfg := &latest.ModelConfig{ MaxTokens: new(int64(4096)), ProviderOpts: map[string]any{ + "context_size": int64(16384), "speculative_draft_model": "ai/qwen3:1B", "speculative_num_tokens": "5", "speculative_acceptance_rate": "0.75", @@ -402,14 +388,16 @@ func TestParseDMRProviderOptsWithSpeculativeDecoding(t *testing.T) { }, } - contextSize, runtimeFlags, specOpts := parseDMRProviderOpts(cfg) + contextSize, runtimeFlags, specOpts, llamaCpp := parseDMRProviderOpts("llama.cpp", cfg) - assert.Equal(t, int64(4096), *contextSize) + require.NotNil(t, contextSize) + assert.Equal(t, int64(16384), *contextSize) assert.Equal(t, []string{"--threads", "8"}, runtimeFlags) require.NotNil(t, specOpts) assert.Equal(t, "ai/qwen3:1B", specOpts.draftModel) assert.Equal(t, 5, specOpts.numTokens) assert.InEpsilon(t, 0.75, specOpts.acceptanceRate, 0.001) + assert.Nil(t, llamaCpp) } func TestParseDMRProviderOptsWithoutSpeculativeDecoding(t *testing.T) { @@ -422,11 +410,256 @@ func TestParseDMRProviderOptsWithoutSpeculativeDecoding(t *testing.T) { }, } - contextSize, runtimeFlags, specOpts := parseDMRProviderOpts(cfg) + contextSize, runtimeFlags, specOpts, llamaCpp := parseDMRProviderOpts("llama.cpp", cfg) - assert.Equal(t, int64(4096), *contextSize) + assert.Nil(t, contextSize, "context_size not in provider_opts, should be nil regardless of max_tokens") assert.Equal(t, []string{"--threads", "8"}, runtimeFlags) assert.Nil(t, specOpts) + assert.Nil(t, llamaCpp) +} + +func TestParseDMRProviderOptsContextSizeFromProviderOpts(t *testing.T) { + t.Parallel() + + cfg := &latest.ModelConfig{ + MaxTokens: new(int64(4096)), + ProviderOpts: map[string]any{ + "context_size": int64(32768), + }, + } + + contextSize, rf, spec, ll := parseDMRProviderOpts("llama.cpp", cfg) + require.NotNil(t, contextSize) + assert.Equal(t, int64(32768), *contextSize) + assert.Nil(t, rf) + assert.Nil(t, spec) + assert.Nil(t, ll) +} + +func TestParseDMRProviderOptsContextSizeNeitherSet(t *testing.T) { + t.Parallel() + + cfg := &latest.ModelConfig{ + Provider: "dmr", + Model: "ai/qwen3", + } + + contextSize, rf, spec, ll := parseDMRProviderOpts("llama.cpp", cfg) + assert.Nil(t, contextSize) + assert.Nil(t, rf) + assert.Nil(t, spec) + assert.Nil(t, ll) +} + +func TestParseDMRProviderOptsThinkingBudget(t *testing.T) { + t.Parallel() + + t.Run("llama.cpp: effort maps to token budget", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "medium"}, + } + _, _, _, llamaCpp := parseDMRProviderOpts("llama.cpp", cfg) + require.NotNil(t, llamaCpp) + require.NotNil(t, llamaCpp.ReasoningBudget) + assert.Equal(t, int32(8192), *llamaCpp.ReasoningBudget) + }) + + t.Run("llama.cpp: explicit tokens", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Tokens: 2048}, + } + _, _, _, llamaCpp := parseDMRProviderOpts("llama.cpp", cfg) + require.NotNil(t, llamaCpp) + require.NotNil(t, llamaCpp.ReasoningBudget) + assert.Equal(t, int32(2048), *llamaCpp.ReasoningBudget) + }) + + t.Run("llama.cpp: disabled", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "none"}, + } + _, _, _, llamaCpp := parseDMRProviderOpts("llama.cpp", cfg) + require.NotNil(t, llamaCpp) + require.NotNil(t, llamaCpp.ReasoningBudget) + assert.Equal(t, int32(0), *llamaCpp.ReasoningBudget) + }) + + t.Run("empty engine defaults to llama.cpp", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Tokens: 4096}, + } + _, _, _, llamaCpp := parseDMRProviderOpts("", cfg) + require.NotNil(t, llamaCpp) + require.NotNil(t, llamaCpp.ReasoningBudget) + assert.Equal(t, int32(4096), *llamaCpp.ReasoningBudget) + }) + + t.Run("vllm engine: no llamacpp config (thinking handled per-request)", func(t *testing.T) { + t.Parallel() + cfg := &latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "high"}, + } + _, _, _, llamaCpp := parseDMRProviderOpts("vllm", cfg) + assert.Nil(t, llamaCpp, "vllm engine should not produce llamacpp config; thinking_budget is sent per-request instead") + }) +} + +func TestBuildVLLMRequestFields(t *testing.T) { + t.Parallel() + + t.Run("nil config returns nil", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(nil) + assert.Nil(t, fields) + }) + + t.Run("nil budget returns nil", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{}) + assert.Nil(t, fields) + }) + + t.Run("disabled (effort none) returns 0", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "none"}, + }) + require.NotNil(t, fields) + assert.Equal(t, int64(0), fields["thinking_token_budget"]) + }) + + t.Run("explicit token count", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Tokens: 4096}, + }) + require.NotNil(t, fields) + assert.Equal(t, int64(4096), fields["thinking_token_budget"]) + }) + + t.Run("effort medium maps to 8192", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "medium"}, + }) + require.NotNil(t, fields) + assert.Equal(t, int64(8192), fields["thinking_token_budget"]) + }) + + t.Run("effort high maps to 16384", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "high"}, + }) + require.NotNil(t, fields) + assert.Equal(t, int64(16384), fields["thinking_token_budget"]) + }) + + t.Run("adaptive returns -1 (unlimited)", func(t *testing.T) { + t.Parallel() + fields := buildVLLMRequestFields(&latest.ModelConfig{ + ThinkingBudget: &latest.ThinkingBudget{Effort: "adaptive"}, + }) + require.NotNil(t, fields) + assert.Equal(t, int64(-1), fields["thinking_token_budget"]) + }) +} + +func TestResolveReasoningBudget(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input *latest.ThinkingBudget + wantBudget int64 + wantOK bool + }{ + { + name: "nil → (0, false)", + input: nil, + wantBudget: 0, + wantOK: false, + }, + { + name: "disabled via Tokens:0 → (0, true)", + input: &latest.ThinkingBudget{Tokens: 0}, + wantBudget: 0, + wantOK: true, + }, + { + name: "disabled via Effort:none → (0, true)", + input: &latest.ThinkingBudget{Effort: "none"}, + wantBudget: 0, + wantOK: true, + }, + { + name: "explicit Tokens:4096 → (4096, true)", + input: &latest.ThinkingBudget{Tokens: 4096}, + wantBudget: 4096, + wantOK: true, + }, + { + name: "explicit Tokens:-1 (dynamic) → (-1, true)", + input: &latest.ThinkingBudget{Tokens: -1}, + wantBudget: -1, + wantOK: true, + }, + { + name: "Effort:minimal → (1024, true)", + input: &latest.ThinkingBudget{Effort: "minimal"}, + wantBudget: 1024, + wantOK: true, + }, + { + name: "Effort:low → (2048, true)", + input: &latest.ThinkingBudget{Effort: "low"}, + wantBudget: 2048, + wantOK: true, + }, + { + name: "Effort:medium → (8192, true)", + input: &latest.ThinkingBudget{Effort: "medium"}, + wantBudget: 8192, + wantOK: true, + }, + { + name: "Effort:high → (16384, true)", + input: &latest.ThinkingBudget{Effort: "high"}, + wantBudget: 16384, + wantOK: true, + }, + { + name: "Effort:adaptive → (-1, true)", + input: &latest.ThinkingBudget{Effort: "adaptive"}, + wantBudget: -1, + wantOK: true, + }, + { + name: "Effort:adaptive/low → (-1, true)", + input: &latest.ThinkingBudget{Effort: "adaptive/low"}, + wantBudget: -1, + wantOK: true, + }, + { + name: "Effort:unknown → (-1, true)", + input: &latest.ThinkingBudget{Effort: "unknown"}, + wantBudget: -1, + wantOK: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + gotBudget, gotOK := resolveReasoningBudget(tt.input) + assert.Equal(t, tt.wantBudget, gotBudget) + assert.Equal(t, tt.wantOK, gotOK) + }) + } } func TestConfigureRequestJSONSerialization(t *testing.T) { @@ -435,14 +668,18 @@ func TestConfigureRequestJSONSerialization(t *testing.T) { t.Run("full request serializes correctly", func(t *testing.T) { t.Parallel() contextSize := int32(8192) + reasoning := int32(-1) req := configureRequest{ - Model: "ai/qwen3:14B", - ContextSize: &contextSize, - RuntimeFlags: []string{"--temp", "0.7"}, - Speculative: &speculativeDecodingRequest{ - DraftModel: "ai/qwen3:1B", - NumTokens: 5, - MinAcceptanceRate: 0.8, + Model: "ai/qwen3:14B", + configureBackendConfig: configureBackendConfig{ + ContextSize: &contextSize, + RuntimeFlags: []string{"--keep-alive", "5m"}, + Speculative: &speculativeDecodingRequest{ + DraftModel: "ai/qwen3:1B", + NumTokens: 5, + MinAcceptanceRate: 0.8, + }, + LlamaCpp: &llamaCppConfig{ReasoningBudget: &reasoning}, }, } @@ -455,13 +692,17 @@ func TestConfigureRequestJSONSerialization(t *testing.T) { assert.Equal(t, "ai/qwen3:14B", parsed["model"]) assert.InEpsilon(t, float64(8192), parsed["context-size"].(float64), 0.001) - assert.Equal(t, []any{"--temp", "0.7"}, parsed["runtime-flags"]) + assert.Equal(t, []any{"--keep-alive", "5m"}, parsed["runtime-flags"]) spec, ok := parsed["speculative"].(map[string]any) require.True(t, ok) assert.Equal(t, "ai/qwen3:1B", spec["draft_model"]) assert.InEpsilon(t, float64(5), spec["num_tokens"].(float64), 0.001) assert.InEpsilon(t, 0.8, spec["min_acceptance_rate"].(float64), 0.001) + + llama, ok := parsed["llamacpp"].(map[string]any) + require.True(t, ok) + assert.InEpsilon(t, float64(-1), llama["reasoning-budget"].(float64), 0.001) }) t.Run("minimal request omits nil fields", func(t *testing.T) { @@ -484,5 +725,52 @@ func TestConfigureRequestJSONSerialization(t *testing.T) { assert.False(t, hasRuntimeFlags, "runtime-flags should be omitted when nil") _, hasSpeculative := parsed["speculative"] assert.False(t, hasSpeculative, "speculative should be omitted when nil") + _, hasLlamaCpp := parsed["llamacpp"] + assert.False(t, hasLlamaCpp, "llamacpp should be omitted when nil") + _, hasMode := parsed["mode"] + assert.False(t, hasMode, "mode should be omitted when nil") + _, hasRawRuntimeFlags := parsed["raw-runtime-flags"] + assert.False(t, hasRawRuntimeFlags, "raw-runtime-flags should be omitted when empty") + _, hasKeepAlive := parsed["keep_alive"] + assert.False(t, hasKeepAlive, "keep_alive should be omitted when nil") + _, hasVLLM := parsed["vllm"] + assert.False(t, hasVLLM, "vllm should be omitted when nil") + }) + + t.Run("schema parity fields serialize with expected keys", func(t *testing.T) { + t.Parallel() + mode := "completion" + keepAlive := "5m" + gpu := 0.9 + req := configureRequest{ + Model: "ai/qwen3:14B", + Mode: &mode, + RawRuntimeFlags: "--foo --bar", + configureBackendConfig: configureBackendConfig{ + KeepAlive: &keepAlive, + VLLM: &vllmConfig{ + HFOverrides: map[string]any{"foo": "bar"}, + GPUMemoryUtilization: &gpu, + }, + }, + } + + data, err := json.Marshal(req) + require.NoError(t, err) + + var parsed map[string]any + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, "completion", parsed["mode"]) + assert.Equal(t, "--foo --bar", parsed["raw-runtime-flags"]) + assert.Equal(t, "5m", parsed["keep_alive"]) + + vllm, ok := parsed["vllm"].(map[string]any) + require.True(t, ok) + hfOverrides, ok := vllm["hf-overrides"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "bar", hfOverrides["foo"]) + assert.InEpsilon(t, 0.9, vllm["gpu-memory-utilization"].(float64), 0.001) }) } diff --git a/pkg/model/provider/dmr/configure.go b/pkg/model/provider/dmr/configure.go index 17573013f..5a0150f43 100644 --- a/pkg/model/provider/dmr/configure.go +++ b/pkg/model/provider/dmr/configure.go @@ -2,7 +2,6 @@ package dmr import ( "bytes" - "cmp" "context" "encoding/json" "fmt" @@ -16,13 +15,45 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" ) -// configureRequest mirrors the model-runner's scheduling.ConfigureRequest structure. -// It specifies per-model runtime configuration options sent via POST /engines/_configure. +// configureRequest mirrors model-runner's scheduling.ConfigureRequest. type configureRequest struct { - Model string `json:"model"` + configureBackendConfig + + Model string `json:"model"` + Mode *string `json:"mode,omitempty"` + RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"` +} + +// configureBackendConfig mirrors model-runner's inference.BackendConfiguration. +type configureBackendConfig struct { ContextSize *int32 `json:"context-size,omitempty"` RuntimeFlags []string `json:"runtime-flags,omitempty"` Speculative *speculativeDecodingRequest `json:"speculative,omitempty"` + KeepAlive *string `json:"keep_alive,omitempty"` + VLLM *vllmConfig `json:"vllm,omitempty"` + LlamaCpp *llamaCppConfig `json:"llamacpp,omitempty"` +} + +// vllmConfig mirrors model-runner's inference.VLLMConfig for POST /engines/_configure. +type vllmConfig struct { + HFOverrides map[string]any `json:"hf-overrides,omitempty"` + GPUMemoryUtilization *float64 `json:"gpu-memory-utilization,omitempty"` +} + +// llamaCppConfig mirrors model-runner's inference.LlamaCppConfig for POST /engines/_configure. +type llamaCppConfig struct { + ReasoningBudget *int32 `json:"reasoning-budget,omitempty"` +} + +func (c *llamaCppConfig) LogValue() slog.Value { + if c == nil { + return slog.AnyValue(nil) + } + var rb any + if c.ReasoningBudget != nil { + rb = *c.ReasoningBudget + } + return slog.GroupValue(slog.Any("reasoning-budget", rb)) } // speculativeDecodingRequest mirrors model-runner's inference.SpeculativeDecodingConfig. @@ -38,14 +69,25 @@ type speculativeDecodingOpts struct { acceptanceRate float64 } +func (so *speculativeDecodingOpts) LogValue() slog.Value { + if so == nil { + return slog.AnyValue(nil) + } + return slog.GroupValue( + slog.String("draft-model", so.draftModel), + slog.Int("num-tokens", so.numTokens), + slog.Float64("acceptance-rate", so.acceptanceRate), + ) +} + // configureModel sends model configuration to Model Runner via POST /engines/_configure. -func configureModel(ctx context.Context, httpClient *http.Client, baseURL, model string, contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) error { +func configureModel(ctx context.Context, httpClient *http.Client, baseURL, model string, backend configureBackendConfig) error { if httpClient == nil { httpClient = &http.Client{} } configureURL := buildConfigureURL(baseURL) - reqData, err := json.Marshal(buildConfigureRequest(model, contextSize, runtimeFlags, specOpts)) + reqData, err := json.Marshal(buildConfigureRequest(model, backend)) if err != nil { return fmt.Errorf("failed to marshal configure request: %w", err) } @@ -62,9 +104,12 @@ func configureModel(ctx context.Context, httpClient *http.Client, baseURL, model slog.Debug("Sending model configure request", "model", model, "url", configureURL, - "context_size", contextSize, - "runtime_flags", runtimeFlags, - "speculative_opts", specOpts) + "context_size", derefInt32(backend.ContextSize), + "runtime_flags", backend.RuntimeFlags, + "speculative_opts", backend.Speculative, + "llamacpp", backend.LlamaCpp, + "keep_alive", backend.KeepAlive, + "vllm", backend.VLLM) resp, err := httpClient.Do(req) if err != nil { @@ -97,116 +142,31 @@ func buildConfigureURL(baseURL string) string { return u.String() } -// buildConfigureRequest constructs the JSON request body for POST /engines/_configure. -func buildConfigureRequest(model string, contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) configureRequest { - req := configureRequest{ - Model: model, +func buildConfigureBackendConfig(contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts, llamaCpp *llamaCppConfig) configureBackendConfig { + cfg := configureBackendConfig{ RuntimeFlags: runtimeFlags, + LlamaCpp: llamaCpp, } - if contextSize != nil { cs := int32(*contextSize) - req.ContextSize = &cs + cfg.ContextSize = &cs } - if specOpts != nil { - req.Speculative = &speculativeDecodingRequest{ + cfg.Speculative = &speculativeDecodingRequest{ DraftModel: specOpts.draftModel, NumTokens: specOpts.numTokens, MinAcceptanceRate: specOpts.acceptanceRate, } } - - return req -} - -// mergeRuntimeFlagsPreferUser merges derived engine flags (from model config fields like -// `temperature`) and user-provided runtime flags (from `provider_opts.runtime_flags`). -// When both specify the same flag key (e.g. --temp), the user value wins and a warning -// is returned. Order: non-conflicting derived flags first, then all user flags. -func mergeRuntimeFlagsPreferUser(derived, user []string) (merged, warnings []string) { - // parsedFlag holds a parsed flag token (e.g. "--temp 0.5" → key="--temp", tokens=["--temp","0.5"]). - type parsedFlag struct { - key string - tokens []string - } - - parse := func(args []string) []parsedFlag { - var out []parsedFlag - for i := 0; i < len(args); i++ { - tok := args[i] - if !strings.HasPrefix(tok, "-") { - out = append(out, parsedFlag{key: tok, tokens: []string{tok}}) - continue - } - // --key=value - if k, _, found := strings.Cut(tok, "="); found { - out = append(out, parsedFlag{key: k, tokens: []string{tok}}) - continue - } - // --key value (next token is the value if it doesn't start with -) - if i+1 < len(args) && !strings.HasPrefix(args[i+1], "-") { - out = append(out, parsedFlag{key: tok, tokens: []string{tok, args[i+1]}}) - i++ - } else { - out = append(out, parsedFlag{key: tok, tokens: []string{tok}}) - } - } - return out - } - - derFlags := parse(derived) - usrFlags := parse(user) - - // Build a set of flag keys the user explicitly provides. - userKeys := make(map[string]bool, len(usrFlags)) - for _, f := range usrFlags { - if strings.HasPrefix(f.key, "-") { - userKeys[f.key] = true - } - } - - // Emit non-conflicting derived flags; warn on conflicts. - for _, f := range derFlags { - if strings.HasPrefix(f.key, "-") && userKeys[f.key] { - warnings = append(warnings, "Overriding runtime flag "+f.key+" with value from provider_opts.runtime_flags") - continue - } - merged = append(merged, f.tokens...) - } - for _, f := range usrFlags { - merged = append(merged, f.tokens...) - } - return merged, warnings + return cfg } -// buildRuntimeFlagsFromModelConfig converts standard ModelConfig fields into backend-specific -// runtime flags that the model-runner understands when launching the engine. -// Currently supports "llama.cpp". Unknown engines produce no flags. -func buildRuntimeFlagsFromModelConfig(engine string, cfg *latest.ModelConfig) []string { - if cfg == nil { - return nil - } - - eng := cmp.Or(strings.TrimSpace(engine), "llama.cpp") - if eng != "llama.cpp" { - return nil - } - - var flags []string - if cfg.Temperature != nil { - flags = append(flags, "--temp", strconv.FormatFloat(*cfg.Temperature, 'f', -1, 64)) - } - if cfg.TopP != nil { - flags = append(flags, "--top-p", strconv.FormatFloat(*cfg.TopP, 'f', -1, 64)) - } - if cfg.FrequencyPenalty != nil { - flags = append(flags, "--frequency-penalty", strconv.FormatFloat(*cfg.FrequencyPenalty, 'f', -1, 64)) - } - if cfg.PresencePenalty != nil { - flags = append(flags, "--presence-penalty", strconv.FormatFloat(*cfg.PresencePenalty, 'f', -1, 64)) +// buildConfigureRequest constructs the JSON request body for POST /engines/_configure. +func buildConfigureRequest(model string, backend configureBackendConfig) configureRequest { + return configureRequest{ + Model: model, + configureBackendConfig: backend, } - return flags } // parseFloat64 attempts to parse a value as float64 from various types. @@ -240,25 +200,142 @@ func parseInt(v any) (int, bool) { return 0, false } +// parseInt64Value parses an int64 from YAML/JSON-decoded values (int, float64, string). +func parseInt64Value(v any) (int64, bool) { + switch t := v.(type) { + case int64: + return t, true + case int: + return int64(t), true + case float64: + return int64(t), true + case string: + s := strings.TrimSpace(t) + if s == "" { + return 0, false + } + n, err := strconv.ParseInt(s, 10, 64) + return n, err == nil + default: + return 0, false + } +} + +// parseContextSize extracts context_size from provider_opts. +// Returns nil when unset, letting model-runner use its default. +func parseContextSize(opts map[string]any) *int64 { + if len(opts) == 0 { + return nil + } + v, ok := opts["context_size"] + if !ok { + return nil + } + if n, ok := parseInt64Value(v); ok { + return &n + } + return nil +} + +// resolveReasoningBudget normalizes a ThinkingBudget to a token count understood by model-runner backends: +// - nil → (0, false) — budget unset, caller should omit the field entirely +// - disabled → (0, true) — budget explicitly disabled, caller should send 0 +// - tokens > 0 → (n, true) — explicit token count +// - adaptive / unknown effort → (-1, true) — unlimited +// - named effort → mapped token count +func resolveReasoningBudget(tb *latest.ThinkingBudget) (budget int64, ok bool) { + if tb == nil { + return 0, false + } + if tb.IsDisabled() { + return 0, true + } + if tb.Tokens != 0 || tb.Effort == "" { + return int64(tb.Tokens), true + } + if tb.IsAdaptive() { + return -1, true + } + if tok, ok := tb.EffortTokens(); ok { + return int64(tok), true + } + return -1, true // unknown effort → unlimited +} + +// buildLlamaCppConfig constructs the llamacpp engine configuration from the model config. +// Currently maps thinking_budget to model-runner's llamacpp.reasoning-budget. +// Returns nil when no relevant config is set. +func buildLlamaCppConfig(cfg *latest.ModelConfig) *llamaCppConfig { + if cfg == nil { + return nil + } + budget, ok := resolveReasoningBudget(cfg.ThinkingBudget) + if !ok { + return nil + } + v := int32(budget) + return &llamaCppConfig{ReasoningBudget: &v} +} + +// buildVLLMRequestFields constructs per-request extra fields for the vLLM engine. +// Currently maps thinking_budget to vLLM's thinking_token_budget sampling parameter. +// Returns nil when no extra fields are needed. +func buildVLLMRequestFields(cfg *latest.ModelConfig) map[string]any { + if cfg == nil { + return nil + } + budget, ok := resolveReasoningBudget(cfg.ThinkingBudget) + if !ok { + return nil + } + return map[string]any{"thinking_token_budget": budget} +} + +func derefInt32(p *int32) any { + if p == nil { + return nil + } + return *p +} + +// derefInt64 safely dereferences a *int64 for logging. Returns nil for nil pointers +// so slog renders "" instead of a memory address. +func derefInt64(p *int64) any { + if p == nil { + return nil + } + return *p +} + +const ( + engineLlamaCpp = "llama.cpp" + engineVLLM = "vllm" +) + // parseDMRProviderOpts extracts DMR-specific provider options from the model config: -// context size, runtime flags, and speculative decoding settings. -func parseDMRProviderOpts(cfg *latest.ModelConfig) (contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts) { +// context size, runtime flags, speculative decoding settings, and engine-specific structured options. +// engine is the active model-runner backend (e.g. "llama.cpp", "vllm", "mlx", "sglang"). +func parseDMRProviderOpts(engine string, cfg *latest.ModelConfig) (contextSize *int64, runtimeFlags []string, specOpts *speculativeDecodingOpts, llamaCpp *llamaCppConfig) { if cfg == nil { - return nil, nil, nil + return nil, nil, nil, nil } - contextSize = cfg.MaxTokens + contextSize = parseContextSize(cfg.ProviderOpts) + + if engine == "" || engine == engineLlamaCpp { + llamaCpp = buildLlamaCppConfig(cfg) + } - slog.Debug("DMR provider opts", "provider_opts", cfg.ProviderOpts) + slog.Debug("DMR provider opts", "provider_opts", cfg.ProviderOpts, "engine", engine) if len(cfg.ProviderOpts) == 0 { - return contextSize, nil, nil + return contextSize, nil, nil, llamaCpp } runtimeFlags = parseRuntimeFlags(cfg.ProviderOpts) specOpts = parseSpeculativeOpts(cfg.ProviderOpts) - return contextSize, runtimeFlags, specOpts + return contextSize, runtimeFlags, specOpts, llamaCpp } // parseRuntimeFlags extracts the "runtime_flags" key from provider opts.