diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc index cbf7e359..62ea8a6d 100644 --- a/compression/compress_weights.cc +++ b/compression/compress_weights.cc @@ -110,9 +110,17 @@ struct Args : public ArgsBase { "Path to model weights (.bin) file.\n" " Required argument."); visitor(model_type_str, "model", std::string(), - "Model type\n 2b-it = 2B parameters, instruction-tuned\n " - "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " - "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n " + "Model type\n " + "gemma-2b-it = Gemma 2B parameters, instruction-tuned\n " + "gemma-2b-pt = Gemma 2B parameters, pretrained\n " + "gemma-7b-it = Gemma 7B parameters, instruction-tuned\n " + "gemma-7b-pt = Gemma 7B parameters, pretrained\n " + "gemma2-2b-it = Gemma2 2B parameters, instruction-tuned\n " + "gemma2-2b-pt = Gemma2 2B parameters, pretrained\n " + "gemma2-9b-it = Gemma2 9B parameters, instruction-tuned\n " + "gemma2-9b-pt = Gemma2 9B parameters, pretrained\n " + "gemma2-27b-it = Gemma2 27B parameters, instruction-tuned\n " + "gemma2-27b-pt = Gemma2 27B parameters, pretrained\n " "gr2b-it = griffin 2B parameters, instruction-tuned\n " "gr2b-pt = griffin 2B parameters, pretrained\n " " Required argument."); diff --git a/gemma/common.cc b/gemma/common.cc index dc37e3ec..43b5ab26 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -44,6 +44,10 @@ constexpr const char* kModelFlags[] = { "paligemma2-3b-448", // PaliGemma2 3B 448 "paligemma2-10b-224", // PaliGemma2 10B 224 "paligemma2-10b-448", // PaliGemma2 10B 448 + "gemma-2b-pt", "gemma-2b-it", + "gemma-7b-pt", "gemma-7b-it", + "gemma2-9b-pt", "gemma2-9b-it", + "gemma2-27b-pt", "gemma2-27b-it", }; constexpr Model kModelTypes[] = { Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B @@ -59,6 +63,10 @@ constexpr Model kModelTypes[] = { Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448 Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224 Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448 + Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B + Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B + Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B + Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B }; constexpr PromptWrapping kPromptWrapping[] = { PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B @@ -71,6 +79,10 @@ constexpr PromptWrapping kPromptWrapping[] = { PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448 PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448 PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448 + PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B + PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 7B + PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 9B + PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 27B }; constexpr size_t kNumModelFlags = std::size(kModelFlags); diff --git a/util/app.h b/util/app.h index 5c0698f0..393e0dd7 100644 --- a/util/app.h +++ b/util/app.h @@ -206,9 +206,17 @@ struct LoaderArgs : public ArgsBase { visitor(compressed_weights, "compressed_weights", Path(), "Alias for --weights."); visitor(model_type_str, "model", std::string(), - "Model type\n 2b-it = 2B parameters, instruction-tuned\n " - "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " - "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n " + "Model type\n " + "gemma-2b-it = Gemma 2B parameters, instruction-tuned\n " + "gemma-2b-pt = Gemma 2B parameters, pretrained\n " + "gemma-7b-it = Gemma 7B parameters, instruction-tuned\n " + "gemma-7b-pt = Gemma 7B parameters, pretrained\n " + "gemma2-2b-it = Gemma2 2B parameters, instruction-tuned\n " + "gemma2-2b-pt = Gemma2 2B parameters, pretrained\n " + "gemma2-9b-it = Gemma2 9B parameters, instruction-tuned\n " + "gemma2-9b-pt = Gemma2 9B parameters, pretrained\n " + "gemma2-27b-it = Gemma2 27B parameters, instruction-tuned\n " + "gemma2-27b-pt = Gemma2 27B parameters, pretrained\n " "gr2b-it = griffin 2B parameters, instruction-tuned\n " "gr2b-pt = griffin 2B parameters, pretrained."); visitor(weight_type_str, "weight_type", std::string("sfp"),