Skip to content

Misc. bug: GGML_ASSERT(n <= tokens.size()) failed - Memory in use ('/completion' endpoint and 'cache_prompt=false') #13484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
broadbit-hu opened this issue May 12, 2025 · 11 comments · Fixed by #13533
Assignees

Comments

@broadbit-hu
Copy link

broadbit-hu commented May 12, 2025

Name and Version

llama.cpp version: b5359 (compiled with -DGGML_RPC=ON)

Model: Mistral-Nemo-12B-Instruct-2407-Q8_0.gguf

Command line arguments:

--flash-attn --temp 0 --seed 1 -c 22000 -ngl 99 --mlock --chat-template mistral-v3-tekken

Error: GGML_ASSERT(n <= tokens.size()) failed

  • Memory critical error by agent node-0 (Agent handle: 0x59fc5fabc930) on address 0x7cbd6cc00000. Reason: Memory in use.
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 40 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 41/41 layers to GPU
load_tensors:        ROCm0 model buffer size = 11731.58 MiB
load_tensors:   CPU_Mapped model buffer size =   680.00 MiB
...........................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 22000
llama_context: n_ctx_per_seq = 22000
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 1
llama_context: freq_base     = 1000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (22000) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_context:  ROCm_Host  output buffer size =     0.50 MiB
llama_kv_cache_unified: kv_size = 22016, type_k = 'f16', type_v = 'f16', n_layer = 40, can_shift = 1, padding = 256
llama_kv_cache_unified:      ROCm0 KV buffer size =  3440.00 MiB
llama_kv_cache_unified: KV self size  = 3440.00 MiB, K (f16): 1720.00 MiB, V (f16): 1720.00 MiB
llama_context:      ROCm0 compute buffer size =   266.00 MiB
llama_context:  ROCm_Host compute buffer size =    53.01 MiB
llama_context: graph nodes  = 1207
llama_context: graph splits = 2
common_init_from_params: setting dry_penalty_last_n to ctx_size = 22016
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
Failed to infer a tool call example (possible template bug)
srv          init: initializing slots, n_slots = 1
slot         init: id  0 | task -1 | new slot n_ctx_slot = 22016
main: model loaded
main: chat template, chat_template: mistral-v3-tekken, example_format: '[INST]You are a helpful assistant

Hello[/INST]Hi there</s>[INST]How are you?[/INST]'
main: server is listening on http://0.0.0.0:18080 - starting the main loop
srv  update_slots: all slots are idle
srv  log_server_r: request: GET /props 192.168.253.130 200
slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 22016, n_keep = 0, n_prompt_tokens = 8241
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.248514
/opt/text/llama.cpp/tools/server/utils.hpp:1157: GGML_ASSERT(n <= tokens.size()) failedslot update_slots: id  0 | task 0 | kv cache rm [2048, end)

Memory critical error by agent node-0 (Agent handle: 0x59fc5fabc930) on address 0x7cbd6cc00000. Reason: Memory in use. 
Aborted (core dumped)

Last working version: b5329

load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 40 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 41/41 layers to GPU
load_tensors:        ROCm0 model buffer size = 11731.58 MiB
load_tensors:   CPU_Mapped model buffer size =   680.00 MiB
...........................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 22000
llama_context: n_ctx_per_seq = 22000
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 1
llama_context: freq_base     = 1000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (22000) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_context:  ROCm_Host  output buffer size =     0.50 MiB
llama_kv_cache_unified: kv_size = 22016, type_k = 'f16', type_v = 'f16', n_layer = 40, can_shift = 1, padding = 256
llama_kv_cache_unified:      ROCm0 KV buffer size =  3440.00 MiB
llama_kv_cache_unified: KV self size  = 3440.00 MiB, K (f16): 1720.00 MiB, V (f16): 1720.00 MiB
llama_context:      ROCm0 compute buffer size =   266.00 MiB
llama_context:  ROCm_Host compute buffer size =    53.01 MiB
llama_context: graph nodes  = 1207
llama_context: graph splits = 2
common_init_from_params: setting dry_penalty_last_n to ctx_size = 22016
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
Failed to infer a tool call example (possible template bug)
srv          init: initializing slots, n_slots = 1
slot         init: id  0 | task -1 | new slot n_ctx_slot = 22016
main: model loaded
main: chat template, chat_template: mistral-v3-tekken, example_format: '[INST]You are a helpful assistant

Hello[/INST]Hi there</s>[INST]How are you?[/INST]'
main: server is listening on http://0.0.0.0:18080 - starting the main loop
srv  update_slots: all slots are idle
srv  log_server_r: request: GET /props 192.168.253.130 200
slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 22016, n_keep = 0, n_prompt_tokens = 8241
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.248514
slot update_slots: id  0 | task 0 | kv cache rm [2048, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 4096, n_tokens = 2048, progress = 0.497027
slot update_slots: id  0 | task 0 | kv cache rm [4096, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 6144, n_tokens = 2048, progress = 0.745541
slot update_slots: id  0 | task 0 | kv cache rm [6144, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 8192, n_tokens = 2048, progress = 0.994054
slot update_slots: id  0 | task 0 | kv cache rm [8192, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 8241, n_tokens = 49, progress = 1.000000
slot update_slots: id  0 | task 0 | prompt done, n_past = 8241, n_tokens = 49
slot      release: id  0 | task 0 | stop processing: n_past = 8638, truncated = 0
slot print_timing: id  0 | task 0 | 
prompt eval time =    5977.53 ms /  8241 tokens (    0.73 ms per token,  1378.66 tokens per second)
       eval time =   12231.95 ms /   398 tokens (   30.73 ms per token,    32.54 tokens per second)
      total time =   18209.48 ms /  8639 tokens
srv  update_slots: all slots are idle

Operating systems

Linux

Which llama.cpp modules do you know to be affected?

llama-server

Command line

llama-server -m Mistral-Nemo-12B-Instruct-2407-Q8_0.gguf --flash-attn --temp 0 --seed 1 -c 22000 -ngl 99 --mlock --chat-template mistral-v3-tekken

Problem description & steps to reproduce

Error GGML_ASSERT(n <= tokens.size()) failedslot update_slots when the input text is long (8241 tokens with 22000 context size)

First Bad Commit

33eff40

Relevant log output

load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 40 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 41/41 layers to GPU
load_tensors:        ROCm0 model buffer size = 11731.58 MiB
load_tensors:   CPU_Mapped model buffer size =   680.00 MiB
...........................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 22000
llama_context: n_ctx_per_seq = 22000
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 1
llama_context: freq_base     = 1000000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (22000) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
llama_context:  ROCm_Host  output buffer size =     0.50 MiB
llama_kv_cache_unified: kv_size = 22016, type_k = 'f16', type_v = 'f16', n_layer = 40, can_shift = 1, padding = 256
llama_kv_cache_unified:      ROCm0 KV buffer size =  3440.00 MiB
llama_kv_cache_unified: KV self size  = 3440.00 MiB, K (f16): 1720.00 MiB, V (f16): 1720.00 MiB
llama_context:      ROCm0 compute buffer size =   266.00 MiB
llama_context:  ROCm_Host compute buffer size =    53.01 MiB
llama_context: graph nodes  = 1207
llama_context: graph splits = 2
common_init_from_params: setting dry_penalty_last_n to ctx_size = 22016
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
Failed to infer a tool call example (possible template bug)
srv          init: initializing slots, n_slots = 1
slot         init: id  0 | task -1 | new slot n_ctx_slot = 22016
main: model loaded
main: chat template, chat_template: mistral-v3-tekken, example_format: '[INST]You are a helpful assistant

Hello[/INST]Hi there</s>[INST]How are you?[/INST]'
main: server is listening on http://0.0.0.0:18080 - starting the main loop
srv  update_slots: all slots are idle
srv  log_server_r: request: GET /props 192.168.253.130 200
slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 22016, n_keep = 0, n_prompt_tokens = 8241
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.248514
/opt/text/llama.cpp/tools/server/utils.hpp:1157: GGML_ASSERT(n <= tokens.size()) failedslot update_slots: id  0 | task 0 | kv cache rm [2048, end)

Memory critical error by agent node-0 (Agent handle: 0x59fc5fabc930) on address 0x7cbd6cc00000. Reason: Memory in use. 
Aborted (core dumped)
@broadbit-hu
Copy link
Author

broadbit-hu commented May 12, 2025

It works with smaller input token sizes, like this 1386 (maybe less than 2048):

slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 22016, n_keep = 0, n_prompt_tokens = 1386
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 1386, n_tokens = 1386, progress = 1.000000
slot update_slots: id  0 | task 0 | prompt done, n_past = 1386, n_tokens = 1386
slot      release: id  0 | task 0 | stop processing: n_past = 1710, truncated = 0
slot print_timing: id  0 | task 0 | 
prompt eval time =     893.83 ms /  1386 tokens (    0.64 ms per token,  1550.64 tokens per second)
       eval time =    8495.26 ms /   325 tokens (   26.14 ms per token,    38.26 tokens per second)
      total time =    9389.08 ms /  1711 tokens
srv  update_slots: all slots are idle

@broadbit-hu
Copy link
Author

broadbit-hu commented May 12, 2025

This bug is reproducible on Nvidia cards as well, with different models, like Qwen Coder:

Build options:

cmake -B build -DGGML_CUDA=ON -DLLAMA_CURL=ON -DGGML_RPC=ON
cmake --build build --config Release

Log:

load_tensors: offloading 28 repeating layers to GPU
load_tensors: offloading output layer to GPU
load_tensors: offloaded 29/29 layers to GPU
load_tensors:        CUDA0 model buffer size =  7165.44 MiB
load_tensors:   CPU_Mapped model buffer size =   552.23 MiB
.......................................................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 32768
llama_context: n_ctx_per_seq = 32768
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 1
llama_context: freq_base     = 1000000.0
llama_context: freq_scale    = 1
llama_context:  CUDA_Host  output buffer size =     0.58 MiB
llama_kv_cache_unified: kv_size = 32768, type_k = 'f16', type_v = 'f16', n_layer = 28, can_shift = 1, padding = 256
llama_kv_cache_unified:      CUDA0 KV buffer size =  1792.00 MiB
llama_kv_cache_unified: KV self size  = 1792.00 MiB, K (f16):  896.00 MiB, V (f16):  896.00 MiB
llama_context:      CUDA0 compute buffer size =   304.00 MiB
llama_context:  CUDA_Host compute buffer size =    71.01 MiB
llama_context: graph nodes  = 931
llama_context: graph splits = 2
common_init_from_params: setting dry_penalty_last_n to ctx_size = 32768
srv          init: initializing slots, n_slots = 1
slot         init: id  0 | task -1 | new slot n_ctx_slot = 32768
main: model loaded
main: chat template, chat_template: {%- for message in messages -%}
  {{- '<|im_start|>' + message.role + '
' + message.content + '<|im_end|>
' -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
  {{- '<|im_start|>assistant
' -}}
{%- endif -%}, example_format: '<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there<|im_end|>
<|im_start|>user
How are you?<|im_end|>
<|im_start|>assistant
'
main: server is listening on http://0.0.0.0:8012 - starting the main loop
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /api/show 192.168.37.63 200
srv  log_server_r: request: GET /props 192.168.38.14 200
slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 32768, n_keep = 0, n_prompt_tokens = 17892
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.114465
llama.cpp/tools/server/utils.hpp:1157: GGML_ASSERT(n <= tokens.size()) failed
slot update_slots: id  0 | task 0 | kv cache rm [2048, end)

@broadbit-hu broadbit-hu changed the title Misc. bug: GGML_ASSERT(n <= tokens.size()) failed Misc. bug: GGML_ASSERT(n <= tokens.size()) failed - Memory in use May 12, 2025
@broadbit-hu broadbit-hu changed the title Misc. bug: GGML_ASSERT(n <= tokens.size()) failed - Memory in use Misc. bug: GGML_ASSERT(n <= tokens.size()) failed - Memory in use (-DGGML_RPC=ON) May 12, 2025
@broadbit-hu broadbit-hu changed the title Misc. bug: GGML_ASSERT(n <= tokens.size()) failed - Memory in use (-DGGML_RPC=ON) Misc. bug: GGML_ASSERT(n <= tokens.size()) failed - Memory in use May 12, 2025
@ngxson ngxson self-assigned this May 13, 2025
@ngxson
Copy link
Collaborator

ngxson commented May 13, 2025

Can you try Llama-3.2-1B-Instruct-Q4_K_M to see if it has the same problem?

llama-server -m Llama-3.2-1B-Instruct-Q4_K_M.gguf -c 22000

https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF


With the model above, it works well even when I enter 12000 tokens

slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 22016, n_keep = 0, n_prompt_tokens = 12010
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.170525
slot update_slots: id  0 | task 0 | kv cache rm [2048, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 4096, n_tokens = 2048, progress = 0.341049
slot update_slots: id  0 | task 0 | kv cache rm [4096, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 6144, n_tokens = 2048, progress = 0.511574
slot update_slots: id  0 | task 0 | kv cache rm [6144, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 8192, n_tokens = 2048, progress = 0.682098
slot update_slots: id  0 | task 0 | kv cache rm [8192, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 10240, n_tokens = 2048, progress = 0.852623
slot update_slots: id  0 | task 0 | kv cache rm [10240, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 12010, n_tokens = 1770, progress = 1.000000
slot update_slots: id  0 | task 0 | prompt done, n_past = 12010, n_tokens = 1770
srv  cancel_tasks: cancel task, id_task = 0
srv  log_server_r: request: POST /v1/chat/completions 127.0.0.1 200

@broadbit-hu
Copy link
Author

broadbit-hu commented May 13, 2025

@ngxson: Thanks for your support!

I'm using '/completion' endpoint, not '/v1/chat/completions'.

Same results with this small model (with or without flash-attention and GPU offloading):

Command:
./build/bin/llama-server -hf ggml-org/SmolVLM-500M-Instruct-GGUF -c 22000 --host 0.0.0.0 --port 18080

llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 22000
llama_context: n_ctx_per_seq = 22000
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 0
llama_context: freq_base     = 100000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (22000) > n_ctx_train (8192) -- possible training context overflow
llama_context:        CPU  output buffer size =     0.19 MiB
llama_kv_cache_unified: kv_size = 22016, type_k = 'f16', type_v = 'f16', n_layer = 32, can_shift = 1, padding = 32
llama_kv_cache_unified:        CPU KV buffer size =   860.00 MiB
llama_kv_cache_unified: KV self size  =  860.00 MiB, K (f16):  430.00 MiB, V (f16):  430.00 MiB
llama_context:      ROCm0 compute buffer size =   723.31 MiB
llama_context:  ROCm_Host compute buffer size =    44.88 MiB
llama_context: graph nodes  = 1094
llama_context: graph splits = 356 (with bs=512), 1 (with bs=1)
common_init_from_params: setting dry_penalty_last_n to ctx_size = 22016
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
Failed to generate tool call example: Value is not callable: null at row 1, column 72:
<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>
                                                                       ^
{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}
 at row 1, column 42:
<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>
                                         ^
{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}
 at row 1, column 42:
<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>
                                         ^
{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}
 at row 1, column 13:
<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>
            ^
{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}
 at row 1, column 1:
<|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>
^
{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}

clip_ctx: CLIP using ROCm0 backend
clip_model_loader: model name:   SmolVLM 500M Instruct
clip_model_loader: description:  
clip_model_loader: GGUF version: 3
clip_model_loader: alignment:    32
clip_model_loader: n_tensors:    198
clip_model_loader: n_kv:         40

load_hparams: projector:          idefics3
load_hparams: n_embd:             768
load_hparams: n_head:             12
load_hparams: n_ff:               3072
load_hparams: n_layer:            12
load_hparams: projection_dim:     960
load_hparams: image_size:         512
load_hparams: patch_size:         16

load_hparams: has_llava_proj:     0
load_hparams: minicpmv_version:   0
load_hparams: proj_scale_factor:  4
load_hparams: n_wa_pattern:       0
load_hparams: ffn_op:             gelu
load_hparams: model size:         103.73 MiB
load_hparams: metadata size:      0.07 MiB
alloc_compute_meta:      ROCm0 compute buffer size =    60.00 MiB
alloc_compute_meta:        CPU compute buffer size =     3.00 MiB
srv    load_model: loaded multimodal model, '/root/.cache/llama.cpp/ggml-org_SmolVLM-500M-Instruct-GGUF_mmproj-SmolVLM-500M-Instruct-Q8_0.gguf'
srv    load_model: ctx_shift is not supported by multimodal, it will be disabled
srv          init: initializing slots, n_slots = 1
slot         init: id  0 | task -1 | new slot n_ctx_slot = 22016
main: model loaded
main: chat template, chat_template: <|im_start|>{% for message in messages %}{{message['role'] | capitalize}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>
{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}, example_format: '<|im_start|>You are a helpful assistant

User: Hello<end_of_utterance>
Assistant: Hi there<end_of_utterance>
User: How are you?<end_of_utterance>
Assistant:'
main: server is listening on http://0.0.0.0:18080 - starting the main loop
srv  update_slots: all slots are idle
srv  log_server_r: request: GET /props 192.168.253.130 200
slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 22016, n_keep = 0, n_prompt_tokens = 11575
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.176933
/opt/text/llama.cpp/tools/server/utils.hpp:1157: GGML_ASSERT(n <= tokens.size()) failed
slot update_slots: id  0 | task 0 | kv cache rm [2048, end)
Memory critical error by agent node-0 (Agent handle: 0x580e1c919910) on address 0x7f4d04400000. Reason: Memory in use. 
Aborted (core dumped)

@ngxson
Copy link
Collaborator

ngxson commented May 13, 2025

I think it's likely that your source code is not clean, try git clone and rebuild the project from beginning.

Also maybe you are using the build from other source (indicated by the agent node-0 message), try using vanilla llama.cpp instead

@broadbit-hu
Copy link
Author

broadbit-hu commented May 13, 2025

I'm using llama.cpp repo only, the last working release version is b5329.

prompt.json

Here's a complete test to reproduce:

  1. Download and Build:
rm -rf llama.cpp
git clone --branch b5368 https://github.com/ggerganov/llama.cpp.git
GGML_HIP="-DGGML_HIP=ON"
GGML_RPC="-DGGML_RPC=ON"
AMDGPU_TARGETS="-DGPU_TARGETS=gfx1100,gfx1101"
ROCM_FLASHATTENTION="-DGGML_HIP_ROCWMMA_FATTN=ON"

HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
    cmake -S . -B build ${GGML_HIP} ${GGML_RPC} ${AMDGPU_TARGETS} ${ROCM_FLASHATTENTION} -DCMAKE_BUILD_TYPE=Release \
    && cmake --build build --target clean --config Release \
    && cmake --build build --config Release -- -j 16
  1. Start server:
./build/bin/llama-server -hf ggml-org/SmolVLM-500M-Instruct-GGUF -c 22000 --host 0.0.0.0 --port 18080
  1. Post the uploaded "prompt.json" (tokens: 5379):
curl -X POST "http://127.0.0.1:18080/completion" -H "Content-Type: application/json" -d @prompt.json
  1. The result:
srv  log_server_r: request: POST /completion 127.0.0.1 200
slot launch_slot_: id  0 | task 2049 | processing task
slot update_slots: id  0 | task 2049 | new prompt, n_ctx_slot = 22016, n_keep = 0, n_prompt_tokens = 5379
slot update_slots: id  0 | task 2049 | kv cache rm [0, end)
slot update_slots: id  0 | task 2049 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.380740
/opt/text/llama.cpp/tools/server/utils.hpp:1157: GGML_ASSERT(n <= tokens.size()) failed
slot update_slots: id  0 | task 2049 | kv cache rm [2048, end)
Memory critical error by agent node-0 (Agent handle: 0x6285a2216910) on address 0x765cf8c00000. Reason: Memory in use. 
Aborted (core dumped)

@broadbit-hu
Copy link
Author

broadbit-hu commented May 13, 2025

prompt-v1-chat-completions.json

The llama.cpp build is tested against the 'v1/chat/completions' endpoint - it's fine:

curl -X POST "http://192.168.253.167:18080/v1/chat/completions" -H "Content-Type: application/json" -d @prompt-v1-chat-completions.json

Results:

slot launch_slot_: id  0 | task 0 | processing task
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 22016, n_keep = 0, n_prompt_tokens = 5387
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 2048, n_tokens = 2048, progress = 0.380174
slot update_slots: id  0 | task 0 | kv cache rm [2048, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 4096, n_tokens = 2048, progress = 0.760349
slot update_slots: id  0 | task 0 | kv cache rm [4096, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 5387, n_tokens = 1291, progress = 1.000000
slot update_slots: id  0 | task 0 | prompt done, n_past = 5387, n_tokens = 1291
slot process_toke: id  0 | task 0 | n_predict (-1) is set for infinite generation. Limiting generated tokens to n_ctx_train (8192) to avoid EOS-less generation infinite loop
slot      release: id  0 | task 0 | stop processing: n_past = 8191, truncated = 1
slot print_timing: id  0 | task 0 | 
prompt eval time =    1902.14 ms /  5387 tokens (    0.35 ms per token,  2832.07 tokens per second)
       eval time =   78704.18 ms /  2805 tokens (   28.06 ms per token,    35.64 tokens per second)
      total time =   80606.32 ms /  8192 tokens
srv  update_slots: all slots are idle
srv  log_server_r: request: POST /v1/chat/completions 192.168.253.130 200

@broadbit-hu broadbit-hu changed the title Misc. bug: GGML_ASSERT(n <= tokens.size()) failed - Memory in use Misc. bug: GGML_ASSERT(n <= tokens.size()) failed - Memory in use ('/completion' endpoint) May 13, 2025
@broadbit-hu
Copy link
Author

broadbit-hu commented May 13, 2025

Thus, the prompt cache will be the cause. If 'cache_prompt' is set to true, the '/completion' endpoint ('prompt.json') will work again — but this is not a solution for this particular case:

"cache_prompt": true,

@broadbit-hu broadbit-hu changed the title Misc. bug: GGML_ASSERT(n <= tokens.size()) failed - Memory in use ('/completion' endpoint) Misc. bug: GGML_ASSERT(n <= tokens.size()) failed - Memory in use ('/completion' endpoint and 'cache_prompt=false') May 13, 2025
@ngxson
Copy link
Collaborator

ngxson commented May 13, 2025

Thanks for the info. Yes this seems to be a valid bug in all versions (even the old build)

The reason why old version doesn't crash is because std::vector allow resize to an arbitrary number of elements. If we resize to a number larger to number of current tokens in cache, std::value will fill these "added" values with zero, which is technically incorrect.

For example, if first request does NOT use cache, this logic will fill up the cache_tokens with wrong values. If the second request decide to use cache, now it will get an incorrect list of cache_tokens.

The newer version crashes because I added the check to prevent such case from happening.

I'll push a fix for this, cc @ggerganov

@broadbit-hu
Copy link
Author

broadbit-hu commented May 13, 2025

@ngxson Thanks for your support!

@broadbit-hu
Copy link
Author

Tested with release b5379 - work like a charm, thanks for your work guys! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants