diff --git a/src/code/agent/constants.py b/src/code/agent/constants.py index 60964992..8a53f0b8 100644 --- a/src/code/agent/constants.py +++ b/src/code/agent/constants.py @@ -75,7 +75,9 @@ OSS_OUTPUT_DOMAIN = os.getenv("OSS_OUTPUT_DOMAIN", "") OSS_EXPIRES_IN_SECOND = os.getenv("OSS_EXPIRES_IN_SECOND", "") + PREWARM_PROMPT = os.getenv("PREWARM_PROMPT", "") +PREWARM_TIMEOUT = int(os.getenv("PREWARM_TIMEOUT", "600")) # 单位:秒,默认10分钟 class ERROR_CODE(Enum): UNCLASSIFY = "UNCLASSIFY" diff --git a/src/code/agent/routes/routes.py b/src/code/agent/routes/routes.py index b590a8dd..d3a8e1e9 100644 --- a/src/code/agent/routes/routes.py +++ b/src/code/agent/routes/routes.py @@ -1,6 +1,7 @@ import json import logging import threading +import time import traceback import requests @@ -54,18 +55,32 @@ def initialize(): constants.PREWARM_PROMPT and constants.BACKEND_TYPE == constants.TYPE_COMFYUI ): - try: - print("prewarm models") - prompt = json.loads(constants.PREWARM_PROMPT) - api = ServerlessApiService() - api.run(prompt) - api.api_clear_history() - print("prewarm models done") - except Exception as e: - print(f"prewarm models got exception:\n{e}") + prewarm_done = threading.Event() + prewarm_result = {"exception": None} + + t = threading.Thread(target=self.prewarm_func, args=(prewarm_done, prewarm_result)) + t.start() + t.join(timeout=constants.PREWARM_TIMEOUT) + if not prewarm_done.is_set(): + print(f"prewarm timeout after {constants.PREWARM_TIMEOUT} seconds, skip prewarm and continue.") + # 如果有异常也继续往下执行 print("FC Initialize End RequestId: " + request_id) - return "Function is initialized, request_id: " + request_id + "\n" + + def prewarm_func(self, prewarm_done, prewarm_result): + try: + print("prewarm models") + prompt = json.loads(constants.PREWARM_PROMPT) + api = ServerlessApiService() + api.run(prompt) + api.api_clear_history() + print("prewarm models done") + except Exception as e: + prewarm_result["exception"] = e + print(f"prewarm models got exception:\n{e}") + finally: + prewarm_done.set() + @self.app.route("/pre-stop", methods=["GET"]) def pre_stop():