Skip to content

Commit e5f0408

Browse files
committed
one more yolo
1 parent 7dfe9e9 commit e5f0408

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

inference_benchmark.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,14 @@ def inference(
109109
top_p: float = 0.9,
110110
max_gen_len: int = 64,
111111
):
112-
results = generator.text_completion(
113-
prompts,
114-
max_gen_len=max_gen_len,
115-
temperature=temperature,
116-
top_p=top_p,
117-
)
118-
for prompt, result in zip(prompts, results):
119-
print(prompt)
120-
print(f"> {result['generation']}")
121-
print("\n==================================\n")
112+
with torch.no_grad():
113+
results = generator.text_completion(
114+
prompts,
115+
max_gen_len=max_gen_len,
116+
temperature=temperature,
117+
top_p=top_p,
118+
)
119+
return zip(prompts, results)
122120

123121
def __get_next_batch(dataloader):
124122
return next(iter(dataloader))
@@ -139,8 +137,9 @@ def benchmark(ckpt_dir,
139137
print("Running inference benchmark...\n")
140138

141139
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=PROFILE_MEMORY) as prof:
142-
with record_function("run_benchmark"):
143-
_, load, inference, total = run_benchmark(data_loader, net)
140+
# with record_function("run_benchmark"):
141+
# _, load, inference, total = run_benchmark(data_loader, net)
142+
_, load, inference, total = run_benchmark(data_loader, net)
144143
profile_cuda_time = prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)
145144
profile_cuda_mem = prof.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10)
146145

0 commit comments

Comments
 (0)