Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions fastdeploy/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight).cast(paddle.get_default_dtype())

if self.use_qk_norm and ("q_norm" in param.name or "k_norm" in param.name):
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type conversion inconsistency: The weight_loader casts to paddle.get_default_dtype() at line 232, but q_norm and k_norm weights are defined with dtype "float32" (lines 203, 211) and explicitly cast to float32 in load_state_dict (lines 224-225).

Consider casting to float32 for consistency:

if self.use_qk_norm and ("q_norm" in param.name or "k_norm" in param.name):
    loaded_weight = get_tensor(loaded_weight).astype("float32")
    param.copy_(loaded_weight, False)
    return
Suggested change
if self.use_qk_norm and ("q_norm" in param.name or "k_norm" in param.name):
if self.use_qk_norm and ("q_norm" in param.name or "k_norm" in param.name):
loaded_weight = loaded_weight.astype("float32")

Copilot uses AI. Check for mistakes.
param.copy_(loaded_weight, False)
return

if self.quant_method.cache_quant_config.has_zero_point: # cache_int4_zp
loaded_weight = 1.0 / loaded_weight
else:
Expand Down
52 changes: 52 additions & 0 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.utils import (
default_weight_loader,
fd_cast,
h2d_copy,
process_weight_transpose,
set_weight_attrs,
Expand Down Expand Up @@ -878,6 +879,57 @@ def __init__(
if self.with_bias and self.tp_size > 1 and self.reduce_results:
set_weight_attrs(self.bias, {"tp_row_bias": True})

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
# In some senerio such as tsp, weight and bias of this layer will not be split in specific module.
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: "senerio" should be "scenario".

Suggested change
# In some senerio such as tsp, weight and bias of this layer will not be split in specific module.
# In some scenario such as tsp, weight and bias of this layer will not be split in specific module.

Copilot uses AI. Check for mistakes.
# For example, weight and bias of this layer in shared_experts will not split, but will be split in o_proj.
# So, we add a white list to avoid split weight and bias in these layers.
layer_white_list = ["shared_experts"]
layer_in_white_list = any(key in self.prefix for key in layer_white_list)

output_dim = getattr(param, "output_dim", None)
weight_need_transpose = getattr(param, "weight_need_transpose", False)
if weight_need_transpose:
loaded_weight = loaded_weight.transpose([1, 0])
# Tensor parallelism splits the weight along the output_dim
if (
output_dim is not None
and self.fd_config is not None
and self.fd_config.parallel_config.tensor_parallel_size > 1
):
dim = -1 if output_dim else 0
if isinstance(loaded_weight, paddle.Tensor):
size = loaded_weight.shape[dim]
else:
size = loaded_weight.get_shape()[dim]
block_size = size // self.fd_config.parallel_config.tensor_parallel_size
shard_offset = self.fd_config.parallel_config.tensor_parallel_rank * block_size
shard_size = (self.fd_config.parallel_config.tensor_parallel_rank + 1) * block_size

# when use_sequence_parallel_moe, we don't split.
if layer_in_white_list:
pass
else:
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
Comment on lines +906 to +912
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name shard_size is misleading. It's actually used as the end index (absolute position), not a size. Consider renaming to shard_end for clarity, consistent with how slice_fn uses start and end parameters.

Suggested change
shard_size = (self.fd_config.parallel_config.tensor_parallel_rank + 1) * block_size
# when use_sequence_parallel_moe, we don't split.
if layer_in_white_list:
pass
else:
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
shard_end = (self.fd_config.parallel_config.tensor_parallel_rank + 1) * block_size
# when use_sequence_parallel_moe, we don't split.
if layer_in_white_list:
pass
else:
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_end)

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with the rest of the codebase, use named parameters when calling slice_fn. Change to:

loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)

This matches the pattern used in lines 543, 696, and other weight_loader implementations.

Suggested change
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)

Copilot uses AI. Check for mistakes.

tp_row_bias = getattr(param, "tp_row_bias", None)
if layer_in_white_list:
pass
else:
if tp_row_bias:
loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size
Comment on lines +909 to +919
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The empty pass statements in the whitelist check reduce code readability. Consider refactoring to a more explicit pattern:

if not layer_in_white_list:
    loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)

This eliminates the unnecessary pass statements and makes the control flow clearer.

Suggested change
if layer_in_white_list:
pass
else:
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
tp_row_bias = getattr(param, "tp_row_bias", None)
if layer_in_white_list:
pass
else:
if tp_row_bias:
loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size
if not layer_in_white_list:
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
tp_row_bias = getattr(param, "tp_row_bias", None)
if not layer_in_white_list and tp_row_bias:
loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size

Copilot uses AI. Check for mistakes.
Comment on lines +909 to +919
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Similar to above, the empty pass statements for the whitelist check reduce readability. Consider refactoring to:

if tp_row_bias and not layer_in_white_list:
    loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size
Suggested change
if layer_in_white_list:
pass
else:
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
tp_row_bias = getattr(param, "tp_row_bias", None)
if layer_in_white_list:
pass
else:
if tp_row_bias:
loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size
if not layer_in_white_list:
loaded_weight = slice_fn(loaded_weight, output_dim, shard_offset, shard_size)
tp_row_bias = getattr(param, "tp_row_bias", None)
if tp_row_bias and not layer_in_white_list:
loaded_weight = loaded_weight / self.fd_config.parallel_config.tensor_parallel_size

Copilot uses AI. Check for mistakes.

# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
loaded_weight = fd_cast(loaded_weight, param)

if param.shape != loaded_weight.shape:
# for e_score_correction_bias
loaded_weight = loaded_weight.reshape(param.shape)
assert param.shape == loaded_weight.shape, (
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
)
loaded_weight = get_tensor(loaded_weight)
param.copy_(loaded_weight, False)

def all2all_transpose(self, x: paddle.Tensor) -> paddle.Tensor:
token_num = x.shape[0]
token_num_pad = (token_num + self.tp_size - 1) // self.tp_size * self.tp_size
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/model_executor/layers/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def __init__(
},
)
set_weight_attrs(self.linear.weight, {"output_dim": True})
if with_bias:
set_weight_attrs(self.linear.bias, {"output_dim": True})

else:
self.linear = RowParallelLinear(
embedding_dim,
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/mtp_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(
)
if self.tp_size > 1:
set_weight_attrs(self.linear.weight, {"output_dim": True})
set_weight_attrs(self.linear.bias, {"output_dim": True})
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential bug: Setting attributes on self.linear.bias without checking if bias exists. When with_bias=False, self.linear.bias may not exist, which would cause an AttributeError.

Add a conditional check:

if self.tp_size > 1:
    set_weight_attrs(self.linear.weight, {"output_dim": True})
    if self.bias_key is not None:
        set_weight_attrs(self.linear.bias, {"output_dim": True})
Suggested change
set_weight_attrs(self.linear.bias, {"output_dim": True})
if self.bias_key is not None:
set_weight_attrs(self.linear.bias, {"output_dim": True})

Copilot uses AI. Check for mistakes.

else:
self.linear = RowParallelLinear(
embedding_dim,
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def init_weight(self):
dtype=self._norm_weight_dtype,
)

def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
loaded_weight = get_tensor(loaded_weight).astype(self._norm_weight_dtype)
param.copy_(loaded_weight, False)

def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
"""
Load the checkpoint state dictionary into the layer.
Expand Down
Loading