Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 73 additions & 68 deletions custom_ops/gpu_ops/get_output_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"
#include "msg_utils.h"
#include "paddle/extension.h"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
Expand All @@ -28,96 +28,101 @@
void GetOutputKVSignal(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
int msg_queue_id = 1024 + rank_id;
static struct msgdatakv msg_rcv;
static key_t key = ftok("/opt/", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
int msg_queue_id = 1024;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只修改此处,其他都是format格式化

if (const char* msg_que_str_tmp = std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string msg_que_str(msg_que_str_tmp);
msg_queue_id = std::stoi(msg_que_str);
}
msg_queue_id += rank_id;
static struct msgdatakv msg_rcv;
static key_t key = ftok("/opt/", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);

int* out_data = const_cast<int*>(x.data<int>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0);
}
if (ret == -1) {
out_data[0] = -1;
out_data[1] = -1;
return;
}
int encoder_count = msg_rcv.mtext[0];

for (int i = 0; i < encoder_count * 3 + 2; i++) {
out_data[i] = msg_rcv.mtext[i];
}
int* out_data = const_cast<int*>(x.data<int>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ * 3 + 2) * 4, 0, 0);
}
if (ret == -1) {
out_data[0] = -1;
out_data[1] = -1;
return;
}
int encoder_count = msg_rcv.mtext[0];

for (int i = 0; i < encoder_count * 3 + 2; i++) {
out_data[i] = msg_rcv.mtext[i];
}
return;
}

void GetOutputEp(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
static struct msgdata msg_rcv;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
static struct msgdata msg_rcv;
if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p);
int inference_msg_queue_id_from_env =
std::stoi(inference_msg_queue_id_env_str);
#ifdef GET_OUTPUT_DEBUG
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
<< inference_msg_queue_id_from_env << std::endl;
#endif
msg_queue_id = inference_msg_queue_id_from_env;
}
msg_queue_id = inference_msg_queue_id_from_env;
}
#ifdef GET_OUTPUT_DEBUG
std::cout << "msg_queue_id is: "
<< msg_queue_id << std::endl;
std::cout << "msg_queue_id is: " << msg_queue_id << std::endl;
#endif
// static key_t key = ftok("/dev/shm", msg_queue_id);
// static int msgid = msgget(key, IPC_CREAT | 0666);
// static key_t key = ftok("/dev/shm", msg_queue_id);
// static int msgid = msgget(key, IPC_CREAT | 0666);

key_t key = ftok("/dev/shm", msg_queue_id);
int msgid = msgget(key, IPC_CREAT | 0666);
key_t key = ftok("/dev/shm", msg_queue_id);
int msgid = msgget(key, IPC_CREAT | 0666);

#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output_key: " << key << std::endl;
std::cout << "get_output msgid: " << msgid << std::endl;
std::cout << "get_output_key: " << key << std::endl;
std::cout << "get_output msgid: " << msgid << std::endl;
#endif

int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
if (ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];
int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
if (ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];

for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
#ifdef GET_OUTPUT_DEBUG
std::cout << "get_output finished: " << msgid << std::endl;
std::cout << "get_output finished: " << msgid << std::endl;
#endif

return;
return;
}

void GetOutputEPStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) {
GetOutputEp(x, rank_id, wait_flag, 1);
void GetOutputEPStatic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
GetOutputEp(x, rank_id, wait_flag, 1);
}

void GetOutputEPDynamic(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
GetOutputEp(x, rank_id, wait_flag, msg_queue_id);
int64_t rank_id,
bool wait_flag,
int msg_queue_id) {
GetOutputEp(x, rank_id, wait_flag, msg_queue_id);
}

PD_BUILD_STATIC_OP(get_output_ep)
Expand Down
10 changes: 8 additions & 2 deletions custom_ops/gpu_ops/remote_cache_kv_ipc.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,14 @@ struct RemoteCacheKvIpc {

if (!inited) {
// just init once
const int msg_id = 1024 + rank;
key_t key = ftok("/opt/", msg_id);
int msg_queue_id = 1024;
if (const char* msg_que_str_tmp =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string msg_que_str(msg_que_str_tmp);
msg_queue_id = std::stoi(msg_que_str);
}
msg_queue_id += rank;
key_t key = ftok("/opt/", msg_queue_id);
msgid = msgget(key, IPC_CREAT | 0666);
inited = true;
}
Expand Down
7 changes: 6 additions & 1 deletion custom_ops/xpu_ops/src/ops/get_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
void GetOutputKVSignal(const paddle::Tensor &x,
int64_t rank_id,
bool wait_flag) {
int msg_queue_id = 1024 + rank_id;
int msg_queue_id = 1024;
if (const char *msg_que_str_tmp = std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string msg_que_str(msg_que_str_tmp);
msg_queue_id = std::stoi(msg_que_str);
}
msg_queue_id += rank_id;
static struct msgdatakv msg_rcv;
static key_t key = ftok("/opt/", msg_queue_id);
static int msgid = msgget(key, IPC_CREAT | 0666);
Expand Down
10 changes: 8 additions & 2 deletions custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,14 @@ struct RemoteCacheKvIpc {

if (!inited) {
// just init once
const int msg_id = 1024 + rank;
key_t key = ftok("/opt/", msg_id);
int msg_queue_id = 1024;
if (const char* msg_que_str_tmp =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string msg_que_str(msg_que_str_tmp);
msg_queue_id = std::stoi(msg_que_str);
}
msg_queue_id += rank;
key_t key = ftok("/opt/", msg_queue_id);
msgid = msgget(key, IPC_CREAT | 0666);
inited = true;
}
Expand Down
Loading