Skip to content

1.31x batch prefill, 1.24x batch decode speedup: NUMA binding #569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ cc_library(
"//io:fields",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
],
)

Expand Down Expand Up @@ -217,6 +216,7 @@ cc_library(
":configs",
":mat",
":model_store",
":ops",
":tensor_info",
":threading_context",
"//compression:compress",
Expand Down Expand Up @@ -281,7 +281,6 @@ cc_library(
":allocator",
":basics",
":mat",
":threading",
":threading_context",
"//compression:compress",
"@highway//:algo",
Expand Down
3 changes: 2 additions & 1 deletion backprop/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <cmath>
#include <complex>
#include <vector>

#include "gtest/gtest.h"
#include "gemma/configs.h"
Expand Down Expand Up @@ -75,7 +76,7 @@ class WeightsWrapper {
ModelWeightsPtrs<T>& get() { return weights_; }

private:
MatOwners owners_;
std::vector<MatOwner> owners_;
ModelWeightsPtrs<T> weights_;
};

Expand Down
5 changes: 3 additions & 2 deletions compression/python/compression_clif_aux.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class SbsWriterImpl : public ISbsWriter {
}

mat.AppendTo(serialized_mat_ptrs_);
mat_owners_.AllocateFor(mat, MatPadding::kPacked);
mat_owners_.push_back(MatOwner());
mat_owners_.back().AllocateFor(mat, MatPadding::kPacked);

// Handle gemma_export_test's MockArray. Write blobs so that the test
// succeeds, but we only have 10 floats, not the full tensor.
Expand Down Expand Up @@ -121,7 +122,7 @@ class SbsWriterImpl : public ISbsWriter {
}

hwy::ThreadPool& pool_;
MatOwners mat_owners_;
std::vector<MatOwner> mat_owners_;
CompressWorkingSet working_set_;
BlobWriter writer_;
std::vector<uint32_t> serialized_mat_ptrs_;
Expand Down
28 changes: 14 additions & 14 deletions evals/gemma_batch_bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "hwy/base.h"
#include "hwy/profiler.h"
#include "hwy/tests/hwy_gtest.h"

// This test can be run manually with the downloaded gemma weights.
// To run the test, pass the following flags:
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
// --tokenizer <tokenizer_path> --weights <weights_path>
// It should pass for the following models:
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,

namespace gcpp {
Expand Down Expand Up @@ -76,26 +76,25 @@ class GemmaTest : public ::testing::Test {
return replies;
}

void GenerateTokens(std::vector<std::string> &kQA, size_t num_questions) {
void GenerateTokens(const std::vector<std::string>& questions) {
ASSERT_NE(s_env->GetGemma(), nullptr);

// Fills prompts round robin from `questions` until the desired batch size.
std::vector<std::string> inputs;
inputs.reserve(num_questions);
for (size_t i = 0; i < num_questions; ++i) {
inputs.push_back(kQA[i]);
inputs.reserve(s_env->MutableConfig().decode_qbatch_size);
size_t qpos = 0;
for (size_t i = 0; i < inputs.capacity(); ++i) {
inputs.push_back(questions[qpos++]);
if (qpos == questions.size()) qpos = 0;
}
std::vector<std::string> responses = BatchGemmaReply(inputs);
for (size_t i = 0; i < num_questions; ++i) {
std::string response = responses.at(i);
fprintf(stderr, "Batch answer %zu '%s'\n\n", i + 1, response.c_str());
for (size_t i = 0; i < inputs.size(); ++i) {
fprintf(stderr, "Batch answer %zu '%s'\n\n", i, responses[i].c_str());
}
}
};

TEST_F(GemmaTest, RandomQuestionsBatched) {
s_env->MutableConfig().decode_qbatch_size = 3;
s_env->MutableConfig().verbosity = 5;

static std::vector<std::string> kQA = {
{"Write me a poem about Australia?"},
{"What's the history of Denmark?"},
Expand Down Expand Up @@ -130,8 +129,9 @@ TEST_F(GemmaTest, RandomQuestionsBatched) {
{"Tell me about space travel."},
{"Explain to me how electric cars work."},
};
static const size_t kNum = kQA.size();
GenerateTokens(kQA, kNum);
s_env->MutableConfig().verbosity = 5;
GenerateTokens(kQA);
PROFILER_PRINT_RESULTS();
}
} // namespace
} // namespace gcpp
Expand Down
2 changes: 2 additions & 0 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ struct Activations {

env(env) {
HWY_ASSERT(batch_size != 0);

// Note that BindC on any MatMul output considerably slows down Prefill.
}

void SetBatchSize(size_t batch_size) {
Expand Down
29 changes: 14 additions & 15 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <stdio.h>

#include <algorithm> // std::min
#include <cstdio>
#include <memory> // std::make_unique
#include <vector>

#include "gemma/activations.h"
Expand Down Expand Up @@ -1055,7 +1055,8 @@ HWY_NOINLINE void Prefill(
// intensity, and so we are eventually compute-limited. We could devote some
// threads to parallelizing over queries, but for simplicity we assign them
// all to MatMul.
const size_t max_tbatch_size = activations.x.Rows();
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
HWY_DASSERT(max_tbatch_size <= activations.x.Rows());

// For each query. `qi` is within the batch, not the global query index.
for (size_t qi = 0; qi < num_queries; ++qi) {
Expand Down Expand Up @@ -1429,25 +1430,18 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
// Prefill stops before min_prompt_size - 1 because the last prompt
// token is the first input token for generation.
timing_info.prefill_start = hwy::platform::Now();
// If tbatch is larger than the qbatch we already have in `activations`, then
// allocate prefill_activations, otherwise reuse.
const bool use_prefill_activations =
runtime_config.prefill_tbatch_size > activations.x.Rows();
Activations prefill_activations(
weights.weights_config,
use_prefill_activations ? runtime_config.prefill_tbatch_size : 0,
activations.env);
// Note that Prefill calls activations.SetBatchSize, so we reset it below.
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
query_idx_start, weights,
use_prefill_activations ? prefill_activations : activations,
runtime_config, div_seq_len, kv_caches);
query_idx_start, weights, activations, runtime_config, div_seq_len,
kv_caches);
// Compute the number of tokens that were prefilled and notify timing_info.
size_t prefilled_tokens = 0;
for (size_t qi = 0; qi < num_queries; ++qi) {
prefilled_tokens += queries_prompt[qi].size() - 1;
}
timing_info.NotifyPrefill(prefilled_tokens);
// queries_pos are incremented by Prefill.
activations.SetBatchSize(num_queries);

// Storage for the last generated token from each query, passed to the next
// Transformer() call.
Expand Down Expand Up @@ -1489,8 +1483,10 @@ void GenerateSingleT(const ModelStore& model,
constexpr size_t kNumQueries = 1;
const size_t qbatch_start = 0;

const size_t max_batch_size =
HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size);
// TODO: move into Gemma?
Activations activations(model.Config(), kNumQueries, env);
Activations activations(model.Config(), max_batch_size, env);

const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
QueriesPos queries_pos(&pos, kNumQueries);
Expand Down Expand Up @@ -1523,7 +1519,9 @@ void GenerateBatchT(const ModelStore& model,
}
}

Activations activations(model.Config(), max_qbatch_size, env);
const size_t max_batch_size =
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
Activations activations(model.Config(), max_batch_size, env);

for (size_t qbatch_start = 0; qbatch_start < num_queries;
qbatch_start += max_qbatch_size) {
Expand Down Expand Up @@ -1557,6 +1555,7 @@ void GenerateImageTokensT(const ModelStore& model,
prefill_runtime_config.prefill_tbatch_size =
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, vit_config.seq_len, env);
prefill_activations.SetBatchSize(prefill_runtime_config.prefill_tbatch_size);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(weights, prefill_runtime_config, image, image_tokens,
prefill_activations);
Expand Down
39 changes: 29 additions & 10 deletions gemma/weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
#include "gemma/weights.h"

#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>

#include <cstdint>
#include <cstdlib>
#include <memory>
#include <random>
#include <string>
Expand All @@ -30,6 +30,7 @@
#include "gemma/configs.h"
#include "gemma/model_store.h"
#include "io/blob_store.h"
#include "ops/matmul.h" // MMParallel
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/base.h"
Expand All @@ -46,7 +47,7 @@ namespace gcpp {
static void InitAttWeightsNUQ(const LayerConfig& layer_config,
MatPtrT<NuqStream>& attn_vec_einsum_w,
MatPtrT<NuqStream>& att_weights,
MatOwners& mat_owners) {
std::vector<MatOwner>& mat_owners) {
if (!attn_vec_einsum_w.HasPtr()) return;
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);

Expand Down Expand Up @@ -91,11 +92,29 @@ static void SplitW1NUQ(const LayerConfig& layer_config) {
}

template <>
void LayerWeightsPtrs<NuqStream>::Fixup(MatOwners& mat_owners) {
void LayerWeightsPtrs<NuqStream>::Fixup(std::vector<MatOwner>& mat_owners) {
InitAttWeightsNUQ(layer_config, attn_vec_einsum_w, att_weights, mat_owners);
SplitW1NUQ(layer_config);
}

// Allocates multiple in parallel and binds to NUMA nodes.
static void AllocateAndBindAll(const std::vector<MatPtr*>& mats,
MatPadding padding,
std::vector<MatOwner>& owners,
hwy::ThreadPool& pool) {
const size_t start = owners.size();
owners.resize(start + mats.size());

MMParallel parallel(ThreadingContext::Get());

// Allocate in parallel because faulting in large tensors is slow.
pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) {
owners[start + task].AllocateFor(*mats[task], padding);
// TODO(janwas): MatMul outputs will later also be BF16.
BindB(*mats[task], sizeof(float), parallel);
});
}

// Parallel I/O into allocated memory, or mapped view of file. The latter is
// better when the file is huge, but page faults add noise to measurements.
enum class Mode { kRead, kMap };
Expand Down Expand Up @@ -209,10 +228,10 @@ static void ReadBatches(const BlobReader& reader,
}

// Aborts on error.
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
const std::vector<BlobRange>& ranges, Tristate map,
MatOwners& mat_owners, const MatPadding padding,
hwy::ThreadPool& pool) {
static void MapOrReadAll(const std::vector<MatPtr*>& mats, BlobReader& reader,
const std::vector<BlobRange>& ranges, Tristate map,
std::vector<MatOwner>& mat_owners,
const MatPadding padding, hwy::ThreadPool& pool) {
HWY_ASSERT(mats.size() == ranges.size());

if (ChooseMode(reader.file_bytes(), map) == Mode::kMap) {
Expand All @@ -226,7 +245,7 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
{
PROFILER_ZONE("Startup.Weights.Allocate");
// NOTE: this changes the stride of `mats`!
mat_owners.AllocateFor(mats, padding, pool);
AllocateAndBindAll(mats, padding, mat_owners, pool);
}

const std::vector<IOBatch> batches =
Expand Down Expand Up @@ -259,7 +278,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
});
});

MapOrRead(mats, reader, ranges, map, mat_owners_, padding, pool);
MapOrReadAll(mats, reader, ranges, map, mat_owners_, padding, pool);

Fixup(pool);
}
Expand Down
Loading
Loading