Skip to content

fix(rpc): Improve input validation and error handling #13069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 28, 2025
78 changes: 68 additions & 10 deletions ggml/src/ggml-rpc/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -982,8 +982,21 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
}

ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
// Validate tensor type before using it
if (tensor->type >= GGML_TYPE_COUNT) {
GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
return nullptr;
}

ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);

// ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ggml_new_tensor fails it will crash, it will not return NULL. The check is still good for future-proofing, but the comment is misleading.

if (result == nullptr) {
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
return nullptr;
}

for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
result->nb[i] = tensor->nb[i];
}
Expand Down Expand Up @@ -1043,7 +1056,9 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);

if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
__func__, in_tensor->data, offset, size, p0, p1);
return false;
}
}

Expand Down Expand Up @@ -1118,7 +1133,9 @@ bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);

if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
__func__, in_tensor->data, offset, size, *hash, p0, p1);
return false;
}
}
ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
Expand Down Expand Up @@ -1183,7 +1200,9 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
if (request.tensor.data + request.offset < p0 ||
request.tensor.data + request.offset >= p1 ||
request.size > (p1 - request.tensor.data - request.offset)) {
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
__func__, request.tensor.data, request.offset, request.size, p0, p1);
return false;
}
}

Expand Down Expand Up @@ -1237,22 +1256,50 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
struct ggml_context * ctx,
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
if (id == 0) {
return nullptr;
}
if (tensor_map.find(id) != tensor_map.end()) {
return tensor_map[id];
}
const rpc_tensor * tensor = tensor_ptrs.at(id);
// Safely find the tensor pointer
auto it_ptr = tensor_ptrs.find(id);
if (it_ptr == tensor_ptrs.end()) {
return nullptr;
}
const rpc_tensor * tensor = it_ptr->second;

struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
if (result == nullptr) {
return nullptr;
}
tensor_map[id] = result;
for (int i = 0; i < GGML_MAX_SRC; i++) {
result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
// Check if the source ID is 0 before calling create_node recursively
if (tensor->src[i] == 0) {
result->src[i] = nullptr;
} else {
result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
// If the recursive call failed for a non-zero ID, propagate the error
if (result->src[i] == nullptr) {
GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
__func__, i, tensor->src[i], id);
// Must return nullptr to signal failure up the call stack
return nullptr;
}
}
}

// Handle view_src similarly
if (tensor->view_src == 0) {
result->view_src = nullptr;
} else {
result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
// If the recursive call failed for a non-zero ID, propagate the error
if (result->view_src == nullptr) {
GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
__func__, tensor->view_src, id);
// Must return nullptr to signal failure up the call stack
return nullptr;
}
}
result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
result->view_offs = tensor->view_offs;
return result;
}
Expand All @@ -1278,6 +1325,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);

size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);

struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ NULL,
Expand All @@ -1297,6 +1345,14 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
int64_t id;
memcpy(&id, &nodes[i], sizeof(id));
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);

// Check if create_node failed for a *non-zero* ID.
// If id was 0, create_node returning nullptr is expected.
// If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
if (graph->nodes[i] == nullptr && id != 0) {
GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
return false;
}
}
ggml_status status = ggml_backend_graph_compute(backend, graph);
response.result = status;
Expand Down Expand Up @@ -1361,7 +1417,9 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
return;
}
rpc_msg_get_alloc_size_rsp response;
server.get_alloc_size(request, response);
if (!server.get_alloc_size(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
Expand Down
Loading