@@ -23,12 +23,40 @@ public struct Gemma3TextConfiguration: Codable {
2323 let rmsNormEps : Float
2424 let vocabularySize : Int
2525 let kvHeads : Int
26- let ropeGlobalBaseFreq : Float
26+ let ropeTheta : Float
2727 let ropeLocalBaseFreq : Float
2828 let ropeTraditional : Bool
2929 let queryPreAttnScalar : Float
3030 let slidingWindow : Int
3131 let slidingWindowPattern : Int
32+ let maxPositionEmbeddings : Int
33+ let ropeScaling : [ String : StringOrNumber ] ?
34+
35+ public init (
36+ modelType: String , hiddenSize: Int , hiddenLayers: Int , intermediateSize: Int ,
37+ attentionHeads: Int , headDim: Int , rmsNormEps: Float , vocabularySize: Int , kvHeads: Int ,
38+ ropeTheta: Float , ropeLocalBaseFreq: Float , ropeTraditional: Bool ,
39+ queryPreAttnScalar: Float , slidingWindow: Int , slidingWindowPattern: Int ,
40+ maxPositionEmbeddings: Int , ropeScaling: [ String : StringOrNumber ] ? = nil
41+ ) {
42+ self . modelType = modelType
43+ self . hiddenSize = hiddenSize
44+ self . hiddenLayers = hiddenLayers
45+ self . intermediateSize = intermediateSize
46+ self . attentionHeads = attentionHeads
47+ self . headDim = headDim
48+ self . rmsNormEps = rmsNormEps
49+ self . vocabularySize = vocabularySize
50+ self . kvHeads = kvHeads
51+ self . ropeTheta = ropeTheta
52+ self . ropeLocalBaseFreq = ropeLocalBaseFreq
53+ self . ropeTraditional = ropeTraditional
54+ self . queryPreAttnScalar = queryPreAttnScalar
55+ self . slidingWindow = slidingWindow
56+ self . slidingWindowPattern = slidingWindowPattern
57+ self . maxPositionEmbeddings = maxPositionEmbeddings
58+ self . ropeScaling = ropeScaling
59+ }
3260
3361 enum CodingKeys : String , CodingKey {
3462 case modelType = " model_type "
@@ -40,12 +68,14 @@ public struct Gemma3TextConfiguration: Codable {
4068 case rmsNormEps = " rms_norm_eps "
4169 case vocabularySize = " vocab_size "
4270 case kvHeads = " num_key_value_heads "
43- case ropeGlobalBaseFreq = " rope_global_base_freq "
71+ case ropeTheta = " rope_theta "
4472 case ropeLocalBaseFreq = " rope_local_base_freq "
4573 case ropeTraditional = " rope_traditional "
4674 case queryPreAttnScalar = " query_pre_attn_scalar "
4775 case slidingWindow = " sliding_window "
4876 case slidingWindowPattern = " sliding_window_pattern "
77+ case maxPositionEmbeddings = " max_position_embeddings "
78+ case ropeScaling = " rope_scaling "
4979 }
5080
5181 enum VLMCodingKeys : String , CodingKey {
@@ -65,16 +95,17 @@ public struct Gemma3TextConfiguration: Codable {
6595 }
6696
6797 modelType = try container. decode ( String . self, forKey: . modelType)
68- hiddenSize = try container. decode ( Int . self, forKey: . hiddenSize)
69- hiddenLayers = try container. decode ( Int . self, forKey: . hiddenLayers)
70- intermediateSize = try container. decode ( Int . self, forKey: . intermediateSize)
98+ hiddenSize = try container. decodeIfPresent ( Int . self, forKey: . hiddenSize) ?? 1152
99+ hiddenLayers = try container. decodeIfPresent ( Int . self, forKey: . hiddenLayers) ?? 26
100+ intermediateSize =
101+ try container. decodeIfPresent ( Int . self, forKey: . intermediateSize) ?? 6912
71102 attentionHeads = try container. decodeIfPresent ( Int . self, forKey: . attentionHeads) ?? 4
72103 headDim = try container. decodeIfPresent ( Int . self, forKey: . headDim) ?? 256
73104 rmsNormEps = try container. decodeIfPresent ( Float . self, forKey: . rmsNormEps) ?? 1.0e-6
74105 vocabularySize = try container. decodeIfPresent ( Int . self, forKey: . vocabularySize) ?? 262144
75106 kvHeads = try container. decodeIfPresent ( Int . self, forKey: . kvHeads) ?? 1
76- ropeGlobalBaseFreq =
77- try container. decodeIfPresent ( Float . self, forKey: . ropeGlobalBaseFreq ) ?? 1_000_000.0
107+ ropeTheta =
108+ try container. decodeIfPresent ( Float . self, forKey: . ropeTheta ) ?? 1_000_000.0
78109 ropeLocalBaseFreq =
79110 try container. decodeIfPresent ( Float . self, forKey: . ropeLocalBaseFreq) ?? 10_000.0
80111 ropeTraditional =
@@ -84,6 +115,10 @@ public struct Gemma3TextConfiguration: Codable {
84115 slidingWindow = try container. decodeIfPresent ( Int . self, forKey: . slidingWindow) ?? 512
85116 slidingWindowPattern =
86117 try container. decodeIfPresent ( Int . self, forKey: . slidingWindowPattern) ?? 6
118+ maxPositionEmbeddings =
119+ try container. decodeIfPresent ( Int . self, forKey: . maxPositionEmbeddings) ?? 32768
120+ ropeScaling =
121+ try container. decodeIfPresent ( [ String : StringOrNumber ] . self, forKey: . ropeScaling)
87122 }
88123}
89124
@@ -106,7 +141,7 @@ private class Attention: Module {
106141 @ModuleInfo ( key: " q_norm " ) var queryNorm : Gemma . RMSNorm
107142 @ModuleInfo ( key: " k_norm " ) var keyNorm : Gemma . RMSNorm
108143
109- @ModuleInfo var rope : RoPE
144+ @ModuleInfo var rope : OffsetLayer
110145
111146 init ( _ config: Gemma3TextConfiguration , layerIdx: Int ) {
112147 let dim = config. hiddenSize
@@ -131,12 +166,16 @@ private class Attention: Module {
131166
132167 self . isSliding = ( layerIdx + 1 ) % config. slidingWindowPattern != 0
133168
134- let baseFreq = isSliding ? config. ropeLocalBaseFreq : config. ropeGlobalBaseFreq
135- self . _rope. wrappedValue = RoPE (
136- dimensions: headDim,
137- traditional: config. ropeTraditional,
138- base: baseFreq
139- )
169+ if isSliding {
170+ self . rope = initializeRope (
171+ dims: headDim, base: config. ropeLocalBaseFreq, traditional: false ,
172+ scalingConfig: nil , maxPositionEmbeddings: nil )
173+ } else {
174+ self . rope = initializeRope (
175+ dims: headDim, base: config. ropeTheta, traditional: false ,
176+ scalingConfig: config. ropeScaling,
177+ maxPositionEmbeddings: config. maxPositionEmbeddings)
178+ }
140179
141180 super. init ( )
142181 }
@@ -163,18 +202,8 @@ private class Attention: Module {
163202 queries = rope ( queries, offset: cache. offset)
164203 keys = rope ( keys, offset: cache. offset)
165204 } else {
166- queries = rope ( queries)
167- keys = rope ( keys)
168- }
169-
170- // Sliding window masking
171- var finalMask = mask
172- if case . array( let maskArray) = mask {
173- let keySeqLen = keys. shape [ 2 ]
174- if maskArray. shape. last! != keySeqLen {
175- let slicedMask = maskArray [ . ellipsis, ( - keySeqLen) ... ]
176- finalMask = . array( slicedMask)
177- }
205+ queries = rope ( queries, offset: 0 )
206+ keys = rope ( keys, offset: 0 )
178207 }
179208
180209 let output = attentionWithCacheUpdate (
@@ -183,7 +212,7 @@ private class Attention: Module {
183212 values: values,
184213 cache: cache,
185214 scale: scale,
186- mask: finalMask
215+ mask: mask
187216 )
188217 . transposed ( 0 , 2 , 1 , 3 )
189218 . reshaped ( B, L, - 1 )
@@ -295,30 +324,19 @@ private class Gemma3Model: Module {
295324 if layerCache == nil {
296325 layerCache = Array ( repeating: nil as KVCache ? , count: layers. count)
297326 }
298- // Create attention masks
299- var fullMask : MLXFast . ScaledDotProductAttentionMaskMode = . none
300- var slidingWindowMask : MLXFast . ScaledDotProductAttentionMaskMode = . none
301- if mask == nil {
302- let j = config. slidingWindowPattern
303- let globalCache : KVCache ? =
304- ( j > 0 && j <= ( layerCache? . count ?? 0 ) ) ? layerCache ? [ j - 1 ] : nil
305- fullMask = createAttentionMask ( h: h, cache: globalCache)
306- let slidingCache : KVCache ? = layerCache? . first ?? nil
307- slidingWindowMask = createAttentionMask (
308- h: h, cache: slidingCache, windowSize: config. slidingWindow)
309- }
310- for (i, layer) in layers. enumerated ( ) {
311- let isGlobal = ( i % config. slidingWindowPattern == config. slidingWindowPattern - 1 )
312327
313- let localMask : MLXFast . ScaledDotProductAttentionMaskMode
314- if let mask {
315- localMask = mask
316- } else if isGlobal {
317- localMask = fullMask
328+ let globalMask = createAttentionMask ( h: h, cache: cache ? [ config. slidingWindowPattern - 1 ] )
329+ let slidingWindowMask =
330+ if config. slidingWindowPattern > 1 {
331+ createAttentionMask ( h: h, cache: cache ? [ 0 ] , windowSize: config. slidingWindow)
318332 } else {
319- localMask = slidingWindowMask
333+ MLXFast . ScaledDotProductAttentionMaskMode . none
320334 }
321- h = layer ( h, mask: localMask, cache: layerCache ? [ i] )
335+
336+ for (i, layer) in layers. enumerated ( ) {
337+ let isGlobal = ( i % config. slidingWindowPattern == config. slidingWindowPattern - 1 )
338+ let mask = isGlobal ? globalMask : slidingWindowMask
339+ h = layer ( h, mask: mask, cache: layerCache ? [ i] )
322340 }
323341 return norm ( h)
324342 }
0 commit comments