Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,23 +488,43 @@ let response = try await session.respond {
}
```

You can tune MLX KV-cache behavior per request with model-specific options:
You can tune MLX request behavior per call with model-specific options,
including KV-cache settings and optional media preprocessing:

```swift
var options = GenerationOptions(temperature: 0.7)
options[custom: MLXLanguageModel.self] = .init(
maxKVSize: 4096,
kvBits: 4,
kvGroupSize: 64,
quantizedKVStart: 128
var mlxOptions = MLXLanguageModel.CustomGenerationOptions.default
mlxOptions.kvCache = .init(
maxSize: 4096,
bits: 4,
groupSize: 64,
quantizedStart: 128
)
// Apply a deterministic preprocessing step for image inputs.
mlxOptions.userInputProcessing = .resize(to: CGSize(width: 512, height: 512))
// Inject extra template context consumed by model-specific chat templates.
mlxOptions.additionalContext = [
"user_name": .string("Alice"),
"turn_count": .int(3),
"verbose": .bool(true),
]
options[custom: MLXLanguageModel.self] = mlxOptions

let response = try await session.respond(
to: "Summarize this transcript",
options: options
)
```

You can specify `userInputProcessing` to enforce a consistent image
preprocessing step
(for example, fixed dimensions for predictable latency, memory usage, and vision behavior).
By default, images are passed through without an explicit resize override
(`resize: nil`), so MLX applies its default media processing behavior.

You can also set `additionalContext` to provide extra JSON template variables
for model-specific chat templates.

GPU cache behavior can be configured when creating the model:

```swift
Expand Down
223 changes: 173 additions & 50 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -196,38 +196,122 @@ import Foundation
/// Set these values through ``GenerationOptions`` using
/// `GenerationOptions[custom: MLXLanguageModel.self]`.
public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions, Codable {
/// Limits how many tokens the KV cache retains.
///
/// Set this to `nil` to use the backend default.
public var maxKVSize: Int?
/// Sets the KV-cache quantization bit width.
///
/// Set this to `nil` to disable KV quantization.
public var kvBits: Int?
/// Sets the token group size used for KV quantization.
public var kvGroupSize: Int
/// Sets the token offset where quantized KV storage starts.
public var quantizedKVStart: Int
/// Configures KV-cache behavior for MLX generation.
public struct KVCache: Codable, Equatable, Sendable {
/// Limits how many tokens the KV cache retains.
///
/// Set this to `nil` to use the backend default.
public var maxSize: Int?

/// Sets the KV-cache quantization bit width.
///
/// Set this to `nil` to disable KV quantization.
public var bits: Int?

/// Sets the token group size used for KV quantization.
public var groupSize: Int

/// Sets the token offset where quantized KV storage starts.
public var quantizedStart: Int

/// Default KV-cache options used when none are provided at runtime.
/// By default, the token group size is 64 and the quantized start is 0.
public static var `default`: Self {
.init(
maxSize: nil,
bits: nil,
groupSize: 64,
quantizedStart: 0
)
}

/// Creates KV-cache configuration for MLX generation.
///
/// - Parameters:
/// - maxSize: The maximum number of tokens to retain in KV cache storage.
/// Pass `nil` to use the backend default.
/// - bits: The KV-cache quantization bit width.
/// Pass `nil` to disable KV quantization.
/// - groupSize: The token group size used for KV quantization.
/// - quantizedStart: The token index where quantized KV storage begins.
public init(
maxSize: Int?,
bits: Int?,
groupSize: Int,
quantizedStart: Int
) {
self.maxSize = maxSize
self.bits = bits
self.groupSize = groupSize
self.quantizedStart = quantizedStart
}
}
/// KV-cache configuration used for generation.
public var kvCache: KVCache

/// Configures media preprocessing applied before model input.
public struct UserInputProcessing: Codable, Equatable, Sendable {
/// Optional resize target applied to media before tokenization.
public var resize: CGSize?

/// Creates user-input processing configuration.
///
/// - Parameter resize: Optional target size for media resizing.
init(resize: CGSize?) {
self.resize = resize
}

/// Creates processing that resizes media to a fixed size.
///
/// - Parameter size: Target size used for resizing media inputs.
public static func resize(to size: CGSize) -> Self {
.init(resize: size)
}

var mlxValue: MLXLMCommon.UserInput.Processing {
.init(resize: resize)
}
}
/// Processing to apply to user media before input preparation.
public var userInputProcessing: UserInputProcessing?

var processingForUserInput: MLXLMCommon.UserInput.Processing {
userInputProcessing?.mlxValue
?? .init(resize: nil)
}

/// Additional key-value pairs injected into the chat template rendering context.
public var additionalContext: [String: AnyLanguageModel.JSONValue]?

var additionalContextForUserInput: [String: any Sendable]? {
additionalContext?.mapValues { $0.toSendable() }
}

/// Creates MLX-specific generation options.
///
/// - Parameters:
/// - maxKVSize: The maximum number of tokens to retain in KV cache storage.
/// Pass `nil` to use the backend default.
/// - kvBits: The KV-cache quantization bit width.
/// Pass `nil` to disable KV quantization.
/// - kvGroupSize: The token group size used for KV quantization.
/// - quantizedKVStart: The token index where quantized KV storage begins.
/// - kvCache: KV-cache configuration used for generation.
/// - additionalContext: Additional key-value pairs injected into the chat
/// template rendering context.
/// - userInputProcessing: Processing to apply to user media before input preparation.
/// Defaults to `nil`, which lets MLX use its default media handling.
public init(
maxKVSize: Int? = nil,
kvBits: Int? = nil,
kvGroupSize: Int = 64,
quantizedKVStart: Int = 0
kvCache: KVCache,
userInputProcessing: UserInputProcessing?,
additionalContext: [String: AnyLanguageModel.JSONValue]?
) {
self.maxKVSize = maxKVSize
self.kvBits = kvBits
self.kvGroupSize = kvGroupSize
self.quantizedKVStart = quantizedKVStart
self.kvCache = kvCache
self.additionalContext = additionalContext
self.userInputProcessing = userInputProcessing
}

/// Default MLX generation options used when none are provided at runtime.
public static var `default`: Self {
.init(
kvCache: .default,
userInputProcessing: nil,
additionalContext: nil
)
}
}

Expand Down Expand Up @@ -761,15 +845,16 @@ import Foundation
}

private func makeUserInput(
session: LanguageModelSession,
fallbackPrompt: String,
tools: [ToolSpec]?
chat: [MLXLMCommon.Chat.Message],
tools: [ToolSpec]?,
processing: MLXLMCommon.UserInput.Processing = .init(resize: nil),
additionalContext: [String: any Sendable]? = nil
) -> MLXLMCommon.UserInput {
let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: fallbackPrompt)
return MLXLMCommon.UserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: tools
processing: processing,
tools: tools,
additionalContext: additionalContext,
)
}

Expand Down Expand Up @@ -813,6 +898,12 @@ import Foundation
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
let generateParameters = toGenerateParameters(options)

// Extract additional context from custom options
let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput
let userInputProcessing =
options[custom: MLXLanguageModel.self]?.processingForUserInput
?? .init(resize: nil)

// Build chat history from full transcript
var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)

Expand All @@ -825,10 +916,11 @@ import Foundation
// Loop until no more tool calls
while true {
// Build user input with current chat history and tools
let userInput = MLXLMCommon.UserInput(
let userInput = makeUserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: toolSpecs
tools: toolSpecs,
processing: userInputProcessing,
additionalContext: additionalContext
)
let lmInput = try await context.processor.prepare(input: userInput)
let resolved = resolveCache(
Expand Down Expand Up @@ -991,10 +1083,20 @@ import Foundation

// Build chat inside task to avoid Sendable issues
let generateParameters = toGenerateParameters(options)
let userInput = makeUserInput(
let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput
let userInputProcessing =
options[custom: MLXLanguageModel.self]?.processingForUserInput
?? .init(resize: nil)
let chat = convertTranscriptToMLXChat(
session: session,
fallbackPrompt: prompt.description,
tools: nil
fallbackPrompt: prompt.description
)

let userInput = makeUserInput(
chat: chat,
tools: nil,
processing: userInputProcessing,
additionalContext: additionalContext
)
let lmInput = try await context.processor.prepare(input: userInput)
let resolved = resolveCache(
Expand Down Expand Up @@ -1092,7 +1194,7 @@ import Foundation
let newCache = context.model.newCache(parameters: params)
let userInput = MLXLMCommon.UserInput(
chat: [.init(role: .system, content: instructions)],
processing: .init(resize: .init(width: 512, height: 512)),
processing: .init(resize: nil),
tools: toolSpecs
)
let lmInput = try await context.processor.prepare(input: userInput)
Expand All @@ -1116,10 +1218,10 @@ import Foundation
let custom = options[custom: MLXLanguageModel.self]
return MLXLMCommon.GenerateParameters(
maxTokens: options.maximumResponseTokens,
maxKVSize: custom?.maxKVSize,
kvBits: custom?.kvBits,
kvGroupSize: custom?.kvGroupSize ?? 64,
quantizedKVStart: custom?.quantizedKVStart ?? 0,
maxKVSize: custom?.kvCache.maxSize,
kvBits: custom?.kvCache.bits,
kvGroupSize: custom?.kvCache.groupSize ?? 64,
quantizedKVStart: custom?.kvCache.quantizedStart ?? 0,
temperature: Float(options.temperature ?? 0.6),
topP: 1.0,
repetitionPenalty: nil,
Expand All @@ -1132,10 +1234,10 @@ import Foundation
let custom = options[custom: MLXLanguageModel.self]
return MLXLMCommon.GenerateParameters(
maxTokens: options.maximumResponseTokens,
maxKVSize: custom?.maxKVSize,
kvBits: custom?.kvBits,
kvGroupSize: custom?.kvGroupSize ?? 64,
quantizedKVStart: custom?.quantizedKVStart ?? 0,
maxKVSize: custom?.kvCache.maxSize,
kvBits: custom?.kvCache.bits,
kvGroupSize: custom?.kvCache.groupSize ?? 64,
quantizedKVStart: custom?.kvCache.quantizedStart ?? 0,
temperature: Float(options.temperature ?? 0.2),
topP: 0.95,
repetitionPenalty: 1.1,
Expand Down Expand Up @@ -1489,7 +1591,7 @@ import Foundation
{
header += ". Expected value: \(constString)"
} else if let enumValues = jsonSchema.enum, !enumValues.isEmpty,
let data = try? encoder.encode(JSONValue.array(enumValues)),
let data = try? encoder.encode(enumValues),
let enumString = String(data: data, encoding: .utf8)
{
header += ". Allowed values: \(enumString)"
Expand Down Expand Up @@ -1529,10 +1631,17 @@ import Foundation
let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil
let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt)

let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput
let userInputProcessing =
options[custom: MLXLanguageModel.self]?.processingForUserInput
?? .init(resize: nil)

let userInput = MLXLMCommon.UserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: nil
processing: userInputProcessing,
tools: nil,
additionalContext: additionalContext,
)
let lmInput = try await context.processor.prepare(input: userInput)

Expand Down Expand Up @@ -1773,4 +1882,18 @@ import Foundation
return sampledToken.item(Int.self)
}
}
extension AnyLanguageModel.JSONValue {
/// Recursively converts a `JSONValue` to its primitive Swift equivalent.
func toSendable() -> any Sendable {
switch self {
case .string(let s): return s
case .int(let i): return i
case .double(let d): return d
case .bool(let b): return b
case .null: return NSNull()
case .array(let arr): return arr.map { $0.toSendable() }
case .object(let obj): return obj.mapValues { $0.toSendable() }
}
}
}
#endif // MLX
4 changes: 4 additions & 0 deletions Sources/AnyLanguageModel/Shared/JSONValue.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import enum JSONSchema.JSONValue

/// A type-safe representation of JSON values used by AnyLanguageModel APIs.
public typealias JSONValue = JSONSchema.JSONValue
Loading