From 664390cc3e318ebbbed42191b8f1e34b6de34dab Mon Sep 17 00:00:00 2001 From: qkunio Date: Fri, 29 May 2026 16:42:05 +0800 Subject: [PATCH] feat(tui): add session model picker --- src/agent/__tests__/agent-model.test.ts | 55 +++++++++++ src/agent/agent.ts | 21 +++- src/cli/index.tsx | 35 +++---- src/cli/model-factory.ts | 27 ++++++ .../tui/__tests__/command-registry.test.ts | 10 ++ src/cli/tui/__tests__/model-command.test.ts | 58 +++++++++++ src/cli/tui/app.tsx | 13 ++- src/cli/tui/command-registry.ts | 5 + src/cli/tui/components/command-list.tsx | 16 +--- .../tui/components/model-selection-prompt.tsx | 86 +++++++++++++++++ src/cli/tui/hooks/use-agent-loop.ts | 95 ++++++++++++++++++- src/cli/tui/model-command.ts | 41 ++++++++ src/cli/tui/visible-window.ts | 14 +++ 13 files changed, 428 insertions(+), 48 deletions(-) create mode 100644 src/agent/__tests__/agent-model.test.ts create mode 100644 src/cli/model-factory.ts create mode 100644 src/cli/tui/__tests__/model-command.test.ts create mode 100644 src/cli/tui/components/model-selection-prompt.tsx create mode 100644 src/cli/tui/model-command.ts create mode 100644 src/cli/tui/visible-window.ts diff --git a/src/agent/__tests__/agent-model.test.ts b/src/agent/__tests__/agent-model.test.ts new file mode 100644 index 0000000..2ed3624 --- /dev/null +++ b/src/agent/__tests__/agent-model.test.ts @@ -0,0 +1,55 @@ +import { describe, expect, it } from "bun:test"; + +import type { AssistantMessage, ModelProvider, ModelProviderInvokeParams, UserMessage } from "@/foundation"; +import { Model } from "@/foundation"; + +import { Agent } from "../agent"; + +class EchoModelProvider implements ModelProvider { + async invoke(params: ModelProviderInvokeParams): Promise { + return modelMessage(params.model); + } + + async *stream(params: ModelProviderInvokeParams): AsyncGenerator { + yield modelMessage(params.model); + } +} + +describe("Agent model switching", () => { + it("uses the new model for future steps without clearing messages", async () => { + const agent = new Agent({ + model: new Model("first-model", new EchoModelProvider()), + prompt: "test", + }); + + const first = await runOnce(agent, "hello"); + expect(first).toBe("first-model"); + expect(agent.messages).toHaveLength(2); + + agent.setModel(new Model("second-model", new EchoModelProvider())); + expect(agent.model.name).toBe("second-model"); + expect(agent.messages).toHaveLength(2); + + const second = await runOnce(agent, "again"); + expect(second).toBe("second-model"); + expect(agent.messages).toHaveLength(4); + }); +}); + +async function runOnce(agent: Agent, text: string): Promise { + const userMessage: UserMessage = { role: "user", content: [{ type: "text", text }] }; + let finalText = ""; + for await (const event of agent.stream(userMessage)) { + if (event.type !== "message" || event.message.role !== "assistant") continue; + const content = event.message.content.find((item) => item.type === "text"); + finalText = content?.text ?? ""; + } + return finalText; +} + +function modelMessage(modelName: string): AssistantMessage { + return { + role: "assistant", + content: [{ type: "text", text: modelName }], + }; +} diff --git a/src/agent/agent.ts b/src/agent/agent.ts index 4535d82..eb6f065 100644 --- a/src/agent/agent.ts +++ b/src/agent/agent.ts @@ -47,11 +47,11 @@ export interface AgentOptions { */ export class Agent { private readonly _context: AgentContext; + private _model: Model; private _streaming = false; private _abortController: AbortController | null = null; readonly name?: string; - readonly model: Model; readonly options: Required; readonly middlewares: AgentMiddleware[]; @@ -80,7 +80,7 @@ export class Agent { maxSteps?: number; }) { this.name = name; - this.model = model; + this._model = model; this._context = { prompt, tools, @@ -97,6 +97,20 @@ export class Agent { return this._context.messages; } + /** + * Gets the model used for future agent steps. + */ + get model() { + return this._model; + } + + /** + * Sets the model used for future agent steps. + */ + setModel(model: Model) { + this._model = model; + } + /** * Gets or sets the prompt for the agent. */ @@ -187,7 +201,7 @@ export class Agent { await this._beforeModel(modelContext); let latest: AssistantMessage | null = null; - for await (const snapshot of this.model.stream(modelContext)) { + for await (const snapshot of this._model.stream(modelContext)) { latest = snapshot; if (snapshot.streaming) { yield this._deriveProgress(snapshot); @@ -359,4 +373,3 @@ export class Agent { } } } - diff --git a/src/cli/index.tsx b/src/cli/index.tsx index 006c6f4..e1f279a 100644 --- a/src/cli/index.tsx +++ b/src/cli/index.tsx @@ -6,12 +6,9 @@ import { render } from "ink"; import { validateIntegrity } from "@/cli/bootstrap"; import { registerCommands } from "@/cli/commands"; import { loadConfig } from "@/cli/config"; +import { buildModelFromEntry } from "@/cli/model-factory"; import { SettingsLoader, SettingsWriter } from "@/cli/settings"; import { createCodingAgent, globalApprovalManager, globalAskUserQuestionManager } from "@/coding"; -import { AnthropicModelProvider } from "@/community/anthropic"; -import { OpenAIModelProvider } from "@/community/openai"; -import type { ModelProvider } from "@/foundation"; -import { Model } from "@/foundation"; import { App } from "./tui"; import { loadAvailableCommands, type SlashCommand } from "./tui/command-registry"; @@ -41,25 +38,7 @@ if (args.length > 0) { throw new Error("No models configured. Run `helixent config model add` to add one."); } - let provider: ModelProvider; - if (entry.provider === "anthropic") { - provider = new AnthropicModelProvider({ - baseURL: entry.baseURL, - apiKey: entry.APIKey, - }); - } else { - provider = new OpenAIModelProvider({ - baseURL: entry.baseURL, - apiKey: entry.APIKey, - }); - } - - const model = new Model(entry.name, provider, { - max_tokens: 16 * 1024, - thinking: { - type: "enabled", - }, - }); + const model = buildModelFromEntry(entry); const skillsDirs = [ join(process.cwd(), "skills"), @@ -84,7 +63,15 @@ if (args.length > 0) { const commands: SlashCommand[] = await loadAvailableCommands(skillsDirs); render( - + , { patchConsole: false }, diff --git a/src/cli/model-factory.ts b/src/cli/model-factory.ts new file mode 100644 index 0000000..335f83d --- /dev/null +++ b/src/cli/model-factory.ts @@ -0,0 +1,27 @@ +import type { ModelEntry } from "@/cli/config"; +import { AnthropicModelProvider } from "@/community/anthropic"; +import { OpenAIModelProvider } from "@/community/openai"; +import type { ModelProvider } from "@/foundation"; +import { Model } from "@/foundation"; + +export function buildModelFromEntry(entry: ModelEntry): Model { + let provider: ModelProvider; + if (entry.provider === "anthropic") { + provider = new AnthropicModelProvider({ + baseURL: entry.baseURL, + apiKey: entry.APIKey, + }); + } else { + provider = new OpenAIModelProvider({ + baseURL: entry.baseURL, + apiKey: entry.APIKey, + }); + } + + return new Model(entry.name, provider, { + max_tokens: 16 * 1024, + thinking: { + type: "enabled", + }, + }); +} diff --git a/src/cli/tui/__tests__/command-registry.test.ts b/src/cli/tui/__tests__/command-registry.test.ts index 806ed0f..994fffd 100644 --- a/src/cli/tui/__tests__/command-registry.test.ts +++ b/src/cli/tui/__tests__/command-registry.test.ts @@ -7,10 +7,12 @@ describe("resolveBuiltinCommand", () => { expect(resolveBuiltinCommand("/clear")).toEqual({ name: "clear", args: "" }); expect(resolveBuiltinCommand("/exit")).toEqual({ name: "exit", args: "" }); expect(resolveBuiltinCommand("/help")).toEqual({ name: "help", args: "" }); + expect(resolveBuiltinCommand("/model")).toEqual({ name: "model", args: "" }); }); it("captures trailing args after a builtin", () => { expect(resolveBuiltinCommand("/help clear")).toEqual({ name: "help", args: "clear" }); + expect(resolveBuiltinCommand("/model deepseek-v4-pro")).toEqual({ name: "model", args: "deepseek-v4-pro" }); expect(resolveBuiltinCommand("/help skill-creator")).toEqual({ name: "help", args: "skill-creator", @@ -39,6 +41,7 @@ describe("formatHelp", () => { expect(text).toContain("Available slash commands"); expect(text).toContain("/clear"); expect(text).toContain("/help"); + expect(text).toContain("/model"); expect(text).toContain("/skill-creator"); expect(text).toContain("Create new skills"); }); @@ -50,6 +53,13 @@ describe("formatHelp", () => { expect(text).toContain("Clear the current conversation history"); }); + it("renders details for the model command", () => { + const text = formatHelp(commands, "model"); + expect(text).toContain("/model"); + expect(text).toContain("Built-in command"); + expect(text).toContain("Choose the model"); + }); + it("tolerates a leading slash and case in target", () => { const text = formatHelp(commands, "/CLEAR"); expect(text).toContain("/clear"); diff --git a/src/cli/tui/__tests__/model-command.test.ts b/src/cli/tui/__tests__/model-command.test.ts new file mode 100644 index 0000000..deec147 --- /dev/null +++ b/src/cli/tui/__tests__/model-command.test.ts @@ -0,0 +1,58 @@ +import { describe, expect, it } from "bun:test"; + +import type { ModelEntry } from "@/cli/config"; + +import { resolveModelSelection } from "../model-command"; + +const models: ModelEntry[] = [ + { + name: "deepseek-v4-flash", + baseURL: "https://api.deepseek.com/v1", + APIKey: "key", + provider: "openai", + }, + { + name: "deepseek-v4-pro", + baseURL: "https://api.deepseek.com/v1", + APIKey: "key", + provider: "openai", + }, +]; + +describe("resolveModelSelection", () => { + it("selects a different configured model", () => { + const result = resolveModelSelection({ + models, + currentModelName: "deepseek-v4-flash", + targetName: "deepseek-v4-pro", + }); + + expect(result.ok).toBe(true); + expect(result.message).toContain("Switched model"); + if (result.ok) { + expect(result.entry.name).toBe("deepseek-v4-pro"); + } + }); + + it("rejects an unknown model without selecting one", () => { + const result = resolveModelSelection({ + models, + currentModelName: "deepseek-v4-flash", + targetName: "unknown", + }); + + expect(result.ok).toBe(false); + expect(result.message).toContain("not found"); + }); + + it("rejects the current model", () => { + const result = resolveModelSelection({ + models, + currentModelName: "deepseek-v4-flash", + targetName: "deepseek-v4-flash", + }); + + expect(result.ok).toBe(false); + expect(result.message).toContain("Already using"); + }); +}); diff --git a/src/cli/tui/app.tsx b/src/cli/tui/app.tsx index 4b8ad6c..bd0130d 100644 --- a/src/cli/tui/app.tsx +++ b/src/cli/tui/app.tsx @@ -10,6 +10,7 @@ import { Footer } from "./components/footer"; import { Header } from "./components/header"; import { InputBox } from "./components/input-box"; import { MessageHistoryItem } from "./components/message-history"; +import { ModelSelectionPrompt } from "./components/model-selection-prompt"; import { StreamingIndicator } from "./components/streaming-indicator"; import { TodoPanel } from "./components/todo-panel"; import { useAgentLoop } from "./hooks/use-agent-loop"; @@ -29,7 +30,7 @@ export function App({ commands: SlashCommand[]; supportProjectWideAllow?: boolean; }) { - const { streaming, messages, onSubmit, abort } = useAgentLoop(); + const { streaming, messages, onSubmit, abort, modelPicker, selectModel, cancelModelSelection } = useAgentLoop(); const { approvalRequest, respondToApproval } = useApprovalManager(); const { askUserQuestionRequest, respondWithAnswers } = useAskUserQuestionManager(); const { latestTodos, todoSnapshots } = useMemo(() => buildTodoViewState(messages), [messages]); @@ -57,7 +58,7 @@ export function App({ todoSnapshots={todoSnapshots} /> )} - {approvalRequest || askUserQuestionRequest ? null : ( + {approvalRequest || askUserQuestionRequest || modelPicker ? null : ( )} {!hideTodos && } @@ -72,6 +73,14 @@ export function App({ questions={askUserQuestionRequest.params.questions} onSubmit={respondWithAnswers} /> + ) : modelPicker ? ( + ) : ( )} diff --git a/src/cli/tui/command-registry.ts b/src/cli/tui/command-registry.ts index 514598a..03faae9 100644 --- a/src/cli/tui/command-registry.ts +++ b/src/cli/tui/command-registry.ts @@ -28,6 +28,11 @@ export const BUILTIN_COMMANDS: SlashCommand[] = [ description: "List available slash commands, or show details for one (`/help `)", type: "builtin", }, + { + name: "model", + description: "Choose the model for this TUI session", + type: "builtin", + }, { name: "quit", description: "Exit the TUI session", diff --git a/src/cli/tui/components/command-list.tsx b/src/cli/tui/components/command-list.tsx index fd49313..9172331 100644 --- a/src/cli/tui/components/command-list.tsx +++ b/src/cli/tui/components/command-list.tsx @@ -2,6 +2,7 @@ import { Box, Text } from "ink"; import type { SlashCommand } from "../command-registry"; import { currentTheme } from "../themes"; +import { getVisibleWindow } from "../visible-window"; const MAX_VISIBLE_COMMANDS = 5; @@ -64,18 +65,3 @@ function summarizeDescription(description: string, maxLength = 72): string { if (normalized.length <= maxLength) return normalized; return `${normalized.slice(0, maxLength - 3)}...`; } - -function getVisibleWindow(total: number, selectedIndex: number, maxVisible: number) { - if (total <= maxVisible) { - return { startIndex: 0, endIndex: total }; - } - - const halfWindow = Math.floor(maxVisible / 2); - const maxStartIndex = total - maxVisible; - const startIndex = Math.max(0, Math.min(selectedIndex - halfWindow, maxStartIndex)); - - return { - startIndex, - endIndex: startIndex + maxVisible, - }; -} diff --git a/src/cli/tui/components/model-selection-prompt.tsx b/src/cli/tui/components/model-selection-prompt.tsx new file mode 100644 index 0000000..0ad0495 --- /dev/null +++ b/src/cli/tui/components/model-selection-prompt.tsx @@ -0,0 +1,86 @@ +import { Box, Text, useInput } from "ink"; +import { memo, useState } from "react"; + +import type { ModelEntry } from "@/cli/config"; + +import { currentTheme } from "../themes"; +import { getVisibleWindow } from "../visible-window"; + +const MAX_VISIBLE_MODELS = 8; + +export const ModelSelectionPrompt = memo(function ModelSelectionPrompt({ + models, + currentModelName, + defaultModelName, + onCancel, + onSelect, +}: { + models: ModelEntry[]; + currentModelName: string; + defaultModelName?: string; + onCancel: () => void; + // eslint-disable-next-line no-unused-vars + onSelect: (modelName: string) => void; +}) { + const initialIndex = Math.max(0, models.findIndex((model) => model.name === currentModelName)); + const [selectedIndex, setSelectedIndex] = useState(initialIndex); + + useInput((input, key) => { + if (key.escape || input === "q") { + onCancel(); + return; + } + + if (key.upArrow) { + setSelectedIndex((index) => (index > 0 ? index - 1 : models.length - 1)); + return; + } + + if (key.downArrow) { + setSelectedIndex((index) => (index < models.length - 1 ? index + 1 : 0)); + return; + } + + if (key.return) { + const selected = models[selectedIndex]; + if (selected) { + onSelect(selected.name); + } + } + }); + + const { endIndex, startIndex } = getVisibleWindow(models.length, selectedIndex, MAX_VISIBLE_MODELS); + const visibleModels = models.slice(startIndex, endIndex); + + return ( + + + Select model + + {visibleModels.map((model, visibleIndex) => { + const index = startIndex + visibleIndex; + const selected = index === selectedIndex; + const markers = [ + model.name === currentModelName ? "current" : null, + model.name === defaultModelName ? "default" : null, + ].filter(Boolean); + const suffix = markers.length > 0 ? ` (${markers.join(", ")})` : ""; + + return ( + + + {selected ? "❯ " : " "} + + + {model.name} + + + {suffix} - {model.provider ?? "openai"} - {model.baseURL} + + + ); + })} + ↑/↓ move · Enter switch · Esc cancel + + ); +}); diff --git a/src/cli/tui/hooks/use-agent-loop.ts b/src/cli/tui/hooks/use-agent-loop.ts index e47e407..1249fcf 100644 --- a/src/cli/tui/hooks/use-agent-loop.ts +++ b/src/cli/tui/hooks/use-agent-loop.ts @@ -2,12 +2,27 @@ import { createContext, createElement, useCallback, useContext, useEffect, useMe import type { ReactNode } from "react"; import type { Agent } from "@/agent"; -import type { AssistantMessage, NonSystemMessage, UserMessage } from "@/foundation"; +import type { ModelEntry } from "@/cli/config"; +import type { AssistantMessage, Model, NonSystemMessage, UserMessage } from "@/foundation"; import type { PromptSubmission, SlashCommand } from "../command-registry"; import { formatHelp, resolveBuiltinCommand } from "../command-registry"; +import { resolveModelSelection } from "../model-command"; import { calculateTokenUsage, type TokenUsageSummary } from "../token-usage"; +export type ModelSelectionOptions = { + models: ModelEntry[]; + defaultModelName?: string; + // eslint-disable-next-line no-unused-vars + buildModel: (entry: ModelEntry) => Model; +}; + +export type ModelPickerState = { + models: ModelEntry[]; + currentModelName: string; + defaultModelName?: string; +}; + type AgentLoopState = { agent: Agent; streaming: boolean; @@ -16,6 +31,10 @@ type AgentLoopState = { onSubmit: (submission: PromptSubmission) => Promise; abort: () => void; tokenUsage: TokenUsageSummary; + modelPicker: ModelPickerState | null; + // eslint-disable-next-line no-unused-vars + selectModel: (modelName: string) => void; + cancelModelSelection: () => void; }; const AgentLoopContext = createContext(null); @@ -23,14 +42,17 @@ const AgentLoopContext = createContext(null); export function AgentLoopProvider({ agent, commands = [], + modelSelection, children, }: { agent: Agent; commands?: SlashCommand[]; + modelSelection?: ModelSelectionOptions; children: ReactNode; }) { const [streaming, setStreaming] = useState(false); const [messages, setMessages] = useState([]); + const [modelPicker, setModelPicker] = useState(null); const streamingRef = useRef(streaming); const pendingMessagesRef = useRef([]); @@ -80,6 +102,27 @@ export function AgentLoopProvider({ return calculateTokenUsage(messages); }, [messages]); + const appendAssistantText = useCallback((text: string) => { + const assistantMessage: AssistantMessage = { + role: "assistant", + content: [{ type: "text", text }], + }; + setMessages((prev) => [...prev, assistantMessage]); + }, []); + + const selectModel = useCallback( + (modelName: string) => { + const message = handleModelCommand(agent, modelSelection, modelName); + setModelPicker(null); + appendAssistantText(message); + }, + [agent, appendAssistantText, modelSelection], + ); + + const cancelModelSelection = useCallback(() => { + setModelPicker(null); + }, []); + const onSubmit = useCallback( async (submission: PromptSubmission) => { const { text, requestedSkillName } = submission; @@ -116,6 +159,27 @@ export function AgentLoopProvider({ return; } + if (invocation?.name === "model") { + flushPendingMessages(); + const userMessage: UserMessage = { role: "user", content: [{ type: "text", text }] }; + if (!invocation.args && modelSelection && modelSelection.models.length > 0) { + setMessages((prev) => [...prev, userMessage]); + setModelPicker({ + models: modelSelection.models, + currentModelName: agent.model.name, + defaultModelName: modelSelection.defaultModelName, + }); + return; + } + + const assistantMessage: AssistantMessage = { + role: "assistant", + content: [{ type: "text", text: handleModelCommand(agent, modelSelection, invocation.args) }], + }; + setMessages((prev) => [...prev, userMessage, assistantMessage]); + return; + } + setStreaming(true); try { @@ -146,7 +210,7 @@ export function AgentLoopProvider({ setStreaming(false); } }, - [agent, commands, enqueueMessage, flushPendingMessages], + [agent, commands, enqueueMessage, flushPendingMessages, modelSelection], ); const value = useMemo( @@ -157,8 +221,11 @@ export function AgentLoopProvider({ onSubmit, abort, tokenUsage, + modelPicker, + selectModel, + cancelModelSelection, }), - [abort, agent, messages, onSubmit, streaming, tokenUsage], + [abort, agent, cancelModelSelection, messages, modelPicker, onSubmit, selectModel, streaming, tokenUsage], ); return createElement(AgentLoopContext.Provider, { value }, children); @@ -176,6 +243,28 @@ export function useAgentLoop() { return useAgentLoopState(); } +function handleModelCommand( + agent: Agent, + modelSelection: ModelSelectionOptions | undefined, + args: string, +): string { + if (!modelSelection) { + return "Model selection is unavailable in this session."; + } + + const selection = resolveModelSelection({ + models: modelSelection.models, + currentModelName: agent.model.name, + targetName: args, + }); + if (!selection.ok) { + return selection.message; + } + + agent.setModel(modelSelection.buildModel(selection.entry)); + return selection.message; +} + function isAbortError(error: unknown): boolean { if (error instanceof DOMException && error.name === "AbortError") return true; if (error instanceof Error && error.name === "AbortError") return true; diff --git a/src/cli/tui/model-command.ts b/src/cli/tui/model-command.ts new file mode 100644 index 0000000..2e234fa --- /dev/null +++ b/src/cli/tui/model-command.ts @@ -0,0 +1,41 @@ +import type { ModelEntry } from "@/cli/config"; + +export type ModelSelectionResult = + | { ok: true; entry: ModelEntry; message: string } + | { ok: false; message: string }; + +export function resolveModelSelection({ + models, + currentModelName, + targetName, +}: { + models: ModelEntry[]; + currentModelName: string; + targetName: string; +}): ModelSelectionResult { + const target = targetName.trim(); + if (!target) { + return { ok: false, message: "No models configured. Run `helixent config model add` to add one." }; + } + + const entry = models.find((model) => model.name === target); + if (!entry) { + return { + ok: false, + message: `Model "${target}" not found. Run \`/model\` to choose from configured models.`, + }; + } + + if (entry.name === currentModelName) { + return { + ok: false, + message: `Already using model "${entry.name}".`, + }; + } + + return { + ok: true, + entry, + message: `Switched model to "${entry.name}" for this session.`, + }; +} diff --git a/src/cli/tui/visible-window.ts b/src/cli/tui/visible-window.ts new file mode 100644 index 0000000..0d314e0 --- /dev/null +++ b/src/cli/tui/visible-window.ts @@ -0,0 +1,14 @@ +export function getVisibleWindow(total: number, selectedIndex: number, maxVisible: number) { + if (total <= maxVisible) { + return { startIndex: 0, endIndex: total }; + } + + const halfWindow = Math.floor(maxVisible / 2); + const maxStartIndex = total - maxVisible; + const startIndex = Math.max(0, Math.min(selectedIndex - halfWindow, maxStartIndex)); + + return { + startIndex, + endIndex: startIndex + maxVisible, + }; +}