Skip to content

Commit 10bc02d

Browse files
committed
support for LLMBasic (mlx-swift-examples)
- ml-explore/mlx-swift-examples#454 - fixes #27 - move ChatSession integration tests into new test target so we can more easily control when it runs - make a ChatSession _unit_ (more or less) test - fix Sendable / thread safety issues uncovered by LLMBasic
1 parent d9f46e3 commit 10bc02d

File tree

15 files changed

+451
-211
lines changed

15 files changed

+451
-211
lines changed

Libraries/MLXLLM/Models/Gemma3Text.swift

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,40 @@ public struct Gemma3TextConfiguration: Codable {
2323
let rmsNormEps: Float
2424
let vocabularySize: Int
2525
let kvHeads: Int
26-
let ropeGlobalBaseFreq: Float
26+
let ropeTheta: Float
2727
let ropeLocalBaseFreq: Float
2828
let ropeTraditional: Bool
2929
let queryPreAttnScalar: Float
3030
let slidingWindow: Int
3131
let slidingWindowPattern: Int
32+
let maxPositionEmbeddings: Int
33+
let ropeScaling: [String: StringOrNumber]?
34+
35+
public init(
36+
modelType: String, hiddenSize: Int, hiddenLayers: Int, intermediateSize: Int,
37+
attentionHeads: Int, headDim: Int, rmsNormEps: Float, vocabularySize: Int, kvHeads: Int,
38+
ropeTheta: Float, ropeLocalBaseFreq: Float, ropeTraditional: Bool,
39+
queryPreAttnScalar: Float, slidingWindow: Int, slidingWindowPattern: Int,
40+
maxPositionEmbeddings: Int, ropeScaling: [String: StringOrNumber]? = nil
41+
) {
42+
self.modelType = modelType
43+
self.hiddenSize = hiddenSize
44+
self.hiddenLayers = hiddenLayers
45+
self.intermediateSize = intermediateSize
46+
self.attentionHeads = attentionHeads
47+
self.headDim = headDim
48+
self.rmsNormEps = rmsNormEps
49+
self.vocabularySize = vocabularySize
50+
self.kvHeads = kvHeads
51+
self.ropeTheta = ropeTheta
52+
self.ropeLocalBaseFreq = ropeLocalBaseFreq
53+
self.ropeTraditional = ropeTraditional
54+
self.queryPreAttnScalar = queryPreAttnScalar
55+
self.slidingWindow = slidingWindow
56+
self.slidingWindowPattern = slidingWindowPattern
57+
self.maxPositionEmbeddings = maxPositionEmbeddings
58+
self.ropeScaling = ropeScaling
59+
}
3260

3361
enum CodingKeys: String, CodingKey {
3462
case modelType = "model_type"
@@ -40,12 +68,14 @@ public struct Gemma3TextConfiguration: Codable {
4068
case rmsNormEps = "rms_norm_eps"
4169
case vocabularySize = "vocab_size"
4270
case kvHeads = "num_key_value_heads"
43-
case ropeGlobalBaseFreq = "rope_global_base_freq"
71+
case ropeTheta = "rope_theta"
4472
case ropeLocalBaseFreq = "rope_local_base_freq"
4573
case ropeTraditional = "rope_traditional"
4674
case queryPreAttnScalar = "query_pre_attn_scalar"
4775
case slidingWindow = "sliding_window"
4876
case slidingWindowPattern = "sliding_window_pattern"
77+
case maxPositionEmbeddings = "max_position_embeddings"
78+
case ropeScaling = "rope_scaling"
4979
}
5080

5181
enum VLMCodingKeys: String, CodingKey {
@@ -65,16 +95,17 @@ public struct Gemma3TextConfiguration: Codable {
6595
}
6696

6797
modelType = try container.decode(String.self, forKey: .modelType)
68-
hiddenSize = try container.decode(Int.self, forKey: .hiddenSize)
69-
hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers)
70-
intermediateSize = try container.decode(Int.self, forKey: .intermediateSize)
98+
hiddenSize = try container.decodeIfPresent(Int.self, forKey: .hiddenSize) ?? 1152
99+
hiddenLayers = try container.decodeIfPresent(Int.self, forKey: .hiddenLayers) ?? 26
100+
intermediateSize =
101+
try container.decodeIfPresent(Int.self, forKey: .intermediateSize) ?? 6912
71102
attentionHeads = try container.decodeIfPresent(Int.self, forKey: .attentionHeads) ?? 4
72103
headDim = try container.decodeIfPresent(Int.self, forKey: .headDim) ?? 256
73104
rmsNormEps = try container.decodeIfPresent(Float.self, forKey: .rmsNormEps) ?? 1.0e-6
74105
vocabularySize = try container.decodeIfPresent(Int.self, forKey: .vocabularySize) ?? 262144
75106
kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? 1
76-
ropeGlobalBaseFreq =
77-
try container.decodeIfPresent(Float.self, forKey: .ropeGlobalBaseFreq) ?? 1_000_000.0
107+
ropeTheta =
108+
try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 1_000_000.0
78109
ropeLocalBaseFreq =
79110
try container.decodeIfPresent(Float.self, forKey: .ropeLocalBaseFreq) ?? 10_000.0
80111
ropeTraditional =
@@ -84,6 +115,10 @@ public struct Gemma3TextConfiguration: Codable {
84115
slidingWindow = try container.decodeIfPresent(Int.self, forKey: .slidingWindow) ?? 512
85116
slidingWindowPattern =
86117
try container.decodeIfPresent(Int.self, forKey: .slidingWindowPattern) ?? 6
118+
maxPositionEmbeddings =
119+
try container.decodeIfPresent(Int.self, forKey: .maxPositionEmbeddings) ?? 32768
120+
ropeScaling =
121+
try container.decodeIfPresent([String: StringOrNumber].self, forKey: .ropeScaling)
87122
}
88123
}
89124

@@ -106,7 +141,7 @@ private class Attention: Module {
106141
@ModuleInfo(key: "q_norm") var queryNorm: Gemma.RMSNorm
107142
@ModuleInfo(key: "k_norm") var keyNorm: Gemma.RMSNorm
108143

109-
@ModuleInfo var rope: RoPE
144+
@ModuleInfo var rope: OffsetLayer
110145

111146
init(_ config: Gemma3TextConfiguration, layerIdx: Int) {
112147
let dim = config.hiddenSize
@@ -131,12 +166,16 @@ private class Attention: Module {
131166

132167
self.isSliding = (layerIdx + 1) % config.slidingWindowPattern != 0
133168

134-
let baseFreq = isSliding ? config.ropeLocalBaseFreq : config.ropeGlobalBaseFreq
135-
self._rope.wrappedValue = RoPE(
136-
dimensions: headDim,
137-
traditional: config.ropeTraditional,
138-
base: baseFreq
139-
)
169+
if isSliding {
170+
self.rope = initializeRope(
171+
dims: headDim, base: config.ropeLocalBaseFreq, traditional: false,
172+
scalingConfig: nil, maxPositionEmbeddings: nil)
173+
} else {
174+
self.rope = initializeRope(
175+
dims: headDim, base: config.ropeTheta, traditional: false,
176+
scalingConfig: config.ropeScaling,
177+
maxPositionEmbeddings: config.maxPositionEmbeddings)
178+
}
140179

141180
super.init()
142181
}
@@ -163,18 +202,8 @@ private class Attention: Module {
163202
queries = rope(queries, offset: cache.offset)
164203
keys = rope(keys, offset: cache.offset)
165204
} else {
166-
queries = rope(queries)
167-
keys = rope(keys)
168-
}
169-
170-
// Sliding window masking
171-
var finalMask = mask
172-
if case .array(let maskArray) = mask {
173-
let keySeqLen = keys.shape[2]
174-
if maskArray.shape.last! != keySeqLen {
175-
let slicedMask = maskArray[.ellipsis, (-keySeqLen)...]
176-
finalMask = .array(slicedMask)
177-
}
205+
queries = rope(queries, offset: 0)
206+
keys = rope(keys, offset: 0)
178207
}
179208

180209
let output = attentionWithCacheUpdate(
@@ -183,7 +212,7 @@ private class Attention: Module {
183212
values: values,
184213
cache: cache,
185214
scale: scale,
186-
mask: finalMask
215+
mask: mask
187216
)
188217
.transposed(0, 2, 1, 3)
189218
.reshaped(B, L, -1)
@@ -295,30 +324,19 @@ private class Gemma3Model: Module {
295324
if layerCache == nil {
296325
layerCache = Array(repeating: nil as KVCache?, count: layers.count)
297326
}
298-
// Create attention masks
299-
var fullMask: MLXFast.ScaledDotProductAttentionMaskMode = .none
300-
var slidingWindowMask: MLXFast.ScaledDotProductAttentionMaskMode = .none
301-
if mask == nil {
302-
let j = config.slidingWindowPattern
303-
let globalCache: KVCache? =
304-
(j > 0 && j <= (layerCache?.count ?? 0)) ? layerCache?[j - 1] : nil
305-
fullMask = createAttentionMask(h: h, cache: globalCache)
306-
let slidingCache: KVCache? = layerCache?.first ?? nil
307-
slidingWindowMask = createAttentionMask(
308-
h: h, cache: slidingCache, windowSize: config.slidingWindow)
309-
}
310-
for (i, layer) in layers.enumerated() {
311-
let isGlobal = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1)
312327

313-
let localMask: MLXFast.ScaledDotProductAttentionMaskMode
314-
if let mask {
315-
localMask = mask
316-
} else if isGlobal {
317-
localMask = fullMask
328+
let globalMask = createAttentionMask(h: h, cache: cache?[config.slidingWindowPattern - 1])
329+
let slidingWindowMask =
330+
if config.slidingWindowPattern > 1 {
331+
createAttentionMask(h: h, cache: cache?[0], windowSize: config.slidingWindow)
318332
} else {
319-
localMask = slidingWindowMask
333+
MLXFast.ScaledDotProductAttentionMaskMode.none
320334
}
321-
h = layer(h, mask: localMask, cache: layerCache?[i])
335+
336+
for (i, layer) in layers.enumerated() {
337+
let isGlobal = (i % config.slidingWindowPattern == config.slidingWindowPattern - 1)
338+
let mask = isGlobal ? globalMask : slidingWindowMask
339+
h = layer(h, mask: mask, cache: layerCache?[i])
322340
}
323341
return norm(h)
324342
}

Libraries/MLXLLM/Models/Mistral3Text.swift

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ private class Attention: Module {
4242
@ModuleInfo(key: "v_proj") var wv: Linear
4343
@ModuleInfo(key: "o_proj") var wo: Linear
4444

45-
let rope: Module
45+
let rope: OffsetLayer
4646

4747
init(_ args: Mistral3TextConfiguration) {
4848
self.args = args
@@ -76,19 +76,6 @@ private class Attention: Module {
7676
super.init()
7777
}
7878

79-
private func applyRoPE(_ x: MLXArray, offset: Int) -> MLXArray {
80-
if let ropeModule = rope as? RoPE {
81-
return ropeModule(x, offset: offset)
82-
} else if let llama3Rope = rope as? Llama3RoPE {
83-
return llama3Rope(x, offset: offset)
84-
} else if let yarnRope = rope as? YarnRoPE {
85-
return yarnRope(x, offset: offset)
86-
} else if let suScaledRope = rope as? SuScaledRoPE {
87-
return suScaledRope(x, offset: offset)
88-
}
89-
return x
90-
}
91-
9279
func callAsFunction(
9380
_ x: MLXArray, attnScale: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode,
9481
cache: KVCache?
@@ -106,8 +93,8 @@ private class Attention: Module {
10693

10794
// Apply RoPE
10895
let offset = cache?.offset ?? 0
109-
queries = applyRoPE(queries, offset: offset)
110-
keys = applyRoPE(keys, offset: offset)
96+
queries = rope(queries, offset: offset)
97+
keys = rope(keys, offset: offset)
11198

11299
// Apply attention scaling
113100
queries = queries * attnScale

Libraries/MLXLLM/Models/Olmo3.swift

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ private class Attention: Module {
2929
@ModuleInfo(key: "q_norm") var qNorm: RMSNorm
3030
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm
3131

32-
let rope: Module
32+
let rope: OffsetLayer
3333

3434
init(_ args: Olmo3Configuration, layerIdx: Int) {
3535
self.args = args
@@ -65,17 +65,6 @@ private class Attention: Module {
6565
super.init()
6666
}
6767

68-
private func applyRoPE(_ x: MLXArray, offset: Int?) -> MLXArray {
69-
if let llama3Rope = rope as? Llama3RoPE {
70-
return llama3Rope(x, offset: offset ?? 0)
71-
} else if let yarnRope = rope as? YarnRoPE {
72-
return yarnRope(x, offset: offset ?? 0)
73-
} else if let basicRope = rope as? RoPE {
74-
return basicRope(x, offset: offset ?? 0)
75-
}
76-
return x
77-
}
78-
7968
func callAsFunction(
8069
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
8170
) -> MLXArray {
@@ -90,11 +79,11 @@ private class Attention: Module {
9079
values = values.reshaped(B, L, nKVHeads, -1).transposed(0, 2, 1, 3)
9180

9281
if let cache {
93-
queries = applyRoPE(queries, offset: cache.offset)
94-
keys = applyRoPE(keys, offset: cache.offset)
82+
queries = rope(queries, offset: cache.offset)
83+
keys = rope(keys, offset: cache.offset)
9584
} else {
96-
queries = applyRoPE(queries, offset: nil)
97-
keys = applyRoPE(keys, offset: nil)
85+
queries = rope(queries, offset: 0)
86+
keys = rope(keys, offset: 0)
9887
}
9988

10089
let output = attentionWithCacheUpdate(

Libraries/MLXLMCommon/AttentionUtils.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ public func attentionWithCacheUpdate(
6767
)
6868
} else {
6969
let (cachedKeys, cachedValues) = cache.update(keys: keys, values: values)
70+
// TODO dkoski
71+
// print("\(cachedKeys.shape) \(cachedValues.shape) \(queries.shape), \(mask.masks?[0].shape ?? [])")
7072
return MLXFast.scaledDotProductAttention(
7173
queries: queries,
7274
keys: cachedKeys,

0 commit comments

Comments
 (0)