Skip to content

Commit a64c49e

Browse files
committed
feat: add dynamic model effort selection
1 parent 3876d71 commit a64c49e

5 files changed

Lines changed: 201 additions & 12 deletions

File tree

OpenCodeClient/OpenCodeClient/AppState.swift

Lines changed: 147 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ final class AppState {
166166
private static let aiBuilderLastOKTestedAtKey = "aiBuilderLastOKTestedAt"
167167
private static let draftInputsBySessionKey = "draftInputsBySession"
168168
private static let selectedModelBySessionKey = "selectedModelBySession"
169+
private static let selectedVariantBySessionKey = "selectedVariantBySession"
169170
private static let showArchivedSessionsKey = "showArchivedSessions"
170171
private static let selectedProjectWorktreeKey = "selectedProjectWorktree"
171172
private static let customProjectPathKey = "customProjectPath"
@@ -211,13 +212,19 @@ final class AppState {
211212
let decoded = try? JSONDecoder().decode([String: String].self, from: data) {
212213
selectedModelIDBySessionID = decoded
213214
}
215+
216+
if let data = UserDefaults.standard.data(forKey: Self.selectedVariantBySessionKey),
217+
let decoded = try? JSONDecoder().decode([String: String].self, from: data) {
218+
selectedVariantBySessionID = decoded
219+
}
214220
}
215221

216222
// Unsent composer drafts per session.
217223
private var draftInputsBySessionID: [String: String] = [:]
218224

219225
// Selected model (providerID/modelID) per session.
220226
private var selectedModelIDBySessionID: [String: String] = [:]
227+
private var selectedVariantBySessionID: [String: String] = [:]
221228

222229
private func persistSelectedModelMap() {
223230
if selectedModelIDBySessionID.isEmpty {
@@ -229,6 +236,16 @@ final class AppState {
229236
}
230237
}
231238

239+
private func persistSelectedVariantMap() {
240+
if selectedVariantBySessionID.isEmpty {
241+
UserDefaults.standard.removeObject(forKey: Self.selectedVariantBySessionKey)
242+
return
243+
}
244+
if let data = try? JSONEncoder().encode(selectedVariantBySessionID) {
245+
UserDefaults.standard.set(data, forKey: Self.selectedVariantBySessionKey)
246+
}
247+
}
248+
232249
func draftText(for sessionID: String?) -> String {
233250
guard let sessionID else { return "" }
234251
return draftInputsBySessionID[sessionID] ?? ""
@@ -538,6 +555,25 @@ final class AppState {
538555
guard modelPresets.indices.contains(selectedModelIndex) else { return nil }
539556
return modelPresets[selectedModelIndex]
540557
}
558+
559+
var selectedModelVariants: [String] {
560+
guard let model = selectedModel else { return [] }
561+
let key = "\(model.providerID)/\(model.modelID)"
562+
return Self.sortedVariants(providerModelsIndex[key]?.variants ?? [])
563+
}
564+
565+
var selectedVariant: String? {
566+
guard let sessionID = currentSessionID else { return nil }
567+
guard let saved = selectedVariantBySessionID[sessionID]?.trimmingCharacters(in: .whitespacesAndNewlines),
568+
!saved.isEmpty else { return nil }
569+
let available = selectedModelVariants
570+
if available.isEmpty { return saved }
571+
return available.contains(saved) ? saved : nil
572+
}
573+
574+
var selectedVariantDisplayName: String {
575+
Self.displayName(forVariant: selectedVariant)
576+
}
541577

542578
var selectedAgent: AgentInfo? {
543579
let visibleAgents = agents.filter { $0.isVisible }
@@ -627,10 +663,24 @@ final class AppState {
627663
func setSelectedModelIndex(_ index: Int) {
628664
guard modelPresets.indices.contains(index) else { return }
629665
selectedModelIndex = index
666+
normalizeSelectedVariantForCurrentSession()
630667
guard let sessionID = currentSessionID else { return }
631668
selectedModelIDBySessionID[sessionID] = modelPresets[index].id
632669
persistSelectedModelMap()
633670
}
671+
672+
func setSelectedVariant(_ variant: String?) {
673+
guard let sessionID = currentSessionID else { return }
674+
let cleaned = variant?.trimmingCharacters(in: .whitespacesAndNewlines)
675+
if let cleaned, !cleaned.isEmpty {
676+
let available = selectedModelVariants
677+
guard available.isEmpty || available.contains(cleaned) else { return }
678+
selectedVariantBySessionID[sessionID] = cleaned
679+
} else {
680+
selectedVariantBySessionID[sessionID] = nil
681+
}
682+
persistSelectedVariantMap()
683+
}
634684

635685
func setSelectedAgentIndex(_ index: Int) {
636686
let visibleAgents = agents.filter { $0.isVisible }
@@ -643,18 +693,94 @@ final class AppState {
643693
guard let saved = selectedModelIDBySessionID[sessionID] else { return }
644694
guard let idx = modelPresets.firstIndex(where: { $0.id == saved }) else { return }
645695
selectedModelIndex = idx
696+
normalizeSelectedVariantForCurrentSession()
646697
}
647698

648699
private func inferAndStoreModelForCurrentSessionIfMissing() {
649700
guard let sessionID = currentSessionID else { return }
650-
guard selectedModelIDBySessionID[sessionID] == nil else { return }
651-
652-
guard let info = messages.reversed().compactMap({ $0.info.resolvedModel }).first else { return }
653-
guard let idx = modelPresets.firstIndex(where: { $0.providerID == info.providerID && $0.modelID == info.modelID }) else { return }
654-
655-
selectedModelIndex = idx
656-
selectedModelIDBySessionID[sessionID] = modelPresets[idx].id
657-
persistSelectedModelMap()
701+
if selectedModelIDBySessionID[sessionID] == nil,
702+
let info = messages.reversed().compactMap({ $0.info.resolvedModel }).first,
703+
let idx = modelPresets.firstIndex(where: { $0.providerID == info.providerID && $0.modelID == info.modelID }) {
704+
selectedModelIndex = idx
705+
selectedModelIDBySessionID[sessionID] = modelPresets[idx].id
706+
persistSelectedModelMap()
707+
}
708+
709+
if selectedVariantBySessionID[sessionID] == nil,
710+
let variant = messages.reversed()
711+
.compactMap({ $0.info.variant?.trimmingCharacters(in: .whitespacesAndNewlines) })
712+
.first(where: { !$0.isEmpty }) {
713+
selectedVariantBySessionID[sessionID] = variant
714+
persistSelectedVariantMap()
715+
}
716+
717+
normalizeSelectedVariantForCurrentSession()
718+
}
719+
720+
private func normalizeSelectedVariantForCurrentSession() {
721+
guard let sessionID = currentSessionID,
722+
let saved = selectedVariantBySessionID[sessionID]?.trimmingCharacters(in: .whitespacesAndNewlines),
723+
!saved.isEmpty else { return }
724+
let available = selectedModelVariants
725+
if available.isEmpty {
726+
if selectedVariantBySessionID[sessionID] != saved {
727+
selectedVariantBySessionID[sessionID] = saved
728+
persistSelectedVariantMap()
729+
}
730+
return
731+
}
732+
guard available.contains(saved) else {
733+
selectedVariantBySessionID[sessionID] = nil
734+
persistSelectedVariantMap()
735+
return
736+
}
737+
if selectedVariantBySessionID[sessionID] != saved {
738+
selectedVariantBySessionID[sessionID] = saved
739+
persistSelectedVariantMap()
740+
}
741+
}
742+
743+
nonisolated private static func sortedVariants(_ values: [String]) -> [String] {
744+
let order: [String: Int] = [
745+
"none": 0,
746+
"minimal": 1,
747+
"low": 2,
748+
"medium": 3,
749+
"high": 4,
750+
"xhigh": 5,
751+
"max": 6,
752+
]
753+
return values.sorted {
754+
let lhs = order[$0.lowercased()] ?? Int.max
755+
let rhs = order[$1.lowercased()] ?? Int.max
756+
if lhs == rhs { return $0.localizedCaseInsensitiveCompare($1) == .orderedAscending }
757+
return lhs < rhs
758+
}
759+
}
760+
761+
nonisolated static func displayName(forVariant variant: String?) -> String {
762+
guard let variant else { return "Auto" }
763+
switch variant.lowercased() {
764+
case "none":
765+
return "None"
766+
case "minimal":
767+
return "Minimal"
768+
case "low":
769+
return "Low"
770+
case "medium":
771+
return "Medium"
772+
case "high":
773+
return "High"
774+
case "xhigh":
775+
return "Extra High"
776+
case "max":
777+
return "Max"
778+
default:
779+
return variant
780+
.replacingOccurrences(of: "_", with: " ")
781+
.replacingOccurrences(of: "-", with: " ")
782+
.capitalized
783+
}
658784
}
659785

660786
var currentSession: Session? {
@@ -827,6 +953,7 @@ final class AppState {
827953
do {
828954
let session = try await apiClient.createSession()
829955
guard sessionLoadingID == loadingID else { return }
956+
let variant = selectedVariant
830957

831958
Self.logger.debug("createSession: created id=\(session.id, privacy: .public) directory=\(session.directory, privacy: .public) effectiveProjectDir=\(self.effectiveProjectDirectory ?? "nil", privacy: .public)")
832959

@@ -836,6 +963,10 @@ final class AppState {
836963
selectedModelIDBySessionID[session.id] = m.id
837964
persistSelectedModelMap()
838965
}
966+
if let variant {
967+
selectedVariantBySessionID[session.id] = variant
968+
persistSelectedVariantMap()
969+
}
839970
messages = []
840971
partsByMessage = [:]
841972
} catch {
@@ -1150,9 +1281,10 @@ final class AppState {
11501281
}
11511282
let tempMessageID = appendOptimisticUserMessage(text)
11521283
let model = selectedModel.map { Message.ModelInfo(providerID: $0.providerID, modelID: $0.modelID) }
1284+
let variant = selectedVariant
11531285
let agentName = selectedAgent?.name ?? "build"
11541286
do {
1155-
try await apiClient.promptAsync(sessionID: sessionID, text: text, agent: agentName, model: model)
1287+
try await apiClient.promptAsync(sessionID: sessionID, text: text, agent: agentName, model: model, variant: variant)
11561288
return true
11571289
} catch {
11581290
let recovered = await recoverFromMissingCurrentSessionIfNeeded(error: error, requestedSessionID: sessionID)
@@ -1176,6 +1308,7 @@ final class AppState {
11761308
providerID: nil,
11771309
modelID: nil,
11781310
model: nil,
1311+
variant: nil,
11791312
error: nil,
11801313
time: Message.TimeInfo(created: now, completed: now),
11811314
finish: nil,
@@ -1654,6 +1787,7 @@ final class AppState {
16541787
providerID: nil,
16551788
modelID: nil,
16561789
model: nil,
1790+
variant: nil,
16571791
error: nil,
16581792
time: Message.TimeInfo(created: now, completed: now),
16591793
finish: nil,
@@ -1713,6 +1847,9 @@ final class AppState {
17131847

17141848
selectedModelIDBySessionID[sessionID] = nil
17151849
persistSelectedModelMap()
1850+
1851+
selectedVariantBySessionID[sessionID] = nil
1852+
persistSelectedVariantMap()
17161853
}
17171854

17181855
private func isSessionNotFoundError(_ error: Error) -> Bool {
@@ -1801,6 +1938,7 @@ final class AppState {
18011938
}
18021939
}
18031940
providerModelsIndex = idx
1941+
normalizeSelectedVariantForCurrentSession()
18041942
} catch {
18051943
providerConfigError = error.localizedDescription
18061944
}

OpenCodeClient/OpenCodeClient/Models/Message.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ struct Message: Codable, Identifiable {
1414
let providerID: String?
1515
let modelID: String?
1616
let model: ModelInfo?
17+
let variant: String?
1718
let error: MessageError?
1819
let time: TimeInfo
1920
let finish: String?

OpenCodeClient/OpenCodeClient/Services/APIClient.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,12 @@ actor APIClient {
233233
return try? decoder.decode(type, from: data)
234234
}
235235

236-
func promptAsync(sessionID: String, text: String, agent: String = "build", model: Message.ModelInfo?) async throws {
236+
func promptAsync(sessionID: String, text: String, agent: String = "build", model: Message.ModelInfo?, variant: String?) async throws {
237237
struct PromptBody: Encodable {
238238
let parts: [PartInput]
239239
let agent: String
240240
let model: ModelInput?
241+
let variant: String?
241242
struct PartInput: Encodable {
242243
let type = "text"
243244
let text: String
@@ -250,7 +251,8 @@ actor APIClient {
250251
let body = PromptBody(
251252
parts: [.init(text: text)],
252253
agent: agent,
253-
model: model.map { .init(providerID: $0.providerID, modelID: $0.modelID) }
254+
model: model.map { .init(providerID: $0.providerID, modelID: $0.modelID) },
255+
variant: variant
254256
)
255257
let bodyData = try JSONEncoder().encode(body)
256258
let (_, response) = try await makeRequest(path: "/session/\(sessionID)/prompt_async", method: "POST", body: bodyData)
@@ -547,20 +549,23 @@ struct ProviderModel: Decodable {
547549
let name: String?
548550
let providerID: String?
549551
let limit: ProviderModelLimit?
552+
let variants: [String]
550553

551554
private enum CodingKeys: String, CodingKey {
552555
case id
553556
case name
554557
case providerID
555558
case providerId
556559
case limit
560+
case variants
557561
}
558562

559-
init(id: String, name: String?, providerID: String?, limit: ProviderModelLimit?) {
563+
init(id: String, name: String?, providerID: String?, limit: ProviderModelLimit?, variants: [String] = []) {
560564
self.id = id
561565
self.name = name
562566
self.providerID = providerID
563567
self.limit = limit
568+
self.variants = variants
564569
}
565570

566571
init(from decoder: Decoder) throws {
@@ -569,6 +574,7 @@ struct ProviderModel: Decodable {
569574
name = try? c.decode(String.self, forKey: .name)
570575
providerID = (try? c.decode(String.self, forKey: .providerID)) ?? (try? c.decode(String.self, forKey: .providerId))
571576
limit = try? c.decode(ProviderModelLimit.self, forKey: .limit)
577+
variants = (try? c.decode([String: AnyCodable].self, forKey: .variants).keys.sorted()) ?? []
572578
}
573579
}
574580

OpenCodeClient/OpenCodeClient/Views/Chat/ChatToolbarView.swift

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ struct ChatToolbarView: View {
8686
private var rightButtons: some View {
8787
HStack(spacing: LayoutConstants.Toolbar.modelButtonSpacing) {
8888
modelMenu
89+
if !state.selectedModelVariants.isEmpty {
90+
effortMenu
91+
}
8992
agentMenu
9093
ContextUsageButton(state: state)
9194

@@ -131,6 +134,47 @@ struct ChatToolbarView: View {
131134
}
132135
.menuStyle(.borderlessButton)
133136
}
137+
138+
private var effortMenu: some View {
139+
Menu {
140+
Button {
141+
state.setSelectedVariant(nil)
142+
} label: {
143+
HStack {
144+
Text("Auto")
145+
if state.selectedVariant == nil {
146+
Image(systemName: "checkmark")
147+
}
148+
}
149+
}
150+
151+
ForEach(state.selectedModelVariants, id: \.self) { variant in
152+
Button {
153+
state.setSelectedVariant(variant)
154+
} label: {
155+
HStack {
156+
Text(AppState.displayName(forVariant: variant))
157+
if state.selectedVariant == variant {
158+
Image(systemName: "checkmark")
159+
}
160+
}
161+
}
162+
}
163+
} label: {
164+
HStack(spacing: 4) {
165+
Text(state.selectedVariantDisplayName)
166+
.font(.caption.weight(.semibold))
167+
Image(systemName: "chevron.down")
168+
.font(.caption2)
169+
}
170+
.padding(.horizontal, 12)
171+
.padding(.vertical, 7)
172+
.background(Color(.systemGray5))
173+
.foregroundColor(.primary)
174+
.clipShape(Capsule())
175+
}
176+
.menuStyle(.borderlessButton)
177+
}
134178

135179
// MARK: - Agent Selection Menu
136180
private var agentMenu: some View {
509 KB
Loading

0 commit comments

Comments
 (0)