From d27310d04268d0ae3464cb2db44a3ca1a0911061 Mon Sep 17 00:00:00 2001 From: noorbhatia Date: Fri, 13 Mar 2026 13:30:33 +0530 Subject: [PATCH 1/9] Add additionalContext support to MLXLanguageModel --- .../Models/MLXLanguageModel.swift | 55 +++++++++++++++++-- .../MLXLanguageModelTests.swift | 22 ++++++++ 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 8f23b2f..f405fb5 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -183,6 +183,16 @@ import Foundation /// let model = MLXLanguageModel(modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit") /// ``` public struct MLXLanguageModel: LanguageModel { + /// Custom generation options for MLX models. + public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions { + /// Additional key-value pairs injected into the chat template rendering context. + public var additionalContext: [String: MLXLMCommon.JSONValue]? + + public init(additionalContext: [String: MLXLMCommon.JSONValue]? = nil) { + self.additionalContext = additionalContext + } + } + /// The reason the model is unavailable. public enum UnavailableReason: Sendable, Equatable, Hashable { /// The model has not been loaded into memory yet. @@ -813,6 +823,11 @@ import Foundation // Map AnyLanguageModel GenerationOptions to MLX GenerateParameters let generateParameters = toGenerateParameters(options) + // Extract additional context from custom options + let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self] + .flatMap { $0.additionalContext } + .map { $0.mapValues { $0.toSendable() } } + // Build chat history from full transcript var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) @@ -828,7 +843,8 @@ import Foundation let userInput = MLXLMCommon.UserInput( chat: chat, processing: .init(resize: .init(width: 512, height: 512)), - tools: toolSpecs + tools: toolSpecs, + additionalContext: additionalContext, ) let lmInput = try await context.processor.prepare(input: userInput) let resolved = resolveCache( @@ -991,10 +1007,17 @@ import Foundation // Build chat inside task to avoid Sendable issues let generateParameters = toGenerateParameters(options) - let userInput = makeUserInput( - session: session, - fallbackPrompt: prompt.description, - tools: nil + let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) + + let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self] + .flatMap { $0.additionalContext } + .map { $0.mapValues { $0.toSendable() } } + + let userInput = MLXLMCommon.UserInput( + chat: chat, + processing: .init(resize: .init(width: 512, height: 512)), + tools: nil, + additionalContext: additionalContext ) let lmInput = try await context.processor.prepare(input: userInput) let resolved = resolveCache( @@ -1529,10 +1552,16 @@ import Foundation let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt) + + let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self] + .flatMap { $0.additionalContext } + .map { $0.mapValues { $0.toSendable() } } + let userInput = MLXLMCommon.UserInput( chat: chat, processing: .init(resize: .init(width: 512, height: 512)), - tools: nil + tools: nil, + additionalContext: additionalContext, ) let lmInput = try await context.processor.prepare(input: userInput) @@ -1773,4 +1802,18 @@ import Foundation return sampledToken.item(Int.self) } } + extension MLXLMCommon.JSONValue { + /// Recursively converts a `JSONValue` to its primitive Swift equivalent. + func toSendable() -> any Sendable { + switch self { + case .string(let s): return s + case .int(let i): return i + case .double(let d): return d + case .bool(let b): return b + case .null: return NSNull() + case .array(let arr): return arr.map { $0.toSendable() } + case .object(let obj): return obj.mapValues { $0.toSendable() } + } + } + } #endif // MLX diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index 937b32e..435d290 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -255,6 +255,28 @@ import Testing #expect([Priority.low, Priority.medium, Priority.high].contains(response.content)) } + @Test func withAdditionalContext() async throws { + let session = LanguageModelSession(model: model) + + var options = GenerationOptions( + temperature: 0.7, + maximumResponseTokens: 32 + ) + options[custom: MLXLanguageModel.self] = .init( + additionalContext: [ + "user_name": .string("Alice"), + "turn_count": .int(3), + "verbose": .bool(true), + ] + ) + + let response = try await session.respond( + to: "Say hello", + options: options + ) + #expect(!response.content.isEmpty) + } + @Test func unavailableForNonexistentModel() async { let model = MLXLanguageModel(modelId: "mlx-community/does-not-exist-anylanguagemodel-test") await model.removeFromCache() From a08700a0bcedc2f70f6ceaa0b8f963de838709aa Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 23 Mar 2026 11:40:44 -0700 Subject: [PATCH 2/9] Fix merge conflict resolution --- .../Models/MLXLanguageModel.swift | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index f405fb5..0b611d5 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -183,16 +183,6 @@ import Foundation /// let model = MLXLanguageModel(modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit") /// ``` public struct MLXLanguageModel: LanguageModel { - /// Custom generation options for MLX models. - public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions { - /// Additional key-value pairs injected into the chat template rendering context. - public var additionalContext: [String: MLXLMCommon.JSONValue]? - - public init(additionalContext: [String: MLXLMCommon.JSONValue]? = nil) { - self.additionalContext = additionalContext - } - } - /// The reason the model is unavailable. public enum UnavailableReason: Sendable, Equatable, Hashable { /// The model has not been loaded into memory yet. @@ -218,6 +208,12 @@ import Foundation public var kvGroupSize: Int /// Sets the token offset where quantized KV storage starts. public var quantizedKVStart: Int + /// Additional key-value pairs injected into the chat template rendering context. + public var additionalContext: [String: MLXLMCommon.JSONValue]? + + var additionalContextForUserInput: [String: any Sendable]? { + additionalContext?.mapValues { $0.toSendable() } + } /// Creates MLX-specific generation options. /// @@ -228,16 +224,20 @@ import Foundation /// Pass `nil` to disable KV quantization. /// - kvGroupSize: The token group size used for KV quantization. /// - quantizedKVStart: The token index where quantized KV storage begins. + /// - additionalContext: Additional key-value pairs injected into the chat + /// template rendering context. public init( maxKVSize: Int? = nil, kvBits: Int? = nil, kvGroupSize: Int = 64, - quantizedKVStart: Int = 0 + quantizedKVStart: Int = 0, + additionalContext: [String: MLXLMCommon.JSONValue]? = nil ) { self.maxKVSize = maxKVSize self.kvBits = kvBits self.kvGroupSize = kvGroupSize self.quantizedKVStart = quantizedKVStart + self.additionalContext = additionalContext } } @@ -824,9 +824,7 @@ import Foundation let generateParameters = toGenerateParameters(options) // Extract additional context from custom options - let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self] - .flatMap { $0.additionalContext } - .map { $0.mapValues { $0.toSendable() } } + let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput // Build chat history from full transcript var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) @@ -1009,9 +1007,7 @@ import Foundation let generateParameters = toGenerateParameters(options) let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) - let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self] - .flatMap { $0.additionalContext } - .map { $0.mapValues { $0.toSendable() } } + let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput let userInput = MLXLMCommon.UserInput( chat: chat, @@ -1553,9 +1549,7 @@ import Foundation let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt) - let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self] - .flatMap { $0.additionalContext } - .map { $0.mapValues { $0.toSendable() } } + let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput let userInput = MLXLMCommon.UserInput( chat: chat, From edcc9868a6d3e0e7afc00fdf0802961d3841a165 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 24 Mar 2026 00:30:58 -0700 Subject: [PATCH 3/9] Incorporate feedback from review --- .../Models/MLXLanguageModel.swift | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 0b611d5..e6a8b45 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -209,7 +209,7 @@ import Foundation /// Sets the token offset where quantized KV storage starts. public var quantizedKVStart: Int /// Additional key-value pairs injected into the chat template rendering context. - public var additionalContext: [String: MLXLMCommon.JSONValue]? + public var additionalContext: [String: JSONValue]? var additionalContextForUserInput: [String: any Sendable]? { additionalContext?.mapValues { $0.toSendable() } @@ -231,7 +231,7 @@ import Foundation kvBits: Int? = nil, kvGroupSize: Int = 64, quantizedKVStart: Int = 0, - additionalContext: [String: MLXLMCommon.JSONValue]? = nil + additionalContext: [String: JSONValue]? = nil ) { self.maxKVSize = maxKVSize self.kvBits = kvBits @@ -773,13 +773,23 @@ import Foundation private func makeUserInput( session: LanguageModelSession, fallbackPrompt: String, - tools: [ToolSpec]? + tools: [ToolSpec]?, + additionalContext: [String: any Sendable]? = nil ) -> MLXLMCommon.UserInput { let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: fallbackPrompt) + return makeUserInput(chat: chat, tools: tools, additionalContext: additionalContext) + } + + private func makeUserInput( + chat: [MLXLMCommon.Chat.Message], + tools: [ToolSpec]?, + additionalContext: [String: any Sendable]? = nil + ) -> MLXLMCommon.UserInput { return MLXLMCommon.UserInput( chat: chat, processing: .init(resize: .init(width: 512, height: 512)), - tools: tools + tools: tools, + additionalContext: additionalContext, ) } @@ -838,11 +848,10 @@ import Foundation // Loop until no more tool calls while true { // Build user input with current chat history and tools - let userInput = MLXLMCommon.UserInput( + let userInput = makeUserInput( chat: chat, - processing: .init(resize: .init(width: 512, height: 512)), tools: toolSpecs, - additionalContext: additionalContext, + additionalContext: additionalContext ) let lmInput = try await context.processor.prepare(input: userInput) let resolved = resolveCache( @@ -1005,13 +1014,11 @@ import Foundation // Build chat inside task to avoid Sendable issues let generateParameters = toGenerateParameters(options) - let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) - let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput - let userInput = MLXLMCommon.UserInput( - chat: chat, - processing: .init(resize: .init(width: 512, height: 512)), + let userInput = makeUserInput( + session: session, + fallbackPrompt: prompt.description, tools: nil, additionalContext: additionalContext ) @@ -1551,11 +1558,10 @@ import Foundation let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput - let userInput = MLXLMCommon.UserInput( + let userInput = makeUserInput( chat: chat, - processing: .init(resize: .init(width: 512, height: 512)), tools: nil, - additionalContext: additionalContext, + additionalContext: additionalContext ) let lmInput = try await context.processor.prepare(input: userInput) @@ -1796,7 +1802,7 @@ import Foundation return sampledToken.item(Int.self) } } - extension MLXLMCommon.JSONValue { + extension JSONValue { /// Recursively converts a `JSONValue` to its primitive Swift equivalent. func toSendable() -> any Sendable { switch self { From 38ecb34675b995c6a41a4e108d5696e8955fe7c9 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 24 Mar 2026 00:51:52 -0700 Subject: [PATCH 4/9] Add userInputProcessing property to MLX custom generation options --- README.md | 13 +++- .../Models/MLXLanguageModel.swift | 63 ++++++++++++++----- 2 files changed, 60 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index c54846f..c7b6ff9 100644 --- a/README.md +++ b/README.md @@ -488,7 +488,8 @@ let response = try await session.respond { } ``` -You can tune MLX KV-cache behavior per request with model-specific options: +You can tune MLX request behavior per call with model-specific options, +including KV-cache settings and optional media preprocessing: ```swift var options = GenerationOptions(temperature: 0.7) @@ -496,7 +497,9 @@ options[custom: MLXLanguageModel.self] = .init( maxKVSize: 4096, kvBits: 4, kvGroupSize: 64, - quantizedKVStart: 128 + quantizedKVStart: 128, + // Apply a deterministic preprocessing step for image inputs. + userInputProcessing: .init(resize: .init(width: 512, height: 512)) ) let response = try await session.respond( @@ -505,6 +508,12 @@ let response = try await session.respond( ) ``` +You can specify `userInputProcessing` to enforce a consistent image +preprocessing step +(for example, fixed dimensions for predictable latency, memory usage, and vision behavior). +By default, images are passed through without an explicit resize override +(`resize: nil`), so MLX applies its default media processing behavior. + GPU cache behavior can be configured when creating the model: ```swift diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index e6a8b45..5afed52 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -196,6 +196,30 @@ import Foundation /// Set these values through ``GenerationOptions`` using /// `GenerationOptions[custom: MLXLanguageModel.self]`. public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions, Codable { + /// Configures media preprocessing applied before model input. + public struct UserInputProcessing: Codable, Equatable, Sendable { + /// Optional resize target applied to media before tokenization. + public var resize: CGSize? + + /// Creates user-input processing configuration. + /// + /// - Parameter resize: Optional target size for media resizing. + public init(resize: CGSize? = nil) { + self.resize = resize + } + + var mlxValue: MLXLMCommon.UserInput.Processing { + .init(resize: resize) + } + } + /// Processing to apply to user media before input preparation. + public var userInputProcessing: UserInputProcessing? + + var processingForUserInput: MLXLMCommon.UserInput.Processing { + userInputProcessing?.mlxValue + ?? .init(resize: nil) + } + /// Limits how many tokens the KV cache retains. /// /// Set this to `nil` to use the backend default. @@ -208,6 +232,7 @@ import Foundation public var kvGroupSize: Int /// Sets the token offset where quantized KV storage starts. public var quantizedKVStart: Int + /// Additional key-value pairs injected into the chat template rendering context. public var additionalContext: [String: JSONValue]? @@ -226,11 +251,14 @@ import Foundation /// - quantizedKVStart: The token index where quantized KV storage begins. /// - additionalContext: Additional key-value pairs injected into the chat /// template rendering context. + /// - userInputProcessing: Processing to apply to user media before input preparation. + /// Defaults to `nil`, which lets MLX use its default media handling. public init( maxKVSize: Int? = nil, kvBits: Int? = nil, kvGroupSize: Int = 64, quantizedKVStart: Int = 0, + userInputProcessing: UserInputProcessing? = nil, additionalContext: [String: JSONValue]? = nil ) { self.maxKVSize = maxKVSize @@ -238,6 +266,7 @@ import Foundation self.kvGroupSize = kvGroupSize self.quantizedKVStart = quantizedKVStart self.additionalContext = additionalContext + self.userInputProcessing = userInputProcessing } } @@ -770,24 +799,15 @@ import Foundation session.tools.isEmpty ? nil : session.tools.map { convertToolToMLXSpec($0) } } - private func makeUserInput( - session: LanguageModelSession, - fallbackPrompt: String, - tools: [ToolSpec]?, - additionalContext: [String: any Sendable]? = nil - ) -> MLXLMCommon.UserInput { - let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: fallbackPrompt) - return makeUserInput(chat: chat, tools: tools, additionalContext: additionalContext) - } - private func makeUserInput( chat: [MLXLMCommon.Chat.Message], tools: [ToolSpec]?, + processing: MLXLMCommon.UserInput.Processing = .init(resize: nil), additionalContext: [String: any Sendable]? = nil ) -> MLXLMCommon.UserInput { return MLXLMCommon.UserInput( chat: chat, - processing: .init(resize: .init(width: 512, height: 512)), + processing: processing, tools: tools, additionalContext: additionalContext, ) @@ -835,6 +855,9 @@ import Foundation // Extract additional context from custom options let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput + let userInputProcessing = + options[custom: MLXLanguageModel.self]?.processingForUserInput + ?? .init(resize: nil) // Build chat history from full transcript var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) @@ -851,6 +874,7 @@ import Foundation let userInput = makeUserInput( chat: chat, tools: toolSpecs, + processing: userInputProcessing, additionalContext: additionalContext ) let lmInput = try await context.processor.prepare(input: userInput) @@ -1015,11 +1039,18 @@ import Foundation // Build chat inside task to avoid Sendable issues let generateParameters = toGenerateParameters(options) let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput + let userInputProcessing = + options[custom: MLXLanguageModel.self]?.processingForUserInput + ?? .init(resize: nil) + let chat = convertTranscriptToMLXChat( + session: session, + fallbackPrompt: prompt.description + ) let userInput = makeUserInput( - session: session, - fallbackPrompt: prompt.description, + chat: chat, tools: nil, + processing: userInputProcessing, additionalContext: additionalContext ) let lmInput = try await context.processor.prepare(input: userInput) @@ -1118,7 +1149,7 @@ import Foundation let newCache = context.model.newCache(parameters: params) let userInput = MLXLMCommon.UserInput( chat: [.init(role: .system, content: instructions)], - processing: .init(resize: .init(width: 512, height: 512)), + processing: .init(resize: nil), tools: toolSpecs ) let lmInput = try await context.processor.prepare(input: userInput) @@ -1557,10 +1588,14 @@ import Foundation let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt) let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput + let userInputProcessing = + options[custom: MLXLanguageModel.self]?.processingForUserInput + ?? .init(resize: nil) let userInput = makeUserInput( chat: chat, tools: nil, + processing: userInputProcessing, additionalContext: additionalContext ) let lmInput = try await context.processor.prepare(input: userInput) From 9091dac3fdd23246eda39554b6f99f39c885f901 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 24 Mar 2026 01:08:40 -0700 Subject: [PATCH 5/9] Group KV cache generation options into struct --- README.md | 16 +-- .../Models/MLXLanguageModel.swift | 114 ++++++++++++------ .../CustomGenerationOptionsTests.swift | 52 +++++--- .../MLXLanguageModelTests.swift | 14 +-- 4 files changed, 124 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index c7b6ff9..49ddc4a 100644 --- a/README.md +++ b/README.md @@ -493,14 +493,16 @@ including KV-cache settings and optional media preprocessing: ```swift var options = GenerationOptions(temperature: 0.7) -options[custom: MLXLanguageModel.self] = .init( - maxKVSize: 4096, - kvBits: 4, - kvGroupSize: 64, - quantizedKVStart: 128, - // Apply a deterministic preprocessing step for image inputs. - userInputProcessing: .init(resize: .init(width: 512, height: 512)) +var mlxOptions = MLXLanguageModel.CustomGenerationOptions.default +mlxOptions.kvCache = .init( + maxSize: 4096, + bits: 4, + groupSize: 64, + quantizedStart: 128 ) +// Apply a deterministic preprocessing step for image inputs. +mlxOptions.userInputProcessing = .init(resize: .init(width: 512, height: 512)) +options[custom: MLXLanguageModel.self] = mlxOptions let response = try await session.respond( to: "Summarize this transcript", diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 5afed52..c6ebdfb 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -196,6 +196,59 @@ import Foundation /// Set these values through ``GenerationOptions`` using /// `GenerationOptions[custom: MLXLanguageModel.self]`. public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions, Codable { + /// Configures KV-cache behavior for MLX generation. + public struct KVCache: Codable, Equatable, Sendable { + /// Limits how many tokens the KV cache retains. + /// + /// Set this to `nil` to use the backend default. + public var maxSize: Int? + + /// Sets the KV-cache quantization bit width. + /// + /// Set this to `nil` to disable KV quantization. + public var bits: Int? + + /// Sets the token group size used for KV quantization. + public var groupSize: Int + + /// Sets the token offset where quantized KV storage starts. + public var quantizedStart: Int + + /// Default KV-cache options used when none are provided at runtime. + /// By default, the token group size is 64 and the quantized start is 0. + public static var `default`: Self { + .init( + maxSize: nil, + bits: nil, + groupSize: 64, + quantizedStart: 0 + ) + } + + /// Creates KV-cache configuration for MLX generation. + /// + /// - Parameters: + /// - maxSize: The maximum number of tokens to retain in KV cache storage. + /// Pass `nil` to use the backend default. + /// - bits: The KV-cache quantization bit width. + /// Pass `nil` to disable KV quantization. + /// - groupSize: The token group size used for KV quantization. + /// - quantizedStart: The token index where quantized KV storage begins. + public init( + maxSize: Int?, + bits: Int?, + groupSize: Int, + quantizedStart: Int + ) { + self.maxSize = maxSize + self.bits = bits + self.groupSize = groupSize + self.quantizedStart = quantizedStart + } + } + /// KV-cache configuration used for generation. + public var kvCache: KVCache + /// Configures media preprocessing applied before model input. public struct UserInputProcessing: Codable, Equatable, Sendable { /// Optional resize target applied to media before tokenization. @@ -204,7 +257,7 @@ import Foundation /// Creates user-input processing configuration. /// /// - Parameter resize: Optional target size for media resizing. - public init(resize: CGSize? = nil) { + public init(resize: CGSize?) { self.resize = resize } @@ -220,19 +273,6 @@ import Foundation ?? .init(resize: nil) } - /// Limits how many tokens the KV cache retains. - /// - /// Set this to `nil` to use the backend default. - public var maxKVSize: Int? - /// Sets the KV-cache quantization bit width. - /// - /// Set this to `nil` to disable KV quantization. - public var kvBits: Int? - /// Sets the token group size used for KV quantization. - public var kvGroupSize: Int - /// Sets the token offset where quantized KV storage starts. - public var quantizedKVStart: Int - /// Additional key-value pairs injected into the chat template rendering context. public var additionalContext: [String: JSONValue]? @@ -243,31 +283,29 @@ import Foundation /// Creates MLX-specific generation options. /// /// - Parameters: - /// - maxKVSize: The maximum number of tokens to retain in KV cache storage. - /// Pass `nil` to use the backend default. - /// - kvBits: The KV-cache quantization bit width. - /// Pass `nil` to disable KV quantization. - /// - kvGroupSize: The token group size used for KV quantization. - /// - quantizedKVStart: The token index where quantized KV storage begins. + /// - kvCache: KV-cache configuration used for generation. /// - additionalContext: Additional key-value pairs injected into the chat /// template rendering context. /// - userInputProcessing: Processing to apply to user media before input preparation. /// Defaults to `nil`, which lets MLX use its default media handling. public init( - maxKVSize: Int? = nil, - kvBits: Int? = nil, - kvGroupSize: Int = 64, - quantizedKVStart: Int = 0, - userInputProcessing: UserInputProcessing? = nil, - additionalContext: [String: JSONValue]? = nil + kvCache: KVCache, + userInputProcessing: UserInputProcessing?, + additionalContext: [String: JSONValue]? ) { - self.maxKVSize = maxKVSize - self.kvBits = kvBits - self.kvGroupSize = kvGroupSize - self.quantizedKVStart = quantizedKVStart + self.kvCache = kvCache self.additionalContext = additionalContext self.userInputProcessing = userInputProcessing } + + /// Default MLX generation options used when none are provided at runtime. + public static var `default`: Self { + .init( + kvCache: .default, + userInputProcessing: nil, + additionalContext: nil + ) + } } /// Controls GPU buffer-pool limits during active and idle phases. @@ -1173,10 +1211,10 @@ import Foundation let custom = options[custom: MLXLanguageModel.self] return MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, - maxKVSize: custom?.maxKVSize, - kvBits: custom?.kvBits, - kvGroupSize: custom?.kvGroupSize ?? 64, - quantizedKVStart: custom?.quantizedKVStart ?? 0, + maxKVSize: custom?.kvCache.maxSize, + kvBits: custom?.kvCache.bits, + kvGroupSize: custom?.kvCache.groupSize ?? 64, + quantizedKVStart: custom?.kvCache.quantizedStart ?? 0, temperature: Float(options.temperature ?? 0.6), topP: 1.0, repetitionPenalty: nil, @@ -1189,10 +1227,10 @@ import Foundation let custom = options[custom: MLXLanguageModel.self] return MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, - maxKVSize: custom?.maxKVSize, - kvBits: custom?.kvBits, - kvGroupSize: custom?.kvGroupSize ?? 64, - quantizedKVStart: custom?.quantizedKVStart ?? 0, + maxKVSize: custom?.kvCache.maxSize, + kvBits: custom?.kvCache.bits, + kvGroupSize: custom?.kvCache.groupSize ?? 64, + quantizedKVStart: custom?.kvCache.quantizedStart ?? 0, temperature: Float(options.temperature ?? 0.2), topP: 0.95, repetitionPenalty: 1.1, diff --git a/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift b/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift index 89539a1..3a4a812 100644 --- a/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift +++ b/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift @@ -867,38 +867,50 @@ struct GeminiCustomOptionsTests { struct MLXCustomOptionsTests { @Test func initialization() { let options = MLXLanguageModel.CustomGenerationOptions( - maxKVSize: 4096, - kvBits: 4, - kvGroupSize: 64, - quantizedKVStart: 128 + kvCache: .init( + maxSize: 4096, + bits: 4, + groupSize: 64, + quantizedStart: 128 + ), + userInputProcessing: nil, + additionalContext: nil ) - #expect(options.maxKVSize == 4096) - #expect(options.kvBits == 4) - #expect(options.kvGroupSize == 64) - #expect(options.quantizedKVStart == 128) + #expect(options.kvCache.maxSize == 4096) + #expect(options.kvCache.bits == 4) + #expect(options.kvCache.groupSize == 64) + #expect(options.kvCache.quantizedStart == 128) } @Test func integrationWithGenerationOptions() { var options = GenerationOptions(temperature: 0.7) options[custom: MLXLanguageModel.self] = .init( - maxKVSize: 2048, - kvBits: 8, - kvGroupSize: 32, - quantizedKVStart: 256 + kvCache: .init( + maxSize: 2048, + bits: 8, + groupSize: 32, + quantizedStart: 256 + ), + userInputProcessing: nil, + additionalContext: nil ) let retrieved = options[custom: MLXLanguageModel.self] - #expect(retrieved?.maxKVSize == 2048) - #expect(retrieved?.kvBits == 8) - #expect(retrieved?.kvGroupSize == 32) - #expect(retrieved?.quantizedKVStart == 256) + #expect(retrieved?.kvCache.maxSize == 2048) + #expect(retrieved?.kvCache.bits == 8) + #expect(retrieved?.kvCache.groupSize == 32) + #expect(retrieved?.kvCache.quantizedStart == 256) } @Test func codable() throws { let options = MLXLanguageModel.CustomGenerationOptions( - maxKVSize: 8192, - kvBits: 4, - kvGroupSize: 64, - quantizedKVStart: 0 + kvCache: .init( + maxSize: 8192, + bits: 4, + groupSize: 64, + quantizedStart: 0 + ), + userInputProcessing: nil, + additionalContext: nil ) let data = try JSONEncoder().encode(options) let decoded = try JSONDecoder().decode(MLXLanguageModel.CustomGenerationOptions.self, from: data) diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index 435d290..3837514 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -262,13 +262,13 @@ import Testing temperature: 0.7, maximumResponseTokens: 32 ) - options[custom: MLXLanguageModel.self] = .init( - additionalContext: [ - "user_name": .string("Alice"), - "turn_count": .int(3), - "verbose": .bool(true), - ] - ) + var custom = MLXLanguageModel.CustomGenerationOptions.default + custom.additionalContext = [ + "user_name": .string("Alice"), + "turn_count": .int(3), + "verbose": .bool(true), + ] + options[custom: MLXLanguageModel.self] = custom let response = try await session.respond( to: "Say hello", From 0646e795264c0f814dbb4c1362c8d25db66e96e5 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 24 Mar 2026 01:13:24 -0700 Subject: [PATCH 6/9] Improve ergnomics of resize processor at call site --- README.md | 2 +- Sources/AnyLanguageModel/Models/MLXLanguageModel.swift | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 49ddc4a..f43d3f6 100644 --- a/README.md +++ b/README.md @@ -501,7 +501,7 @@ mlxOptions.kvCache = .init( quantizedStart: 128 ) // Apply a deterministic preprocessing step for image inputs. -mlxOptions.userInputProcessing = .init(resize: .init(width: 512, height: 512)) +mlxOptions.userInputProcessing = .resize(to: CGSize(width: 512, height: 512)) options[custom: MLXLanguageModel.self] = mlxOptions let response = try await session.respond( diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index c6ebdfb..5a4cad8 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -257,10 +257,17 @@ import Foundation /// Creates user-input processing configuration. /// /// - Parameter resize: Optional target size for media resizing. - public init(resize: CGSize?) { + init(resize: CGSize?) { self.resize = resize } + /// Creates processing that resizes media to a fixed size. + /// + /// - Parameter size: Target size used for resizing media inputs. + public static func resize(to size: CGSize) -> Self { + .init(resize: size) + } + var mlxValue: MLXLMCommon.UserInput.Processing { .init(resize: resize) } From e90669b5c33a12ce08fd2245ddb5869cf3721ca8 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 24 Mar 2026 01:19:24 -0700 Subject: [PATCH 7/9] Fix compiler errors due to JSONValue ambiguity --- .../AnyLanguageModel/Models/MLXLanguageModel.swift | 14 +++++++------- .../MLXLanguageModelTests.swift | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 5a4cad8..70bf551 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -281,7 +281,7 @@ import Foundation } /// Additional key-value pairs injected into the chat template rendering context. - public var additionalContext: [String: JSONValue]? + public var additionalContext: [String: MLXLMCommon.JSONValue]? var additionalContextForUserInput: [String: any Sendable]? { additionalContext?.mapValues { $0.toSendable() } @@ -298,7 +298,7 @@ import Foundation public init( kvCache: KVCache, userInputProcessing: UserInputProcessing?, - additionalContext: [String: JSONValue]? + additionalContext: [String: MLXLMCommon.JSONValue]? ) { self.kvCache = kvCache self.additionalContext = additionalContext @@ -1591,7 +1591,7 @@ import Foundation { header += ". Expected value: \(constString)" } else if let enumValues = jsonSchema.enum, !enumValues.isEmpty, - let data = try? encoder.encode(JSONValue.array(enumValues)), + let data = try? encoder.encode(enumValues), let enumString = String(data: data, encoding: .utf8) { header += ". Allowed values: \(enumString)" @@ -1637,11 +1637,11 @@ import Foundation options[custom: MLXLanguageModel.self]?.processingForUserInput ?? .init(resize: nil) - let userInput = makeUserInput( + let userInput = MLXLMCommon.UserInput( chat: chat, - tools: nil, processing: userInputProcessing, - additionalContext: additionalContext + tools: nil, + additionalContext: additionalContext, ) let lmInput = try await context.processor.prepare(input: userInput) @@ -1882,7 +1882,7 @@ import Foundation return sampledToken.item(Int.self) } } - extension JSONValue { + extension MLXLMCommon.JSONValue { /// Recursively converts a `JSONValue` to its primitive Swift equivalent. func toSendable() -> any Sendable { switch self { diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index 3837514..f82b16f 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -264,9 +264,9 @@ import Testing ) var custom = MLXLanguageModel.CustomGenerationOptions.default custom.additionalContext = [ - "user_name": .string("Alice"), - "turn_count": .int(3), - "verbose": .bool(true), + "user_name": MLXLMCommon.JSONValue.string("Alice"), + "turn_count": MLXLMCommon.JSONValue.int(3), + "verbose": MLXLMCommon.JSONValue.bool(true), ] options[custom: MLXLanguageModel.self] = custom From bea4b35ef717739dd0175bde625977a1a92d7ced Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 24 Mar 2026 01:33:46 -0700 Subject: [PATCH 8/9] Update expectation for MLX image processing test --- .../MLXLanguageModelTests.swift | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index f82b16f..6a9047d 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -4,6 +4,7 @@ import Testing @testable import AnyLanguageModel #if MLX + import MLXLMCommon private let shouldRunMLXTests = { // Enable when explicitly requested via environment variable if ProcessInfo.processInfo.environment["ENABLE_MLX_TESTS"] != nil { @@ -154,7 +155,11 @@ import Testing ) ]) let session = LanguageModelSession(model: visionModel, transcript: transcript) - let response = try await session.respond(to: "") + var options = GenerationOptions() + var mlxOptions = MLXLanguageModel.CustomGenerationOptions.default + mlxOptions.userInputProcessing = .resize(to: CGSize(width: 512, height: 512)) + options[custom: MLXLanguageModel.self] = mlxOptions + let response = try await session.respond(to: "", options: options) #expect(!response.content.isEmpty) } @@ -168,7 +173,11 @@ import Testing ) ]) let session = LanguageModelSession(model: visionModel, transcript: transcript) - let response = try await session.respond(to: "") + var options = GenerationOptions() + var mlxOptions = MLXLanguageModel.CustomGenerationOptions.default + mlxOptions.userInputProcessing = .resize(to: CGSize(width: 512, height: 512)) + options[custom: MLXLanguageModel.self] = mlxOptions + let response = try await session.respond(to: "", options: options) #expect(!response.content.isEmpty) } From b8348915d9d5a159f426aad6da7fb9e4e6344e0d Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 24 Mar 2026 01:44:06 -0700 Subject: [PATCH 9/9] Incorporate feedback from review --- README.md | 11 ++++++++++- .../AnyLanguageModel/Models/MLXLanguageModel.swift | 6 +++--- Sources/AnyLanguageModel/Shared/JSONValue.swift | 4 ++++ .../AnyLanguageModelTests/MLXLanguageModelTests.swift | 7 +++---- 4 files changed, 20 insertions(+), 8 deletions(-) create mode 100644 Sources/AnyLanguageModel/Shared/JSONValue.swift diff --git a/README.md b/README.md index f43d3f6..d6f53e9 100644 --- a/README.md +++ b/README.md @@ -502,6 +502,12 @@ mlxOptions.kvCache = .init( ) // Apply a deterministic preprocessing step for image inputs. mlxOptions.userInputProcessing = .resize(to: CGSize(width: 512, height: 512)) +// Inject extra template context consumed by model-specific chat templates. +mlxOptions.additionalContext = [ + "user_name": .string("Alice"), + "turn_count": .int(3), + "verbose": .bool(true), +] options[custom: MLXLanguageModel.self] = mlxOptions let response = try await session.respond( @@ -511,11 +517,14 @@ let response = try await session.respond( ``` You can specify `userInputProcessing` to enforce a consistent image -preprocessing step +preprocessing step (for example, fixed dimensions for predictable latency, memory usage, and vision behavior). By default, images are passed through without an explicit resize override (`resize: nil`), so MLX applies its default media processing behavior. +You can also set `additionalContext` to provide extra JSON template variables +for model-specific chat templates. + GPU cache behavior can be configured when creating the model: ```swift diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 70bf551..0ef37ef 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -281,7 +281,7 @@ import Foundation } /// Additional key-value pairs injected into the chat template rendering context. - public var additionalContext: [String: MLXLMCommon.JSONValue]? + public var additionalContext: [String: AnyLanguageModel.JSONValue]? var additionalContextForUserInput: [String: any Sendable]? { additionalContext?.mapValues { $0.toSendable() } @@ -298,7 +298,7 @@ import Foundation public init( kvCache: KVCache, userInputProcessing: UserInputProcessing?, - additionalContext: [String: MLXLMCommon.JSONValue]? + additionalContext: [String: AnyLanguageModel.JSONValue]? ) { self.kvCache = kvCache self.additionalContext = additionalContext @@ -1882,7 +1882,7 @@ import Foundation return sampledToken.item(Int.self) } } - extension MLXLMCommon.JSONValue { + extension AnyLanguageModel.JSONValue { /// Recursively converts a `JSONValue` to its primitive Swift equivalent. func toSendable() -> any Sendable { switch self { diff --git a/Sources/AnyLanguageModel/Shared/JSONValue.swift b/Sources/AnyLanguageModel/Shared/JSONValue.swift new file mode 100644 index 0000000..5d3c315 --- /dev/null +++ b/Sources/AnyLanguageModel/Shared/JSONValue.swift @@ -0,0 +1,4 @@ +import enum JSONSchema.JSONValue + +/// A type-safe representation of JSON values used by AnyLanguageModel APIs. +public typealias JSONValue = JSONSchema.JSONValue diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index 6a9047d..bb048d3 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -4,7 +4,6 @@ import Testing @testable import AnyLanguageModel #if MLX - import MLXLMCommon private let shouldRunMLXTests = { // Enable when explicitly requested via environment variable if ProcessInfo.processInfo.environment["ENABLE_MLX_TESTS"] != nil { @@ -273,9 +272,9 @@ import Testing ) var custom = MLXLanguageModel.CustomGenerationOptions.default custom.additionalContext = [ - "user_name": MLXLMCommon.JSONValue.string("Alice"), - "turn_count": MLXLMCommon.JSONValue.int(3), - "verbose": MLXLMCommon.JSONValue.bool(true), + "user_name": JSONValue.string("Alice"), + "turn_count": JSONValue.int(3), + "verbose": JSONValue.bool(true), ] options[custom: MLXLanguageModel.self] = custom