File tree Expand file tree Collapse file tree 1 file changed +29
-2
lines changed
bitsandbytes/backends/default Expand file tree Collapse file tree 1 file changed +29
-2
lines changed Original file line number Diff line number Diff line change 11from collections .abc import Sequence
2+ from functools import wraps
23from math import prod , sqrt
34from typing import Optional
45
89from ..utils import CODE
910
1011
12+ def _try_torch_compile (func = None , ** compile_kwargs ):
13+ """
14+ Wrapper around torch.compile that falls back to the original function if compilation fails.
15+ """
16+
17+ def decorator (fn ):
18+ try :
19+ compiled_fn = torch .compile (fn , ** compile_kwargs )
20+
21+ @wraps (fn )
22+ def wrapper (* args , ** kwargs ):
23+ try :
24+ return compiled_fn (* args , ** kwargs )
25+ except Exception :
26+ return fn (* args , ** kwargs )
27+
28+ return wrapper
29+ except Exception :
30+ return fn
31+
32+ if func is None :
33+ return decorator
34+ else :
35+ return decorator (func )
36+
37+
1138@register_kernel ("bitsandbytes::int8_mm_dequant" , "default" )
1239def _ (
1340 A : torch .Tensor ,
@@ -332,7 +359,7 @@ def _(
332359}
333360
334361
335- @torch . compile
362+ @_try_torch_compile
336363def _optimizer_precondition_32bit (
337364 g : torch .Tensor ,
338365 p : torch .Tensor ,
@@ -393,7 +420,7 @@ def _optimizer_precondition_32bit(
393420 unorm_vec .add_ (total_norm )
394421
395422
396- @torch . compile
423+ @_try_torch_compile
397424def _optimizer_update_32bit (
398425 g : torch .Tensor ,
399426 p : torch .Tensor ,
You can’t perform that action at this time.
0 commit comments