Skip to content

Conversation

@aendk
Copy link
Contributor

@aendk aendk commented Dec 5, 2025

[DRAFT]
This PR suggest to remove some superfluous synchronization calls between tokens to be faster on CUDA backends. I see between 1% and 2% perf gain depending on the model, GPU and settings.

Mechanism

The performance impact is best explained visually. Here are the "before" and "after" nsight system traces. They are not to scale. The relevant part is the row with the green and red bubbles. Both images show the overhead between the GPU execution of two tokens. The generation of the n-th token ends on the left hand side of the screenshot, in the green bubble titled cudaStreamSynchronize. The calculation of the next/n+1st token starts on the right hand side, at the green bar titled cudaGraphLaunch. In between, there is CPU orchestration overhead. This PR aims to shrink the time spent in the middle, between GPU token generation. Original:

Screenshot 2025-12-05 at 14 45 35

In the middle of the above image, we see red and green bubbles alternating. In this case, the green bubbles are synchronization steps, the red bubbles are asynchronous copy calls from host to device. If async operations are immediately followed by synchronization calls, they are executed synchronously. This is not efficient. Removing the green synchronization operations between asynchronous copy calls leads to asynchronous copies and reduced overhead between GPU token generation:

Screenshot 2025-12-05 at 14 45 10

Performance

I benchmarked on a RTX Pro 6000 Blackwell using ./llama-bench -m $models -p 0 -n 128,256,512 -fa 1.
My testing shows around 1% improvement, with gpt-oss-20b gaining up to 1.4%. llama 3B Q4_K - Medium shows very high variance, prompting me to run the tests again with -r 100. At -r 100, a clearer trend of improved performance for gemma3n E2B Q8_0 is also visible.

Details with default `-r 5`

Baseline:

| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        392.24 ± 1.07 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        392.72 ± 0.35 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        387.72 ± 0.38 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |        464.85 ± 0.55 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |        465.39 ± 0.59 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        461.87 ± 0.74 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg128 |        231.59 ± 0.09 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg256 |        231.47 ± 0.03 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg512 |        228.21 ± 0.46 |

build: 909072abc (7176)

PR:

| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        397.14 ± 1.50 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        398.36 ± 0.45 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        393.25 ± 0.65 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |        472.48 ± 3.71 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |        468.81 ± 0.19 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        463.62 ± 1.28 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg128 |        232.84 ± 0.18 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg256 |        232.82 ± 0.08 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg512 |        229.62 ± 0.25 |

build: f6b408d84 (7178)

Speedup:

1.01249
1.01436
1.01426
1.01641
1.00735
1.00379
1.0054
1.00583
1.00618
Details with `-r 100`

Baseline:

| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        393.24 ± 0.45 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        393.33 ± 2.97 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        381.93 ± 2.40 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |       446.41 ± 40.17 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |       451.55 ± 21.34 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        454.89 ± 0.33 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg128 |        231.90 ± 0.27 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg256 |        231.93 ± 0.21 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg512 |        228.47 ± 0.14 |

build: 909072abc (7176)

PR:

| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg128 |        398.52 ± 0.41 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg256 |        397.32 ± 5.71 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | CUDA       |  99 |  1 |           tg512 |        383.53 ± 3.06 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg128 |       441.09 ± 50.39 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg256 |       456.69 ± 20.91 |
| llama 3B Q4_K - Medium         |   1.87 GiB |     3.21 B | CUDA       |  99 |  1 |           tg512 |        458.19 ± 0.32 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg128 |        233.98 ± 0.13 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg256 |        233.65 ± 0.25 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |  1 |           tg512 |        230.18 ± 0.14 |

build: aebcdf119 (7178)

Speedup:

1.01366
1.01025
1.00408
0.982033
1.00875
1.00819
1.00893
1.00876
1.00938

Implementation Concerns

The approach here aims to minimize changes in the general backend, and to other backends. However, the synchronization calls originate from the general backend. Some changes there are unavoidable, as well as retaining a synchronization call after the last copy to ensure correctness across backends.

Additionally, AFAIK there is no documentation on the functional guarantees of a function like ggml_copy_tensor, and it could be that the current design proposal violates existing assumptions, or practices around potentially breaking ABIs between ggml and llama.cpp. For this reason, this PR is a draft.
I also have not copy-pasted my additions ggml_backend_buffer_i interface changes (added set_tensor_async + whitespace) to the the other backends just yet. This causes the unrelated tests to fail.

Please advise on the architectural choices of this implementation.

For example, we could make the set_tensor in the CUDA backend default async. This would avoid interface changes, but change the behavior of similar functions between backends.

@ggerganov @JohannesGaessler

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Dec 5, 2025
@wishstudio
Copy link
Contributor

How about using existing .set_tensor_async in ggml_backend_i? There is also a ggml_backend_tensor_set_async helper.

I believe the set_tensor, get_tensor etc interfaces guarantees synchronization behavior, though I'm not sure which part depends on it. A better plan might be changing each set_tensor use to the async version one-by-one instead of changing them all at once which will likely be ABI change and lead to problems.

Besides, I believe the current way of synchronizing both input_backend and split_backend is still not optimal as some synchronizations can be completely skipped. For example, for a CPU split copying data from GPU, we have to do a synchronization to make sure all data is copied to RAM before use. But for a GPU split copying data from CPU, there is no synchronization needed as the cudaMemcpyAsync call will be automatically serialized with following kernel launches.

For synchronizations at the beginning of this code section (ggml_backend_synchronize(split_backend);, etc) I think the intention is to synchronize before the following steps. The behavior is changed in this patch, though I feel these synchronizations are unnecessary in many cases (especially single GPU).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants