Skip to content

naming scheme between gemma and gemma2 variants on the command line was not consistent #506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions compression/compress_weights.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,17 @@ struct Args : public ArgsBase<Args> {
"Path to model weights (.bin) file.\n"
" Required argument.");
visitor(model_type_str, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
"Model type\n "
"gemma-2b-it = Gemma 2B parameters, instruction-tuned\n "
"gemma-2b-pt = Gemma 2B parameters, pretrained\n "
"gemma-7b-it = Gemma 7B parameters, instruction-tuned\n "
"gemma-7b-pt = Gemma 7B parameters, pretrained\n "
"gemma2-2b-it = Gemma2 2B parameters, instruction-tuned\n "
"gemma2-2b-pt = Gemma2 2B parameters, pretrained\n "
"gemma2-9b-it = Gemma2 9B parameters, instruction-tuned\n "
"gemma2-9b-pt = Gemma2 9B parameters, pretrained\n "
"gemma2-27b-it = Gemma2 27B parameters, instruction-tuned\n "
"gemma2-27b-pt = Gemma2 27B parameters, pretrained\n "
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument.");
Expand Down
12 changes: 12 additions & 0 deletions gemma/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ constexpr const char* kModelFlags[] = {
"paligemma2-3b-448", // PaliGemma2 3B 448
"paligemma2-10b-224", // PaliGemma2 10B 224
"paligemma2-10b-448", // PaliGemma2 10B 448
"gemma-2b-pt", "gemma-2b-it",
"gemma-7b-pt", "gemma-7b-it",
"gemma2-9b-pt", "gemma2-9b-it",
"gemma2-27b-pt", "gemma2-27b-it",
};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
Expand All @@ -59,6 +63,10 @@ constexpr Model kModelTypes[] = {
Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
};
constexpr PromptWrapping kPromptWrapping[] = {
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
Expand All @@ -71,6 +79,10 @@ constexpr PromptWrapping kPromptWrapping[] = {
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 7B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 9B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 27B
};

constexpr size_t kNumModelFlags = std::size(kModelFlags);
Expand Down
14 changes: 11 additions & 3 deletions util/app.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,17 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
visitor(compressed_weights, "compressed_weights", Path(),
"Alias for --weights.");
visitor(model_type_str, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
"Model type\n "
"gemma-2b-it = Gemma 2B parameters, instruction-tuned\n "
"gemma-2b-pt = Gemma 2B parameters, pretrained\n "
"gemma-7b-it = Gemma 7B parameters, instruction-tuned\n "
"gemma-7b-pt = Gemma 7B parameters, pretrained\n "
"gemma2-2b-it = Gemma2 2B parameters, instruction-tuned\n "
"gemma2-2b-pt = Gemma2 2B parameters, pretrained\n "
"gemma2-9b-it = Gemma2 9B parameters, instruction-tuned\n "
"gemma2-9b-pt = Gemma2 9B parameters, pretrained\n "
"gemma2-27b-it = Gemma2 27B parameters, instruction-tuned\n "
"gemma2-27b-pt = Gemma2 27B parameters, pretrained\n "
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained.");
visitor(weight_type_str, "weight_type", std::string("sfp"),
Expand Down
Loading