Skip to content

Commit 72a095f

Browse files
authored
Fix mps example for non-LLMs (#11538)
Fix broken ios app demo tests
1 parent 04710d4 commit 72a095f

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

examples/apple/mps/scripts/mps_example.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,25 @@ def parse_args():
138138
if args.model_name not in MODEL_NAME_TO_MODEL:
139139
raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")
140140

141-
llm_config = LlmConfig()
142141
if args.model_name == "llama2":
142+
# Building LLM example.
143+
llm_config = LlmConfig()
143144
if args.checkpoint:
144145
llm_config.base.checkpoint = args.checkpoint
145146
if args.params:
146147
llm_config.base.params = args.params
147148
llm_config.model.use_kv_cache = True
148-
model, example_inputs, _, _ = EagerModelFactory.create_model(
149-
module_name=MODEL_NAME_TO_MODEL[args.model_name][0],
150-
model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1],
151-
llm_config=llm_config,
152-
)
149+
model, example_inputs, _, _ = EagerModelFactory.create_model(
150+
module_name=MODEL_NAME_TO_MODEL[args.model_name][0],
151+
model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1],
152+
llm_config=llm_config,
153+
)
154+
else:
155+
# Building non-LLM example.
156+
model, example_inputs, _, _ = EagerModelFactory.create_model(
157+
module_name=MODEL_NAME_TO_MODEL[args.model_name][0],
158+
model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1],
159+
)
153160

154161
model = model.eval()
155162

0 commit comments

Comments
 (0)