Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/agent/__tests__/agent-model.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import { describe, expect, it } from "bun:test";

import type { AssistantMessage, ModelProvider, ModelProviderInvokeParams, UserMessage } from "@/foundation";
import { Model } from "@/foundation";

import { Agent } from "../agent";

class EchoModelProvider implements ModelProvider {
async invoke(params: ModelProviderInvokeParams): Promise<AssistantMessage> {
return modelMessage(params.model);
}

async *stream(params: ModelProviderInvokeParams): AsyncGenerator<AssistantMessage> {
yield modelMessage(params.model);
}
}

describe("Agent model switching", () => {
it("uses the new model for future steps without clearing messages", async () => {
const agent = new Agent({
model: new Model("first-model", new EchoModelProvider()),
prompt: "test",
});

const first = await runOnce(agent, "hello");
expect(first).toBe("first-model");
expect(agent.messages).toHaveLength(2);

agent.setModel(new Model("second-model", new EchoModelProvider()));
expect(agent.model.name).toBe("second-model");
expect(agent.messages).toHaveLength(2);

const second = await runOnce(agent, "again");
expect(second).toBe("second-model");
expect(agent.messages).toHaveLength(4);
});
});

async function runOnce(agent: Agent, text: string): Promise<string> {
const userMessage: UserMessage = { role: "user", content: [{ type: "text", text }] };
let finalText = "";
for await (const event of agent.stream(userMessage)) {
if (event.type !== "message" || event.message.role !== "assistant") continue;
const content = event.message.content.find((item) => item.type === "text");
finalText = content?.text ?? "";
}
return finalText;
}

function modelMessage(modelName: string): AssistantMessage {
return {
role: "assistant",
content: [{ type: "text", text: modelName }],
};
}
21 changes: 17 additions & 4 deletions src/agent/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ export interface AgentOptions {
*/
export class Agent {
private readonly _context: AgentContext;
private _model: Model;
private _streaming = false;
private _abortController: AbortController | null = null;

readonly name?: string;
readonly model: Model;
readonly options: Required<AgentOptions>;
readonly middlewares: AgentMiddleware[];

Expand Down Expand Up @@ -80,7 +80,7 @@ export class Agent {
maxSteps?: number;
}) {
this.name = name;
this.model = model;
this._model = model;
this._context = {
prompt,
tools,
Expand All @@ -97,6 +97,20 @@ export class Agent {
return this._context.messages;
}

/**
* Gets the model used for future agent steps.
*/
get model() {
return this._model;
}

/**
* Sets the model used for future agent steps.
*/
setModel(model: Model) {
this._model = model;
}

/**
* Gets or sets the prompt for the agent.
*/
Expand Down Expand Up @@ -187,7 +201,7 @@ export class Agent {
await this._beforeModel(modelContext);

let latest: AssistantMessage | null = null;
for await (const snapshot of this.model.stream(modelContext)) {
for await (const snapshot of this._model.stream(modelContext)) {
latest = snapshot;
if (snapshot.streaming) {
yield this._deriveProgress(snapshot);
Expand Down Expand Up @@ -359,4 +373,3 @@ export class Agent {
}
}
}

35 changes: 11 additions & 24 deletions src/cli/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@ import { render } from "ink";
import { validateIntegrity } from "@/cli/bootstrap";
import { registerCommands } from "@/cli/commands";
import { loadConfig } from "@/cli/config";
import { buildModelFromEntry } from "@/cli/model-factory";
import { SettingsLoader, SettingsWriter } from "@/cli/settings";
import { createCodingAgent, globalApprovalManager, globalAskUserQuestionManager } from "@/coding";
import { AnthropicModelProvider } from "@/community/anthropic";
import { OpenAIModelProvider } from "@/community/openai";
import type { ModelProvider } from "@/foundation";
import { Model } from "@/foundation";

import { App } from "./tui";
import { loadAvailableCommands, type SlashCommand } from "./tui/command-registry";
Expand Down Expand Up @@ -41,25 +38,7 @@ if (args.length > 0) {
throw new Error("No models configured. Run `helixent config model add` to add one.");
}

let provider: ModelProvider;
if (entry.provider === "anthropic") {
provider = new AnthropicModelProvider({
baseURL: entry.baseURL,
apiKey: entry.APIKey,
});
} else {
provider = new OpenAIModelProvider({
baseURL: entry.baseURL,
apiKey: entry.APIKey,
});
}

const model = new Model(entry.name, provider, {
max_tokens: 16 * 1024,
thinking: {
type: "enabled",
},
});
const model = buildModelFromEntry(entry);

const skillsDirs = [
join(process.cwd(), "skills"),
Expand All @@ -84,7 +63,15 @@ if (args.length > 0) {
const commands: SlashCommand[] = await loadAvailableCommands(skillsDirs);

render(
<AgentLoopProvider agent={agent} commands={commands}>
<AgentLoopProvider
agent={agent}
commands={commands}
modelSelection={{
models: config.models,
defaultModelName,
buildModel: buildModelFromEntry,
}}
>
<App commands={commands} supportProjectWideAllow />
</AgentLoopProvider>,
{ patchConsole: false },
Expand Down
27 changes: 27 additions & 0 deletions src/cli/model-factory.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import type { ModelEntry } from "@/cli/config";
import { AnthropicModelProvider } from "@/community/anthropic";
import { OpenAIModelProvider } from "@/community/openai";
import type { ModelProvider } from "@/foundation";
import { Model } from "@/foundation";

export function buildModelFromEntry(entry: ModelEntry): Model {
let provider: ModelProvider;
if (entry.provider === "anthropic") {
provider = new AnthropicModelProvider({
baseURL: entry.baseURL,
apiKey: entry.APIKey,
});
} else {
provider = new OpenAIModelProvider({
baseURL: entry.baseURL,
apiKey: entry.APIKey,
});
}

return new Model(entry.name, provider, {
max_tokens: 16 * 1024,
thinking: {
type: "enabled",
},
});
}
10 changes: 10 additions & 0 deletions src/cli/tui/__tests__/command-registry.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ describe("resolveBuiltinCommand", () => {
expect(resolveBuiltinCommand("/clear")).toEqual({ name: "clear", args: "" });
expect(resolveBuiltinCommand("/exit")).toEqual({ name: "exit", args: "" });
expect(resolveBuiltinCommand("/help")).toEqual({ name: "help", args: "" });
expect(resolveBuiltinCommand("/model")).toEqual({ name: "model", args: "" });
});

it("captures trailing args after a builtin", () => {
expect(resolveBuiltinCommand("/help clear")).toEqual({ name: "help", args: "clear" });
expect(resolveBuiltinCommand("/model deepseek-v4-pro")).toEqual({ name: "model", args: "deepseek-v4-pro" });
expect(resolveBuiltinCommand("/help skill-creator")).toEqual({
name: "help",
args: "skill-creator",
Expand Down Expand Up @@ -39,6 +41,7 @@ describe("formatHelp", () => {
expect(text).toContain("Available slash commands");
expect(text).toContain("/clear");
expect(text).toContain("/help");
expect(text).toContain("/model");
expect(text).toContain("/skill-creator");
expect(text).toContain("Create new skills");
});
Expand All @@ -50,6 +53,13 @@ describe("formatHelp", () => {
expect(text).toContain("Clear the current conversation history");
});

it("renders details for the model command", () => {
const text = formatHelp(commands, "model");
expect(text).toContain("/model");
expect(text).toContain("Built-in command");
expect(text).toContain("Choose the model");
});

it("tolerates a leading slash and case in target", () => {
const text = formatHelp(commands, "/CLEAR");
expect(text).toContain("/clear");
Expand Down
58 changes: 58 additions & 0 deletions src/cli/tui/__tests__/model-command.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import { describe, expect, it } from "bun:test";

import type { ModelEntry } from "@/cli/config";

import { resolveModelSelection } from "../model-command";

const models: ModelEntry[] = [
{
name: "deepseek-v4-flash",
baseURL: "https://api.deepseek.com/v1",
APIKey: "key",
provider: "openai",
},
{
name: "deepseek-v4-pro",
baseURL: "https://api.deepseek.com/v1",
APIKey: "key",
provider: "openai",
},
];

describe("resolveModelSelection", () => {
it("selects a different configured model", () => {
const result = resolveModelSelection({
models,
currentModelName: "deepseek-v4-flash",
targetName: "deepseek-v4-pro",
});

expect(result.ok).toBe(true);
expect(result.message).toContain("Switched model");
if (result.ok) {
expect(result.entry.name).toBe("deepseek-v4-pro");
}
});

it("rejects an unknown model without selecting one", () => {
const result = resolveModelSelection({
models,
currentModelName: "deepseek-v4-flash",
targetName: "unknown",
});

expect(result.ok).toBe(false);
expect(result.message).toContain("not found");
});

it("rejects the current model", () => {
const result = resolveModelSelection({
models,
currentModelName: "deepseek-v4-flash",
targetName: "deepseek-v4-flash",
});

expect(result.ok).toBe(false);
expect(result.message).toContain("Already using");
});
});
13 changes: 11 additions & 2 deletions src/cli/tui/app.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { Footer } from "./components/footer";
import { Header } from "./components/header";
import { InputBox } from "./components/input-box";
import { MessageHistoryItem } from "./components/message-history";
import { ModelSelectionPrompt } from "./components/model-selection-prompt";
import { StreamingIndicator } from "./components/streaming-indicator";
import { TodoPanel } from "./components/todo-panel";
import { useAgentLoop } from "./hooks/use-agent-loop";
Expand All @@ -29,7 +30,7 @@ export function App({
commands: SlashCommand[];
supportProjectWideAllow?: boolean;
}) {
const { streaming, messages, onSubmit, abort } = useAgentLoop();
const { streaming, messages, onSubmit, abort, modelPicker, selectModel, cancelModelSelection } = useAgentLoop();
const { approvalRequest, respondToApproval } = useApprovalManager();
const { askUserQuestionRequest, respondWithAnswers } = useAskUserQuestionManager();
const { latestTodos, todoSnapshots } = useMemo(() => buildTodoViewState(messages), [messages]);
Expand Down Expand Up @@ -57,7 +58,7 @@ export function App({
todoSnapshots={todoSnapshots}
/>
)}
{approvalRequest || askUserQuestionRequest ? null : (
{approvalRequest || askUserQuestionRequest || modelPicker ? null : (
<StreamingIndicator streaming={streaming} nextTodo={nextTodo} />
)}
{!hideTodos && <TodoPanel todos={latestTodos} />}
Expand All @@ -72,6 +73,14 @@ export function App({
questions={askUserQuestionRequest.params.questions}
onSubmit={respondWithAnswers}
/>
) : modelPicker ? (
<ModelSelectionPrompt
models={modelPicker.models}
currentModelName={modelPicker.currentModelName}
defaultModelName={modelPicker.defaultModelName}
onSelect={selectModel}
onCancel={cancelModelSelection}
/>
) : (
<InputBox commands={commands} onSubmit={onSubmit} onAbort={abort} />
)}
Expand Down
5 changes: 5 additions & 0 deletions src/cli/tui/command-registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ export const BUILTIN_COMMANDS: SlashCommand[] = [
description: "List available slash commands, or show details for one (`/help <name>`)",
type: "builtin",
},
{
name: "model",
description: "Choose the model for this TUI session",
type: "builtin",
},
{
name: "quit",
description: "Exit the TUI session",
Expand Down
16 changes: 1 addition & 15 deletions src/cli/tui/components/command-list.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Box, Text } from "ink";

import type { SlashCommand } from "../command-registry";
import { currentTheme } from "../themes";
import { getVisibleWindow } from "../visible-window";

const MAX_VISIBLE_COMMANDS = 5;

Expand Down Expand Up @@ -64,18 +65,3 @@ function summarizeDescription(description: string, maxLength = 72): string {
if (normalized.length <= maxLength) return normalized;
return `${normalized.slice(0, maxLength - 3)}...`;
}

function getVisibleWindow(total: number, selectedIndex: number, maxVisible: number) {
if (total <= maxVisible) {
return { startIndex: 0, endIndex: total };
}

const halfWindow = Math.floor(maxVisible / 2);
const maxStartIndex = total - maxVisible;
const startIndex = Math.max(0, Math.min(selectedIndex - halfWindow, maxStartIndex));

return {
startIndex,
endIndex: startIndex + maxVisible,
};
}
Loading