Skip to content

Commit ff5a6b4

Browse files
Fix: Python 3.14 / torch.compile compatibility
1 parent 364b00e commit ff5a6b4

File tree

1 file changed

+29
-2
lines changed
  • bitsandbytes/backends/default

1 file changed

+29
-2
lines changed

bitsandbytes/backends/default/ops.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Sequence
2+
from functools import wraps
23
from math import prod, sqrt
34
from typing import Optional
45

@@ -8,6 +9,32 @@
89
from ..utils import CODE
910

1011

12+
def _try_torch_compile(func=None, **compile_kwargs):
13+
"""
14+
Wrapper around torch.compile that falls back to the original function if compilation fails.
15+
"""
16+
17+
def decorator(fn):
18+
try:
19+
compiled_fn = torch.compile(fn, **compile_kwargs)
20+
21+
@wraps(fn)
22+
def wrapper(*args, **kwargs):
23+
try:
24+
return compiled_fn(*args, **kwargs)
25+
except Exception:
26+
return fn(*args, **kwargs)
27+
28+
return wrapper
29+
except Exception:
30+
return fn
31+
32+
if func is None:
33+
return decorator
34+
else:
35+
return decorator(func)
36+
37+
1138
@register_kernel("bitsandbytes::int8_mm_dequant", "default")
1239
def _(
1340
A: torch.Tensor,
@@ -332,7 +359,7 @@ def _(
332359
}
333360

334361

335-
@torch.compile
362+
@_try_torch_compile
336363
def _optimizer_precondition_32bit(
337364
g: torch.Tensor,
338365
p: torch.Tensor,
@@ -393,7 +420,7 @@ def _optimizer_precondition_32bit(
393420
unorm_vec.add_(total_norm)
394421

395422

396-
@torch.compile
423+
@_try_torch_compile
397424
def _optimizer_update_32bit(
398425
g: torch.Tensor,
399426
p: torch.Tensor,

0 commit comments

Comments
 (0)