@@ -20,58 +20,53 @@ private func create<C: Codable, M>(
2020/// Registry of model type, e.g 'llama', to functions that can instantiate the model from configuration.
2121///
2222/// Typically called via ``LLMModelFactory/load(hub:configuration:progressHandler:)``.
23- public class LLMTypeRegistry : ModelTypeRegistry , @ unchecked Sendable {
23+ public enum LLMTypeRegistry {
2424
2525 /// Shared instance with default model types.
26- public static let shared : LLMTypeRegistry = . init( creators: all ( ) )
27-
28- /// All predefined model types.
29- private static func all( ) -> [ String : @Sendable ( URL) throws -> any LanguageModel ] {
30- [
31- " mistral " : create ( LlamaConfiguration . self, LlamaModel . init) ,
32- " llama " : create ( LlamaConfiguration . self, LlamaModel . init) ,
33- " phi " : create ( PhiConfiguration . self, PhiModel . init) ,
34- " phi3 " : create ( Phi3Configuration . self, Phi3Model . init) ,
35- " phimoe " : create ( PhiMoEConfiguration . self, PhiMoEModel . init) ,
36- " gemma " : create ( GemmaConfiguration . self, GemmaModel . init) ,
37- " gemma2 " : create ( Gemma2Configuration . self, Gemma2Model . init) ,
38- " gemma3 " : create ( Gemma3TextConfiguration . self, Gemma3TextModel . init) ,
39- " gemma3_text " : create ( Gemma3TextConfiguration . self, Gemma3TextModel . init) ,
40- " gemma3n " : create ( Gemma3nTextConfiguration . self, Gemma3nTextModel . init) ,
41- " qwen2 " : create ( Qwen2Configuration . self, Qwen2Model . init) ,
42- " qwen3 " : create ( Qwen3Configuration . self, Qwen3Model . init) ,
43- " qwen3_moe " : create ( Qwen3MoEConfiguration . self, Qwen3MoEModel . init) ,
44- " starcoder2 " : create ( Starcoder2Configuration . self, Starcoder2Model . init) ,
45- " cohere " : create ( CohereConfiguration . self, CohereModel . init) ,
46- " openelm " : create ( OpenElmConfiguration . self, OpenELMModel . init) ,
47- " internlm2 " : create ( InternLM2Configuration . self, InternLM2Model . init) ,
48- " deepseek_v3 " : create ( DeepseekV3Configuration . self, DeepseekV3Model . init) ,
49- " granite " : create ( GraniteConfiguration . self, GraniteModel . init) ,
50- " granitemoehybrid " : create (
51- GraniteMoeHybridConfiguration . self, GraniteMoeHybridModel . init) ,
52- " mimo " : create ( MiMoConfiguration . self, MiMoModel . init) ,
53- " glm4 " : create ( GLM4Configuration . self, GLM4Model . init) ,
54- " acereason " : create ( Qwen2Configuration . self, Qwen2Model . init) ,
55- " falcon_h1 " : create ( FalconH1Configuration . self, FalconH1Model . init) ,
56- " bitnet " : create ( BitnetConfiguration . self, BitnetModel . init) ,
57- " smollm3 " : create ( SmolLM3Configuration . self, SmolLM3Model . init) ,
58- " ernie4_5 " : create ( Ernie45Configuration . self, Ernie45Model . init) ,
59- " lfm2 " : create ( LFM2Configuration . self, LFM2Model . init) ,
60- " baichuan_m1 " : create ( BaichuanM1Configuration . self, BaichuanM1Model . init) ,
61- " exaone4 " : create ( Exaone4Configuration . self, Exaone4Model . init) ,
62- " gpt_oss " : create ( GPTOSSConfiguration . self, GPTOSSModel . init) ,
63- " lille-130m " : create ( Lille130mConfiguration . self, Lille130mModel . init) ,
64- " olmoe " : create ( OlmoEConfiguration . self, OlmoEModel . init) ,
65- " olmo2 " : create ( Olmo2Configuration . self, Olmo2Model . init) ,
66- " olmo3 " : create ( Olmo3Configuration . self, Olmo3Model . init) ,
67- " bailing_moe " : create ( BailingMoeConfiguration . self, BailingMoeModel . init) ,
68- " lfm2_moe " : create ( LFM2MoEConfiguration . self, LFM2MoEModel . init) ,
69- " nanochat " : create ( NanoChatConfiguration . self, NanoChatModel . init) ,
70- " afmoe " : create ( AfMoEConfiguration . self, AfMoEModel . init) ,
71- " jamba_3b " : create ( JambaConfiguration . self, JambaModel . init) ,
72- " mistral3 " : create ( Mistral3TextConfiguration . self, Mistral3TextModel . init) ,
73- ]
74- }
26+ public static let shared : ModelTypeRegistry = . init( creators: [
27+ " mistral " : create ( LlamaConfiguration . self, LlamaModel . init) ,
28+ " llama " : create ( LlamaConfiguration . self, LlamaModel . init) ,
29+ " phi " : create ( PhiConfiguration . self, PhiModel . init) ,
30+ " phi3 " : create ( Phi3Configuration . self, Phi3Model . init) ,
31+ " phimoe " : create ( PhiMoEConfiguration . self, PhiMoEModel . init) ,
32+ " gemma " : create ( GemmaConfiguration . self, GemmaModel . init) ,
33+ " gemma2 " : create ( Gemma2Configuration . self, Gemma2Model . init) ,
34+ " gemma3 " : create ( Gemma3TextConfiguration . self, Gemma3TextModel . init) ,
35+ " gemma3_text " : create ( Gemma3TextConfiguration . self, Gemma3TextModel . init) ,
36+ " gemma3n " : create ( Gemma3nTextConfiguration . self, Gemma3nTextModel . init) ,
37+ " qwen2 " : create ( Qwen2Configuration . self, Qwen2Model . init) ,
38+ " qwen3 " : create ( Qwen3Configuration . self, Qwen3Model . init) ,
39+ " qwen3_moe " : create ( Qwen3MoEConfiguration . self, Qwen3MoEModel . init) ,
40+ " starcoder2 " : create ( Starcoder2Configuration . self, Starcoder2Model . init) ,
41+ " cohere " : create ( CohereConfiguration . self, CohereModel . init) ,
42+ " openelm " : create ( OpenElmConfiguration . self, OpenELMModel . init) ,
43+ " internlm2 " : create ( InternLM2Configuration . self, InternLM2Model . init) ,
44+ " deepseek_v3 " : create ( DeepseekV3Configuration . self, DeepseekV3Model . init) ,
45+ " granite " : create ( GraniteConfiguration . self, GraniteModel . init) ,
46+ " granitemoehybrid " : create (
47+ GraniteMoeHybridConfiguration . self, GraniteMoeHybridModel . init) ,
48+ " mimo " : create ( MiMoConfiguration . self, MiMoModel . init) ,
49+ " glm4 " : create ( GLM4Configuration . self, GLM4Model . init) ,
50+ " acereason " : create ( Qwen2Configuration . self, Qwen2Model . init) ,
51+ " falcon_h1 " : create ( FalconH1Configuration . self, FalconH1Model . init) ,
52+ " bitnet " : create ( BitnetConfiguration . self, BitnetModel . init) ,
53+ " smollm3 " : create ( SmolLM3Configuration . self, SmolLM3Model . init) ,
54+ " ernie4_5 " : create ( Ernie45Configuration . self, Ernie45Model . init) ,
55+ " lfm2 " : create ( LFM2Configuration . self, LFM2Model . init) ,
56+ " baichuan_m1 " : create ( BaichuanM1Configuration . self, BaichuanM1Model . init) ,
57+ " exaone4 " : create ( Exaone4Configuration . self, Exaone4Model . init) ,
58+ " gpt_oss " : create ( GPTOSSConfiguration . self, GPTOSSModel . init) ,
59+ " lille-130m " : create ( Lille130mConfiguration . self, Lille130mModel . init) ,
60+ " olmoe " : create ( OlmoEConfiguration . self, OlmoEModel . init) ,
61+ " olmo2 " : create ( Olmo2Configuration . self, Olmo2Model . init) ,
62+ " olmo3 " : create ( Olmo3Configuration . self, Olmo3Model . init) ,
63+ " bailing_moe " : create ( BailingMoeConfiguration . self, BailingMoeModel . init) ,
64+ " lfm2_moe " : create ( LFM2MoEConfiguration . self, LFM2MoEModel . init) ,
65+ " nanochat " : create ( NanoChatConfiguration . self, NanoChatModel . init) ,
66+ " afmoe " : create ( AfMoEConfiguration . self, AfMoEModel . init) ,
67+ " jamba_3b " : create ( JambaConfiguration . self, JambaModel . init) ,
68+ " mistral3 " : create ( Mistral3TextConfiguration . self, Mistral3TextModel . init) ,
69+ ] )
7570}
7671
7772/// Registry of models and any overrides that go with them, e.g. prompt augmentation.
@@ -458,7 +453,7 @@ private struct LLMUserInputProcessor: UserInputProcessor {
458453/// let modelContainer = try await LLMModelFactory.shared.loadContainer(
459454/// configuration: LLMRegistry.llama3_8B_4bit)
460455/// ```
461- public class LLMModelFactory : ModelFactory {
456+ public final class LLMModelFactory : ModelFactory {
462457
463458 public init ( typeRegistry: ModelTypeRegistry , modelRegistry: AbstractModelRegistry ) {
464459 self . typeRegistry = typeRegistry
@@ -478,7 +473,7 @@ public class LLMModelFactory: ModelFactory {
478473 public func _load(
479474 hub: HubApi , configuration: ModelConfiguration ,
480475 progressHandler: @Sendable @escaping ( Progress ) -> Void
481- ) async throws -> sending ModelContext {
476+ ) async throws -> ModelContext {
482477 // download weights and config
483478 let modelDirectory = try await downloadModel (
484479 hub: hub, configuration: configuration, progressHandler: progressHandler)
@@ -497,7 +492,7 @@ public class LLMModelFactory: ModelFactory {
497492
498493 let model : LanguageModel
499494 do {
500- model = try typeRegistry. createModel (
495+ model = try await typeRegistry. createModel (
501496 configuration: configurationURL, modelType: baseConfig. modelType)
502497 } catch let error as DecodingError {
503498 throw ModelFactoryError . configurationDecodingError (
0 commit comments