Skip to content

Commit d9f46e3

Browse files
Fix many compiler warnings (#14)
* Fix compiler warnings * Deprecate SuScaledRotaryEmbedding * Fix defaultHubApi shared mutable state warning * Fix non-final class warnings * Fix ModelAdapterFactory warnings * Resolve concurrency warnings in registries * Make handler in Tool sendable * Fix Message type ambiguity * Replace `[String: Any]` with `[String: any Sendable]` * Use real tests for streamlined API * Refactor streamlined API * Rename files with ChatSession for clarity * Pin swift-transformers until next release * Make additionalContext sendable * Make tool schema sendable * ChatSession: use AsyncStream-based generate instead of deprecated callback * `sending` adjustments * Docstring fix * Mark non-thread-safe methods as deprecated, add thread-safe methods to ModelContainer * Format
1 parent 74f85d9 commit d9f46e3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+806
-836
lines changed

Libraries/Embedders/Load.swift

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,6 @@ func loadSynchronous(modelDirectory: URL) throws -> EmbeddingModel {
9191
}
9292
}
9393

94-
if let quantization = baseConfig.quantization {
95-
quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) {
96-
path, module in
97-
weights["\(path).scales"] != nil
98-
}
99-
}
100-
10194
// apply the loaded weights
10295
let parameters = ModuleParameters.unflattened(weights)
10396
try model.update(parameters: parameters, verify: [.all])

Libraries/Embedders/Qwen3.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ private class Attention: Module {
6060
}
6161

6262
public func callAsFunction(
63-
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
63+
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
6464
) -> MLXArray {
6565
let (B, L) = (x.dim(0), x.dim(1))
6666

@@ -125,7 +125,7 @@ private class TransformerBlock: Module {
125125
}
126126

127127
public func callAsFunction(
128-
_ x: MLXArray, mask: MLXArray? = nil, cache: KVCache?
128+
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
129129
) -> MLXArray {
130130
var r = attention(inputLayerNorm(x), mask: mask, cache: cache)
131131
let h = x + r
@@ -157,7 +157,7 @@ private class Qwen3ModelInner: Module {
157157
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray {
158158
var h = embedTokens(inputs)
159159

160-
let mask: MLXArray? = createAttentionMask(h: h, cache: cache)
160+
let mask = createAttentionMask(h: h, cache: cache?.first)
161161

162162
for (i, layer) in layers.enumerated() {
163163
h = layer(h, mask: mask, cache: cache?[i])

Libraries/MLXLLM/LLMModel.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,11 @@ extension LLMModel {
2424
{
2525
let prefillStepSize = windowSize ?? 512
2626
var y = input.text
27-
var state: LMOutput.State? = nil
2827

29-
// prepare the prompt in chunks if larger than the prefill size
28+
// Prepare the prompt in chunks if larger than the prefill size
3029
while y.tokens.size > prefillStepSize {
3130
let input = y[.newAxis, ..<prefillStepSize]
32-
let result = self(input, cache: cache.isEmpty ? nil : cache, state: state)
31+
_ = self(input, cache: cache.isEmpty ? nil : cache, state: nil)
3332
eval(cache)
3433
y = y[prefillStepSize...]
3534
}

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 48 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -20,58 +20,53 @@ private func create<C: Codable, M>(
2020
/// Registry of model type, e.g 'llama', to functions that can instantiate the model from configuration.
2121
///
2222
/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``.
23-
public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
23+
public enum LLMTypeRegistry {
2424

2525
/// Shared instance with default model types.
26-
public static let shared: LLMTypeRegistry = .init(creators: all())
27-
28-
/// All predefined model types.
29-
private static func all() -> [String: @Sendable (URL) throws -> any LanguageModel] {
30-
[
31-
"mistral": create(LlamaConfiguration.self, LlamaModel.init),
32-
"llama": create(LlamaConfiguration.self, LlamaModel.init),
33-
"phi": create(PhiConfiguration.self, PhiModel.init),
34-
"phi3": create(Phi3Configuration.self, Phi3Model.init),
35-
"phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init),
36-
"gemma": create(GemmaConfiguration.self, GemmaModel.init),
37-
"gemma2": create(Gemma2Configuration.self, Gemma2Model.init),
38-
"gemma3": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
39-
"gemma3_text": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
40-
"gemma3n": create(Gemma3nTextConfiguration.self, Gemma3nTextModel.init),
41-
"qwen2": create(Qwen2Configuration.self, Qwen2Model.init),
42-
"qwen3": create(Qwen3Configuration.self, Qwen3Model.init),
43-
"qwen3_moe": create(Qwen3MoEConfiguration.self, Qwen3MoEModel.init),
44-
"starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init),
45-
"cohere": create(CohereConfiguration.self, CohereModel.init),
46-
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
47-
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
48-
"deepseek_v3": create(DeepseekV3Configuration.self, DeepseekV3Model.init),
49-
"granite": create(GraniteConfiguration.self, GraniteModel.init),
50-
"granitemoehybrid": create(
51-
GraniteMoeHybridConfiguration.self, GraniteMoeHybridModel.init),
52-
"mimo": create(MiMoConfiguration.self, MiMoModel.init),
53-
"glm4": create(GLM4Configuration.self, GLM4Model.init),
54-
"acereason": create(Qwen2Configuration.self, Qwen2Model.init),
55-
"falcon_h1": create(FalconH1Configuration.self, FalconH1Model.init),
56-
"bitnet": create(BitnetConfiguration.self, BitnetModel.init),
57-
"smollm3": create(SmolLM3Configuration.self, SmolLM3Model.init),
58-
"ernie4_5": create(Ernie45Configuration.self, Ernie45Model.init),
59-
"lfm2": create(LFM2Configuration.self, LFM2Model.init),
60-
"baichuan_m1": create(BaichuanM1Configuration.self, BaichuanM1Model.init),
61-
"exaone4": create(Exaone4Configuration.self, Exaone4Model.init),
62-
"gpt_oss": create(GPTOSSConfiguration.self, GPTOSSModel.init),
63-
"lille-130m": create(Lille130mConfiguration.self, Lille130mModel.init),
64-
"olmoe": create(OlmoEConfiguration.self, OlmoEModel.init),
65-
"olmo2": create(Olmo2Configuration.self, Olmo2Model.init),
66-
"olmo3": create(Olmo3Configuration.self, Olmo3Model.init),
67-
"bailing_moe": create(BailingMoeConfiguration.self, BailingMoeModel.init),
68-
"lfm2_moe": create(LFM2MoEConfiguration.self, LFM2MoEModel.init),
69-
"nanochat": create(NanoChatConfiguration.self, NanoChatModel.init),
70-
"afmoe": create(AfMoEConfiguration.self, AfMoEModel.init),
71-
"jamba_3b": create(JambaConfiguration.self, JambaModel.init),
72-
"mistral3": create(Mistral3TextConfiguration.self, Mistral3TextModel.init),
73-
]
74-
}
26+
public static let shared: ModelTypeRegistry = .init(creators: [
27+
"mistral": create(LlamaConfiguration.self, LlamaModel.init),
28+
"llama": create(LlamaConfiguration.self, LlamaModel.init),
29+
"phi": create(PhiConfiguration.self, PhiModel.init),
30+
"phi3": create(Phi3Configuration.self, Phi3Model.init),
31+
"phimoe": create(PhiMoEConfiguration.self, PhiMoEModel.init),
32+
"gemma": create(GemmaConfiguration.self, GemmaModel.init),
33+
"gemma2": create(Gemma2Configuration.self, Gemma2Model.init),
34+
"gemma3": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
35+
"gemma3_text": create(Gemma3TextConfiguration.self, Gemma3TextModel.init),
36+
"gemma3n": create(Gemma3nTextConfiguration.self, Gemma3nTextModel.init),
37+
"qwen2": create(Qwen2Configuration.self, Qwen2Model.init),
38+
"qwen3": create(Qwen3Configuration.self, Qwen3Model.init),
39+
"qwen3_moe": create(Qwen3MoEConfiguration.self, Qwen3MoEModel.init),
40+
"starcoder2": create(Starcoder2Configuration.self, Starcoder2Model.init),
41+
"cohere": create(CohereConfiguration.self, CohereModel.init),
42+
"openelm": create(OpenElmConfiguration.self, OpenELMModel.init),
43+
"internlm2": create(InternLM2Configuration.self, InternLM2Model.init),
44+
"deepseek_v3": create(DeepseekV3Configuration.self, DeepseekV3Model.init),
45+
"granite": create(GraniteConfiguration.self, GraniteModel.init),
46+
"granitemoehybrid": create(
47+
GraniteMoeHybridConfiguration.self, GraniteMoeHybridModel.init),
48+
"mimo": create(MiMoConfiguration.self, MiMoModel.init),
49+
"glm4": create(GLM4Configuration.self, GLM4Model.init),
50+
"acereason": create(Qwen2Configuration.self, Qwen2Model.init),
51+
"falcon_h1": create(FalconH1Configuration.self, FalconH1Model.init),
52+
"bitnet": create(BitnetConfiguration.self, BitnetModel.init),
53+
"smollm3": create(SmolLM3Configuration.self, SmolLM3Model.init),
54+
"ernie4_5": create(Ernie45Configuration.self, Ernie45Model.init),
55+
"lfm2": create(LFM2Configuration.self, LFM2Model.init),
56+
"baichuan_m1": create(BaichuanM1Configuration.self, BaichuanM1Model.init),
57+
"exaone4": create(Exaone4Configuration.self, Exaone4Model.init),
58+
"gpt_oss": create(GPTOSSConfiguration.self, GPTOSSModel.init),
59+
"lille-130m": create(Lille130mConfiguration.self, Lille130mModel.init),
60+
"olmoe": create(OlmoEConfiguration.self, OlmoEModel.init),
61+
"olmo2": create(Olmo2Configuration.self, Olmo2Model.init),
62+
"olmo3": create(Olmo3Configuration.self, Olmo3Model.init),
63+
"bailing_moe": create(BailingMoeConfiguration.self, BailingMoeModel.init),
64+
"lfm2_moe": create(LFM2MoEConfiguration.self, LFM2MoEModel.init),
65+
"nanochat": create(NanoChatConfiguration.self, NanoChatModel.init),
66+
"afmoe": create(AfMoEConfiguration.self, AfMoEModel.init),
67+
"jamba_3b": create(JambaConfiguration.self, JambaModel.init),
68+
"mistral3": create(Mistral3TextConfiguration.self, Mistral3TextModel.init),
69+
])
7570
}
7671

7772
/// Registry of models and any overrides that go with them, e.g. prompt augmentation.
@@ -458,7 +453,7 @@ private struct LLMUserInputProcessor: UserInputProcessor {
458453
/// let modelContainer = try await LLMModelFactory.shared.loadContainer(
459454
/// configuration: LLMRegistry.llama3_8B_4bit)
460455
/// ```
461-
public class LLMModelFactory: ModelFactory {
456+
public final class LLMModelFactory: ModelFactory {
462457

463458
public init(typeRegistry: ModelTypeRegistry, modelRegistry: AbstractModelRegistry) {
464459
self.typeRegistry = typeRegistry
@@ -478,7 +473,7 @@ public class LLMModelFactory: ModelFactory {
478473
public func _load(
479474
hub: HubApi, configuration: ModelConfiguration,
480475
progressHandler: @Sendable @escaping (Progress) -> Void
481-
) async throws -> sending ModelContext {
476+
) async throws -> ModelContext {
482477
// download weights and config
483478
let modelDirectory = try await downloadModel(
484479
hub: hub, configuration: configuration, progressHandler: progressHandler)
@@ -497,7 +492,7 @@ public class LLMModelFactory: ModelFactory {
497492

498493
let model: LanguageModel
499494
do {
500-
model = try typeRegistry.createModel(
495+
model = try await typeRegistry.createModel(
501496
configuration: configurationURL, modelType: baseConfig.modelType)
502497
} catch let error as DecodingError {
503498
throw ModelFactoryError.configurationDecodingError(

Libraries/MLXLLM/Models/AfMoE.swift

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -498,14 +498,13 @@ private class AfMoEModelInner: Module {
498498
}
499499

500500
// Create attention masks
501-
let faCache: [KVCache]? = layerCache.map { [$0[faIdx]] }
502-
let faMask = createAttentionMask(h: h, cache: faCache)
501+
let faMask = createAttentionMask(h: h, cache: layerCache?[faIdx])
503502

504503
var swaMask: MLXFast.ScaledDotProductAttentionMaskMode = .none
505504
if let swaIdx = swaIdx, let layerCache = layerCache {
506-
let swaCache = [layerCache[swaIdx]]
507505
// Create mask with sliding window
508-
swaMask = createSlidingWindowMask(h: h, cache: swaCache, windowSize: slidingWindow)
506+
swaMask = createAttentionMask(
507+
h: h, cache: layerCache[swaIdx], windowSize: slidingWindow)
509508
}
510509

511510
for (i, layer) in layers.enumerated() {
@@ -515,25 +514,6 @@ private class AfMoEModelInner: Module {
515514

516515
return norm(h)
517516
}
518-
519-
// Helper to create sliding window mask
520-
private func createSlidingWindowMask(
521-
h: MLXArray, cache: [KVCache]?, windowSize: Int
522-
) -> MLXFast.ScaledDotProductAttentionMaskMode {
523-
let t = h.dim(1)
524-
if t > 1 {
525-
var offset = 0
526-
if let c = cache?.first {
527-
offset = c.offset
528-
if let maxSize = c.maxSize {
529-
offset = min(maxSize, offset)
530-
}
531-
}
532-
let mask = createCausalMask(n: t, offset: offset, windowSize: windowSize)
533-
return .array(mask)
534-
}
535-
return .none
536-
}
537517
}
538518

539519
// MARK: - AfMoE Model (Public)

Libraries/MLXLLM/Models/BailingMoe.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ private class BailingMoeGate: Module, UnaryLayer {
212212
}
213213

214214
func groupSelect(_ x: MLXArray) -> (inds: MLXArray, scores: MLXArray) {
215-
let (bsz, seqLen, h) = (x.dim(0), x.dim(1), x.dim(2))
215+
let (bsz, seqLen, _) = (x.dim(0), x.dim(1), x.dim(2))
216216

217217
let logits = gate(x)
218218
var scores = sigmoid(logits.asType(.float32))
@@ -221,14 +221,14 @@ private class BailingMoeGate: Module, UnaryLayer {
221221

222222
let topKGroup = top(groupScores, k: 2, axis: -1).sum(axis: -1, keepDims: true)
223223
var k = nGroup - topkGroup
224-
var groupIdx = argPartition(topKGroup, kth: k - 1, axis: -2)[.ellipsis, ..<k, 0...]
224+
let groupIdx = argPartition(topKGroup, kth: k - 1, axis: -2)[.ellipsis, ..<k, 0...]
225225
scores = putAlong(groupScores, groupIdx, values: MLXArray(0.0), axis: -2)
226226
scores = flattened(scores, start: -2, end: -1)
227227

228228
k = topK
229229
let inds = argPartition(-scores, kth: k - 1, axis: -1)[.ellipsis, ..<k]
230230
scores = takeAlong(scores, inds, axis: -1)
231-
if topK ?? 1 > 1, normTopkProb {
231+
if topK > 1, normTopkProb {
232232
let denominator = scores.sum(axis: -1, keepDims: true) + 1e-20
233233
scores = scores / denominator
234234
}

Libraries/MLXLLM/Models/DeepseekV3.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import Foundation
44
import MLX
55
import MLXFast
6-
import MLXLLM
76
import MLXLMCommon
87
import MLXNN
98

@@ -354,7 +353,7 @@ private class MoEGate: Module {
354353
}
355354

356355
func callAsFunction(_ x: MLXArray) -> (MLXArray, MLXArray) {
357-
let (bsz, seqLen, h) = (x.dim(0), x.dim(1), x.dim(2))
356+
let (bsz, seqLen, _) = (x.dim(0), x.dim(1), x.dim(2))
358357

359358
let hiddenStates = x.matmul(weight.T)
360359
var scores = sigmoid(hiddenStates)

Libraries/MLXLLM/Models/Gemma3Text.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import Foundation
1111
import MLX
1212
import MLXFast
13-
import MLXLLM
1413
import MLXLMCommon
1514
import MLXNN
1615

Libraries/MLXLLM/Models/Granite.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ private class TransformerBlock: Module {
121121
let residualMultiplier: Float
122122

123123
public init(_ args: GraniteConfiguration) {
124-
let attentionHeads = args.attentionHeads
125124
let hiddenSize = args.hiddenSize
126125

127126
self._attention.wrappedValue = Attention(args)
@@ -271,7 +270,7 @@ public struct GraniteConfiguration: Codable, Sendable {
271270
self.maxPositionEmbeddings = try container.decode(Int.self, forKey: .maxPositionEmbeddings)
272271
self.kvHeads = try container.decode(Int.self, forKey: .kvHeads)
273272
self.attentionBias = try container.decode(Bool.self, forKey: .attentionBias)
274-
self.mlpBias = try container.decode(Bool.self, forKey: .mlpBias) ?? false
273+
self.mlpBias = try container.decode(Bool.self, forKey: .mlpBias)
275274
self.ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) ?? 10000000.0
276275
self.ropeScaling = try container.decodeIfPresent(
277276
[String: StringOrNumber].self, forKey: .ropeScaling)

Libraries/MLXLLM/Models/GraniteMoeHybrid.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ private class GraniteMoeHybridMamba2Mixer: Module {
119119
}
120120
}
121121

122-
var padded = concatenated([convState!, input], axis: 1)
122+
let padded = concatenated([convState!, input], axis: 1)
123123

124124
if let cache {
125125
let end = padded.dim(1)
@@ -136,12 +136,12 @@ private class GraniteMoeHybridMamba2Mixer: Module {
136136
mask: MLXArray?,
137137
cache: MambaCache?
138138
) -> MLXArray {
139-
var projected = inProj(hiddenStates)
139+
let projected = inProj(hiddenStates)
140140
let splits = split(
141141
projected, indices: [intermediateSize, intermediateSize + convDim], axis: -1)
142-
var gate = splits[0]
142+
let gate = splits[0]
143143
var convInput = splits[1]
144-
var dt = splits[2]
144+
let dt = splits[2]
145145

146146
if let mask {
147147
let expandedMask = expandedDimensions(mask, axis: -1)
@@ -551,7 +551,7 @@ public class GraniteMoeHybridModel: Module, LLMModel, KVCacheDimensionProvider {
551551
for layerIndex in 0 ..< configuration.hiddenLayers {
552552
let prefix = "model.layers.\(layerIndex).block_sparse_moe"
553553
guard
554-
var inputWeight = sanitized.removeValue(forKey: "\(prefix).input_linear.weight")
554+
let inputWeight = sanitized.removeValue(forKey: "\(prefix).input_linear.weight")
555555
else { continue }
556556

557557
let expertHidden = inputWeight.dim(1)

0 commit comments

Comments
 (0)