diff --git a/custom_ops/gpu_ops/get_output_ep.cc b/custom_ops/gpu_ops/get_output_ep.cc index 68730615f23..7cb78ed49f5 100644 --- a/custom_ops/gpu_ops/get_output_ep.cc +++ b/custom_ops/gpu_ops/get_output_ep.cc @@ -17,8 +17,8 @@ #include #include #include -#include "paddle/extension.h" #include "msg_utils.h" +#include "paddle/extension.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) @@ -28,96 +28,101 @@ void GetOutputKVSignal(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) { - int msg_queue_id = 1024 + rank_id; - static struct msgdatakv msg_rcv; - static key_t key = ftok("/opt/", msg_queue_id); - static int msgid = msgget(key, IPC_CREAT | 0666); + int msg_queue_id = 1024; + if (const char* msg_que_str_tmp = std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string msg_que_str(msg_que_str_tmp); + msg_queue_id = std::stoi(msg_que_str); + } + msg_queue_id += rank_id; + static struct msgdatakv msg_rcv; + static key_t key = ftok("/opt/", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); - int* out_data = const_cast(x.data()); - int ret = -1; - if (!wait_flag) { - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT); - } else { - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0); - } - if (ret == -1) { - out_data[0] = -1; - out_data[1] = -1; - return; - } - int encoder_count = msg_rcv.mtext[0]; - - for (int i = 0; i < encoder_count * 3 + 2; i++) { - out_data[i] = msg_rcv.mtext[i]; - } + int* out_data = const_cast(x.data()); + int ret = -1; + if (!wait_flag) { + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT); + } else { + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0); + } + if (ret == -1) { + out_data[0] = -1; + out_data[1] = -1; return; + } + int encoder_count = msg_rcv.mtext[0]; + + for (int i = 0; i < encoder_count * 3 + 2; i++) { + out_data[i] = msg_rcv.mtext[i]; + } + return; } void GetOutputEp(const paddle::Tensor& x, - int64_t rank_id, - bool wait_flag, - int msg_queue_id) { - static struct msgdata msg_rcv; - if (const char* inference_msg_queue_id_env_p = - std::getenv("INFERENCE_MSG_QUEUE_ID")) { - std::string inference_msg_queue_id_env_str( - inference_msg_queue_id_env_p); - int inference_msg_queue_id_from_env = - std::stoi(inference_msg_queue_id_env_str); + int64_t rank_id, + bool wait_flag, + int msg_queue_id) { + static struct msgdata msg_rcv; + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); #ifdef GET_OUTPUT_DEBUG - std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " - << inference_msg_queue_id_from_env << std::endl; + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; #endif - msg_queue_id = inference_msg_queue_id_from_env; - } + msg_queue_id = inference_msg_queue_id_from_env; + } #ifdef GET_OUTPUT_DEBUG - std::cout << "msg_queue_id is: " - << msg_queue_id << std::endl; + std::cout << "msg_queue_id is: " << msg_queue_id << std::endl; #endif - // static key_t key = ftok("/dev/shm", msg_queue_id); - // static int msgid = msgget(key, IPC_CREAT | 0666); + // static key_t key = ftok("/dev/shm", msg_queue_id); + // static int msgid = msgget(key, IPC_CREAT | 0666); - key_t key = ftok("/dev/shm", msg_queue_id); - int msgid = msgget(key, IPC_CREAT | 0666); + key_t key = ftok("/dev/shm", msg_queue_id); + int msgid = msgget(key, IPC_CREAT | 0666); #ifdef GET_OUTPUT_DEBUG - std::cout << "get_output_key: " << key << std::endl; - std::cout << "get_output msgid: " << msgid << std::endl; + std::cout << "get_output_key: " << key << std::endl; + std::cout << "get_output msgid: " << msgid << std::endl; #endif - int64_t* out_data = const_cast(x.data()); - int ret = -1; - if (!wait_flag) { - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT); - } else { - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0); - } - if (ret == -1) { - out_data[0] = -2; - out_data[1] = 0; - return; - } - int bsz = msg_rcv.mtext[1]; + int64_t* out_data = const_cast(x.data()); + int ret = -1; + if (!wait_flag) { + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT); + } else { + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0); + } + if (ret == -1) { + out_data[0] = -2; + out_data[1] = 0; + return; + } + int bsz = msg_rcv.mtext[1]; - for (int64_t i = 0; i < bsz + 2; i++) { - out_data[i] = (int64_t)msg_rcv.mtext[i]; - } + for (int64_t i = 0; i < bsz + 2; i++) { + out_data[i] = (int64_t)msg_rcv.mtext[i]; + } #ifdef GET_OUTPUT_DEBUG - std::cout << "get_output finished: " << msgid << std::endl; + std::cout << "get_output finished: " << msgid << std::endl; #endif - return; + return; } -void GetOutputEPStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) { - GetOutputEp(x, rank_id, wait_flag, 1); +void GetOutputEPStatic(const paddle::Tensor& x, + int64_t rank_id, + bool wait_flag) { + GetOutputEp(x, rank_id, wait_flag, 1); } void GetOutputEPDynamic(const paddle::Tensor& x, - int64_t rank_id, - bool wait_flag, - int msg_queue_id) { - GetOutputEp(x, rank_id, wait_flag, msg_queue_id); + int64_t rank_id, + bool wait_flag, + int msg_queue_id) { + GetOutputEp(x, rank_id, wait_flag, msg_queue_id); } PD_BUILD_STATIC_OP(get_output_ep) diff --git a/custom_ops/gpu_ops/remote_cache_kv_ipc.h b/custom_ops/gpu_ops/remote_cache_kv_ipc.h index 759e1d65012..2ed5f466d33 100644 --- a/custom_ops/gpu_ops/remote_cache_kv_ipc.h +++ b/custom_ops/gpu_ops/remote_cache_kv_ipc.h @@ -73,8 +73,14 @@ struct RemoteCacheKvIpc { if (!inited) { // just init once - const int msg_id = 1024 + rank; - key_t key = ftok("/opt/", msg_id); + int msg_queue_id = 1024; + if (const char* msg_que_str_tmp = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string msg_que_str(msg_que_str_tmp); + msg_queue_id = std::stoi(msg_que_str); + } + msg_queue_id += rank; + key_t key = ftok("/opt/", msg_queue_id); msgid = msgget(key, IPC_CREAT | 0666); inited = true; } diff --git a/custom_ops/xpu_ops/src/ops/get_output.cc b/custom_ops/xpu_ops/src/ops/get_output.cc index a1150e0087b..e2cf48aab42 100644 --- a/custom_ops/xpu_ops/src/ops/get_output.cc +++ b/custom_ops/xpu_ops/src/ops/get_output.cc @@ -23,7 +23,12 @@ void GetOutputKVSignal(const paddle::Tensor &x, int64_t rank_id, bool wait_flag) { - int msg_queue_id = 1024 + rank_id; + int msg_queue_id = 1024; + if (const char *msg_que_str_tmp = std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string msg_que_str(msg_que_str_tmp); + msg_queue_id = std::stoi(msg_que_str); + } + msg_queue_id += rank_id; static struct msgdatakv msg_rcv; static key_t key = ftok("/opt/", msg_queue_id); static int msgid = msgget(key, IPC_CREAT | 0666); diff --git a/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h b/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h index ff384f02e86..ffbc85c7cbb 100644 --- a/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h +++ b/custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h @@ -64,8 +64,14 @@ struct RemoteCacheKvIpc { if (!inited) { // just init once - const int msg_id = 1024 + rank; - key_t key = ftok("/opt/", msg_id); + int msg_queue_id = 1024; + if (const char* msg_que_str_tmp = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string msg_que_str(msg_que_str_tmp); + msg_queue_id = std::stoi(msg_que_str); + } + msg_queue_id += rank; + key_t key = ftok("/opt/", msg_queue_id); msgid = msgget(key, IPC_CREAT | 0666); inited = true; }