Skip to content

Commit 71872d6

Browse files
committed
fix dyname load bug
1 parent b435639 commit 71872d6

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

fastdeploy/model_executor/layers/quantization/kv_cache.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,11 @@ def process_weights_after_loading(self, layer: nn.Layer):
263263
"""
264264
use for loader v1
265265
"""
266-
if layer.cache_k_scale._is_initialized():
267-
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale)
268-
if layer.cache_v_scale._is_initialized():
269-
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)
266+
if "block_wise" not in layer.cache_quant_type_str:
267+
if layer.cache_k_scale._is_initialized():
268+
layer.cache_k_out_scale.set_value(1 / layer.cache_k_scale)
269+
if layer.cache_v_scale._is_initialized():
270+
layer.cache_v_out_scale.set_value(1 / layer.cache_v_scale)
270271

271272
def apply(self, layer):
272273
"""

fastdeploy/model_executor/model_loader/default_loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer:
9999
def load_rl_mock_model(self, fd_config: FDConfig) -> nn.Layer:
100100
"""use for rl model load"""
101101
# (TODO:gaoziyuan) optimze
102+
assert fd_config.load_config.load_strategy == "normal", fd_config.load_config.load_strategy
102103
original_architectures = fd_config.model_config.architectures[0]
103104
logger.info(f"Starting to load model {original_architectures}.")
104105

@@ -110,16 +111,15 @@ def load_rl_mock_model(self, fd_config: FDConfig) -> nn.Layer:
110111
model_architectures = original_architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
111112

112113
model_architectures += "RL"
113-
context = paddle.LazyGuard()
114+
context = contextlib.nullcontext()
114115

115116
with context:
116117
model_cls = ModelRegistry.get_class(model_architectures)
117118
model = model_cls(fd_config)
118119

119120
model.eval()
120121

121-
if fd_config.load_config.load_strategy == "normal":
122-
# normal strategy need load weight and architectures need without "RL"
123-
self.load_weights(model, fd_config, original_architectures)
122+
# normal strategy need load weight and architectures need without "RL"
123+
self.load_weights(model, fd_config, original_architectures)
124124
# RL model not need set_state_dict
125125
return model

fastdeploy/model_executor/model_loader/default_loader_v1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer:
102102
def load_rl_mock_model(self, fd_config: FDConfig) -> nn.Layer:
103103
"""use for rl model load"""
104104
# (TODO:gaoziyuan) optimze
105+
assert fd_config.load_config.load_strategy == "normal", fd_config.load_config.load_strategy
105106
original_architectures = fd_config.model_config.architectures[0]
106107

107108
import fastdeploy.rl # noqa
@@ -120,8 +121,7 @@ def load_rl_mock_model(self, fd_config: FDConfig) -> nn.Layer:
120121

121122
model.eval()
122123

123-
if fd_config.load_config.load_strategy == "normal":
124-
# normal strategy need load weight and architectures need without "RL"
125-
self.load_weights(model, fd_config, original_architectures)
124+
# normal strategy need load weight and architectures need without "RL"
125+
self.load_weights(model, fd_config, original_architectures)
126126
# RL model not need set_state_dict
127127
return model

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17+
import gc
1718
import os
1819
import time
1920
from multiprocessing.shared_memory import SharedMemory
@@ -107,8 +108,12 @@ def _normal_load_weight(self):
107108
from fastdeploy.model_executor.model_loader import get_model_loader
108109

109110
model_loader = get_model_loader(load_config=self.fd_config.load_config)
110-
state_dict = model_loader.load_rl_mock_model(fd_config=self.fd_config).state_dict()
111+
model = model_loader.load_rl_mock_model(fd_config=self.fd_config)
112+
state_dict = model.state_dict()
111113
self._update_model_from_state(state_dict, "raw")
114+
del model
115+
del state_dict
116+
gc.collect()
112117

113118
def _update_ipc_snapshot(self):
114119
"""Update using IPC snapshot strategy for elastic recovery."""

0 commit comments

Comments
 (0)