@@ -138,18 +138,25 @@ def parse_args():
138
138
if args .model_name not in MODEL_NAME_TO_MODEL :
139
139
raise RuntimeError (f"Available models are { list (MODEL_NAME_TO_MODEL .keys ())} ." )
140
140
141
- llm_config = LlmConfig ()
142
141
if args .model_name == "llama2" :
142
+ # Building LLM example.
143
+ llm_config = LlmConfig ()
143
144
if args .checkpoint :
144
145
llm_config .base .checkpoint = args .checkpoint
145
146
if args .params :
146
147
llm_config .base .params = args .params
147
148
llm_config .model .use_kv_cache = True
148
- model , example_inputs , _ , _ = EagerModelFactory .create_model (
149
- module_name = MODEL_NAME_TO_MODEL [args .model_name ][0 ],
150
- model_class_name = MODEL_NAME_TO_MODEL [args .model_name ][1 ],
151
- llm_config = llm_config ,
152
- )
149
+ model , example_inputs , _ , _ = EagerModelFactory .create_model (
150
+ module_name = MODEL_NAME_TO_MODEL [args .model_name ][0 ],
151
+ model_class_name = MODEL_NAME_TO_MODEL [args .model_name ][1 ],
152
+ llm_config = llm_config ,
153
+ )
154
+ else :
155
+ # Building non-LLM example.
156
+ model , example_inputs , _ , _ = EagerModelFactory .create_model (
157
+ module_name = MODEL_NAME_TO_MODEL [args .model_name ][0 ],
158
+ model_class_name = MODEL_NAME_TO_MODEL [args .model_name ][1 ],
159
+ )
153
160
154
161
model = model .eval ()
155
162
0 commit comments