Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 144 additions & 8 deletions OpenCodeClient/OpenCodeClient/AppState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ final class AppState {
private static let aiBuilderLastOKTestedAtKey = "aiBuilderLastOKTestedAt"
private static let draftInputsBySessionKey = "draftInputsBySession"
private static let selectedModelBySessionKey = "selectedModelBySession"
private static let selectedVariantBySessionKey = "selectedVariantBySession"
private static let showArchivedSessionsKey = "showArchivedSessions"
private static let selectedProjectWorktreeKey = "selectedProjectWorktree"
private static let customProjectPathKey = "customProjectPath"
Expand Down Expand Up @@ -218,13 +219,19 @@ final class AppState {
let decoded = try? JSONDecoder().decode([String: String].self, from: data) {
selectedModelIDBySessionID = decoded
}

if let data = UserDefaults.standard.data(forKey: Self.selectedVariantBySessionKey),
let decoded = try? JSONDecoder().decode([String: String].self, from: data) {
selectedVariantBySessionID = decoded
}
}

// Unsent composer drafts per session.
private var draftInputsBySessionID: [String: String] = [:]

// Selected model (providerID/modelID) per session.
private var selectedModelIDBySessionID: [String: String] = [:]
private var selectedVariantBySessionID: [String: String] = [:]

private func persistSelectedModelMap() {
if selectedModelIDBySessionID.isEmpty {
Expand All @@ -236,6 +243,16 @@ final class AppState {
}
}

private func persistSelectedVariantMap() {
if selectedVariantBySessionID.isEmpty {
UserDefaults.standard.removeObject(forKey: Self.selectedVariantBySessionKey)
return
}
if let data = try? JSONEncoder().encode(selectedVariantBySessionID) {
UserDefaults.standard.set(data, forKey: Self.selectedVariantBySessionKey)
}
}

func draftText(for sessionID: String?) -> String {
guard let sessionID else { return "" }
return draftInputsBySessionID[sessionID] ?? ""
Expand Down Expand Up @@ -556,6 +573,25 @@ final class AppState {
guard modelPresets.indices.contains(selectedModelIndex) else { return nil }
return modelPresets[selectedModelIndex]
}

var selectedModelVariants: [String] {
guard let model = selectedModel else { return [] }
let key = "\(model.providerID)/\(model.modelID)"
return Self.sortedVariants(providerModelsIndex[key]?.variants ?? [])
}

var selectedVariant: String? {
guard let sessionID = currentSessionID else { return nil }
guard let saved = selectedVariantBySessionID[sessionID]?.trimmingCharacters(in: .whitespacesAndNewlines),
!saved.isEmpty else { return nil }
let available = selectedModelVariants
if available.isEmpty { return saved }
return available.contains(saved) ? saved : nil
}

var selectedVariantDisplayName: String {
Self.displayName(forVariant: selectedVariant)
}

var selectedAgent: AgentInfo? {
let visibleAgents = agents.filter { $0.isVisible }
Expand Down Expand Up @@ -660,10 +696,24 @@ final class AppState {
func setSelectedModelIndex(_ index: Int) {
guard modelPresets.indices.contains(index) else { return }
selectedModelIndex = index
normalizeSelectedVariantForCurrentSession()
guard let sessionID = currentSessionID else { return }
selectedModelIDBySessionID[sessionID] = modelPresets[index].id
persistSelectedModelMap()
}

func setSelectedVariant(_ variant: String?) {
guard let sessionID = currentSessionID else { return }
let cleaned = variant?.trimmingCharacters(in: .whitespacesAndNewlines)
if let cleaned, !cleaned.isEmpty {
let available = selectedModelVariants
guard available.isEmpty || available.contains(cleaned) else { return }
selectedVariantBySessionID[sessionID] = cleaned
} else {
selectedVariantBySessionID[sessionID] = nil
}
persistSelectedVariantMap()
}

func setSelectedAgentIndex(_ index: Int) {
let visibleAgents = agents.filter { $0.isVisible }
Expand All @@ -676,20 +726,95 @@ final class AppState {
guard let saved = selectedModelIDBySessionID[sessionID] else { return }
guard let idx = modelPresets.firstIndex(where: { $0.id == saved }) else { return }
selectedModelIndex = idx
normalizeSelectedVariantForCurrentSession()
}

private func syncModelFromMessageHistory() {
guard let sessionID = currentSessionID else { return }

guard let info = messages.reversed().compactMap({ $0.info.resolvedModel }).first else { return }
guard let idx = modelPresets.firstIndex(where: { $0.providerID == info.providerID && $0.modelID == info.modelID }) else {
if let info = messages.reversed().compactMap({ $0.info.resolvedModel }).first,
let idx = modelPresets.firstIndex(where: { $0.providerID == info.providerID && $0.modelID == info.modelID }) {
selectedModelIndex = idx
selectedModelIDBySessionID[sessionID] = modelPresets[idx].id
persistSelectedModelMap()
} else if let info = messages.reversed().compactMap({ $0.info.resolvedModel }).first {
Self.logger.warning("syncModelFromMessageHistory: model \(info.providerID, privacy: .public)/\(info.modelID, privacy: .public) not in presets, keeping current selection")
return
}

selectedModelIndex = idx
selectedModelIDBySessionID[sessionID] = modelPresets[idx].id
persistSelectedModelMap()
if selectedVariantBySessionID[sessionID] == nil,
let variant = messages.reversed()
.compactMap({ $0.info.variant?.trimmingCharacters(in: .whitespacesAndNewlines) })
.first(where: { !$0.isEmpty }) {
selectedVariantBySessionID[sessionID] = variant
persistSelectedVariantMap()
}

normalizeSelectedVariantForCurrentSession()
}

private func normalizeSelectedVariantForCurrentSession() {
guard let sessionID = currentSessionID,
let saved = selectedVariantBySessionID[sessionID]?.trimmingCharacters(in: .whitespacesAndNewlines),
!saved.isEmpty else { return }
let available = selectedModelVariants
if available.isEmpty {
if selectedVariantBySessionID[sessionID] != saved {
selectedVariantBySessionID[sessionID] = saved
persistSelectedVariantMap()
}
return
}
guard available.contains(saved) else {
selectedVariantBySessionID[sessionID] = nil
persistSelectedVariantMap()
return
}
if selectedVariantBySessionID[sessionID] != saved {
selectedVariantBySessionID[sessionID] = saved
persistSelectedVariantMap()
}
}

nonisolated private static func sortedVariants(_ values: [String]) -> [String] {
let order: [String: Int] = [
"none": 0,
"minimal": 1,
"low": 2,
"medium": 3,
"high": 4,
"xhigh": 5,
"max": 6,
]
return values.sorted {
let lhs = order[$0.lowercased()] ?? Int.max
let rhs = order[$1.lowercased()] ?? Int.max
if lhs == rhs { return $0.localizedCaseInsensitiveCompare($1) == .orderedAscending }
return lhs < rhs
}
}

nonisolated static func displayName(forVariant variant: String?) -> String {
guard let variant else { return "Auto" }
switch variant.lowercased() {
case "none":
return "None"
case "minimal":
return "Minimal"
case "low":
return "Low"
case "medium":
return "Medium"
case "high":
return "High"
case "xhigh":
return "Extra High"
case "max":
return "Max"
default:
return variant
.replacingOccurrences(of: "_", with: " ")
.replacingOccurrences(of: "-", with: " ")
.capitalized
}
}

var currentSession: Session? {
Expand Down Expand Up @@ -882,6 +1007,7 @@ final class AppState {
do {
let session = try await apiClient.createSession(title: nil)
guard sessionLoadingID == loadingID else { return }
let variant = selectedVariant

Self.logger.debug("createSession: created id=\(session.id, privacy: .public) directory=\(session.directory, privacy: .public) effectiveProjectDir=\(self.effectiveProjectDirectory ?? "nil", privacy: .public)")

Expand All @@ -891,6 +1017,10 @@ final class AppState {
selectedModelIDBySessionID[session.id] = m.id
persistSelectedModelMap()
}
if let variant {
selectedVariantBySessionID[session.id] = variant
persistSelectedVariantMap()
}
messageStore.resetStreaming()
messages = []
partsByMessage = [:]
Expand Down Expand Up @@ -1243,9 +1373,10 @@ final class AppState {
}
let tempMessageID = appendOptimisticUserMessage(text)
let model = selectedModel.map { Message.ModelInfo(providerID: $0.providerID, modelID: $0.modelID) }
let variant = selectedVariant
let agentName = selectedAgent?.name ?? "build"
do {
try await apiClient.promptAsync(sessionID: sessionID, text: text, agent: agentName, model: model)
try await apiClient.promptAsync(sessionID: sessionID, text: text, agent: agentName, model: model, variant: variant)
return true
} catch {
let recovered = await recoverFromMissingCurrentSessionIfNeeded(error: error, requestedSessionID: sessionID)
Expand All @@ -1269,6 +1400,7 @@ final class AppState {
providerID: nil,
modelID: nil,
model: nil,
variant: nil,
error: nil,
time: Message.TimeInfo(created: now, completed: now),
finish: nil,
Expand Down Expand Up @@ -1690,6 +1822,9 @@ final class AppState {

selectedModelIDBySessionID[sessionID] = nil
persistSelectedModelMap()

selectedVariantBySessionID[sessionID] = nil
persistSelectedVariantMap()
}

private func isSessionNotFoundError(_ error: Error) -> Bool {
Expand Down Expand Up @@ -1778,6 +1913,7 @@ final class AppState {
}
}
providerModelsIndex = idx
normalizeSelectedVariantForCurrentSession()
} catch {
providerConfigError = error.localizedDescription
}
Expand Down
1 change: 1 addition & 0 deletions OpenCodeClient/OpenCodeClient/Models/Message.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct Message: Codable, Identifiable {
let providerID: String?
let modelID: String?
let model: ModelInfo?
let variant: String?
let error: MessageError?
let time: TimeInfo
let finish: String?
Expand Down
14 changes: 10 additions & 4 deletions OpenCodeClient/OpenCodeClient/Services/APIClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,12 @@ actor APIClient {
return try? decoder.decode(type, from: data)
}

func promptAsync(sessionID: String, text: String, agent: String = "build", model: Message.ModelInfo?) async throws {
func promptAsync(sessionID: String, text: String, agent: String = "build", model: Message.ModelInfo?, variant: String?) async throws {
struct PromptBody: Encodable {
let parts: [PartInput]
let agent: String
let model: ModelInput?
let variant: String?
struct PartInput: Encodable {
let type = "text"
let text: String
Expand All @@ -250,7 +251,8 @@ actor APIClient {
let body = PromptBody(
parts: [.init(text: text)],
agent: agent,
model: model.map { .init(providerID: $0.providerID, modelID: $0.modelID) }
model: model.map { .init(providerID: $0.providerID, modelID: $0.modelID) },
variant: variant
)
let bodyData = try JSONEncoder().encode(body)
let (_, response) = try await makeRequest(path: "/session/\(sessionID)/prompt_async", method: "POST", body: bodyData)
Expand Down Expand Up @@ -561,20 +563,23 @@ struct ProviderModel: Decodable {
let name: String?
let providerID: String?
let limit: ProviderModelLimit?
let variants: [String]

private enum CodingKeys: String, CodingKey {
case id
case name
case providerID
case providerId
case limit
case variants
}

init(id: String, name: String?, providerID: String?, limit: ProviderModelLimit?) {
init(id: String, name: String?, providerID: String?, limit: ProviderModelLimit?, variants: [String] = []) {
self.id = id
self.name = name
self.providerID = providerID
self.limit = limit
self.variants = variants
}

init(from decoder: Decoder) throws {
Expand All @@ -583,6 +588,7 @@ struct ProviderModel: Decodable {
name = try? c.decode(String.self, forKey: .name)
providerID = (try? c.decode(String.self, forKey: .providerID)) ?? (try? c.decode(String.self, forKey: .providerId))
limit = try? c.decode(ProviderModelLimit.self, forKey: .limit)
variants = (try? c.decode([String: AnyCodable].self, forKey: .variants).keys.sorted()) ?? []
}
}

Expand Down Expand Up @@ -617,7 +623,7 @@ protocol APIClientProtocol: Actor {
func updateSession(sessionID: String, title: String) async throws -> Session
func deleteSession(sessionID: String) async throws
func messages(sessionID: String, limit: Int?) async throws -> [MessageWithParts]
func promptAsync(sessionID: String, text: String, agent: String, model: Message.ModelInfo?) async throws
func promptAsync(sessionID: String, text: String, agent: String, model: Message.ModelInfo?, variant: String?) async throws
func abort(sessionID: String) async throws
func sessionStatus() async throws -> [String: SessionStatus]
func pendingPermissions() async throws -> [APIClient.PermissionRequest]
Expand Down
1 change: 1 addition & 0 deletions OpenCodeClient/OpenCodeClient/Stores/MessageStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ final class MessageStore {
providerID: nil,
modelID: nil,
model: nil,
variant: nil,
error: nil,
time: Message.TimeInfo(created: now, completed: now),
finish: nil,
Expand Down
44 changes: 44 additions & 0 deletions OpenCodeClient/OpenCodeClient/Views/Chat/ChatToolbarView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ struct ChatToolbarView: View {
private var rightButtons: some View {
HStack(spacing: LayoutConstants.Toolbar.modelButtonSpacing) {
modelMenu
if !state.selectedModelVariants.isEmpty {
effortMenu
}
agentMenu
ContextUsageButton(state: state)

Expand Down Expand Up @@ -133,6 +136,47 @@ struct ChatToolbarView: View {
}
.menuStyle(.borderlessButton)
}

private var effortMenu: some View {
Menu {
Button {
state.setSelectedVariant(nil)
} label: {
HStack {
Text("Auto")
if state.selectedVariant == nil {
Image(systemName: "checkmark")
}
}
}

ForEach(state.selectedModelVariants, id: \.self) { variant in
Button {
state.setSelectedVariant(variant)
} label: {
HStack {
Text(AppState.displayName(forVariant: variant))
if state.selectedVariant == variant {
Image(systemName: "checkmark")
}
}
}
}
} label: {
HStack(spacing: 4) {
Text(state.selectedVariantDisplayName)
.font(.caption.weight(.semibold))
Image(systemName: "chevron.down")
.font(.caption2)
}
.padding(.horizontal, 12)
.padding(.vertical, 7)
.background(Color(.systemGray5))
.foregroundColor(.primary)
.clipShape(Capsule())
}
.menuStyle(.borderlessButton)
}

// MARK: - Agent Selection Menu
private var agentMenu: some View {
Expand Down
Loading