1
1
import torch
2
2
from torch .utils .data import DataLoader
3
3
import time
4
- import cotracker .models .build_cotracker
5
- from cotracker .datasets .tap_vid_datasets import TapVidDataset
6
- import os
7
- from cotracker .datasets .utils import collate_fn
8
4
from datasets import load_dataset
5
+ import fire
6
+ from torch .profiler import profile , record_function , ProfilerActivity
9
7
10
8
### Setup ###
11
9
BATCH_SIZE = 1
12
10
BATCH_COUNT = 5
13
11
NUM_WORKERS = 1
12
+ PROFILE_MEMORY = True
13
+
14
+ # https://huggingface.co/datasets/gsm8k
15
+ HUGGING_FACE_GSMK_DATASET_ID = "gsm8k"
14
16
15
17
# Manual seed for reproducatibility
16
18
SEED = 42
20
22
DEVICE_CUDA = 'cuda'
21
23
DEVICE_CPU = 'cpu'
22
24
25
+ from llama import Llama
26
+ from typing import List
27
+
28
+ def get_device ():
29
+ return torch .device (DEVICE_CUDA if torch .cuda .is_available () else DEVICE_CPU )
23
30
24
31
def get_data_loader (num_workers = 1 ):
25
- dataset = load_dataset ("HuggingFaceH4/no_robots" )
32
+ dataset = load_dataset (HUGGING_FACE_GSMK_DATASET_ID , 'main' )[ 'train' ]
26
33
dataloader = DataLoader (
27
34
dataset ,
28
35
batch_size = BATCH_SIZE ,
29
36
shuffle = False ,
30
- num_workers = num_workers ,
31
- collate_fn = collate_fn ,
37
+ num_workers = num_workers
32
38
)
33
39
return dataloader
34
40
35
41
36
- def get_model (checkpoint_path = CHECKPOINT_S4_W12 ):
37
- return cotracker .models .build_cotracker .build_cotracker (checkpoint_path )
42
+ def get_model (ckpt_dir , tokenizer_path , max_seq_len , max_batch_size ):
43
+ generator = Llama .build (
44
+ ckpt_dir = ckpt_dir ,
45
+ tokenizer_path = tokenizer_path ,
46
+ max_seq_len = max_seq_len ,
47
+ max_batch_size = max_batch_size ,
48
+ )
49
+ return generator
38
50
39
51
40
- def run_inference (dataloader , model , cuda = True ):
52
+ def run_benchmark (dataloader , model ):
41
53
load_time_per_batch = torch .zeros (BATCH_COUNT )
42
54
inference_time_per_batch = torch .zeros (BATCH_COUNT )
43
55
total_time_per_batch = torch .zeros (BATCH_COUNT )
44
56
45
- device = DEVICE_CUDA if cuda else DEVICE_CPU
57
+ device = get_device ()
58
+ # model.to(device)
46
59
print ("Working on device: {}" .format (device ))
47
- model . to ( device )
60
+
48
61
49
62
for batch_idx in range (BATCH_COUNT ):
50
- print ("Starting BATCHs {} of {}" .format (batch_idx + 1 , BATCH_COUNT ))
51
- (output , load_time , train_time ), batch_time = measure_runtime (run_batch_inference ,
63
+ print ("Starting BATCH {} of {}" .format (batch_idx + 1 , BATCH_COUNT ))
64
+ (output , load_time , inference_time ), batch_time = measure_runtime (run_batch_inference ,
52
65
dataloader ,
53
- model ,
54
- cuda )
66
+ model )
55
67
load_time_per_batch [batch_idx ] = load_time
56
- inference_time_per_batch [batch_idx ] = train_time
68
+ inference_time_per_batch [batch_idx ] = inference_time
57
69
total_time_per_batch [batch_idx ] = batch_time
58
70
59
71
print ("Finished Batch {} of {}" .format (batch_idx + 1 , BATCH_COUNT ))
60
72
print ("Batch load time: {}" .format (load_time ))
61
- print ("Batch inference time: {}" .format (train_time ))
73
+ print ("Batch inference time: {}" .format (inference_time ))
62
74
print ("Batch total time: {}" .format (batch_time ))
63
75
return model , load_time_per_batch , inference_time_per_batch , total_time_per_batch
64
76
@@ -71,46 +83,85 @@ def measure_runtime(func, *func_args):
71
83
return result , elapsed
72
84
73
85
74
- def run_batch_inference (dataloader , model , cuda = True ):
75
- (x , y ), load_time = measure_runtime (
86
+ def run_batch_inference (dataloader , model ):
87
+ (question , answer ), load_time = measure_runtime (
76
88
__get_next_batch , dataloader )
77
89
78
- if cuda :
79
- x = x .to (DEVICE_CUDA )
80
- y = y .to (DEVICE_CUDA )
81
-
82
- output , train_time = measure_runtime (
90
+
91
+ # print("question: ", question, "\nanswer: ", answer)
92
+ # print("question type: ", type(question), "answer type", type(answer))
93
+ # print("question shape: ", len(question), "answer shape", len(answer))
94
+ # device = get_device()
95
+ # x = x.to(device)
96
+ # y = y.to(device)
97
+
98
+ output , inference_time = measure_runtime (
99
+ inference ,
83
100
model ,
84
- x )
101
+ [ question ] )
85
102
86
- return output , load_time , train_time
103
+ return output , load_time , inference_time
104
+
105
+ def inference (
106
+ generator : Llama ,
107
+ prompts : List [str ],
108
+ temperature : float = 0.6 ,
109
+ top_p : float = 0.9 ,
110
+ max_gen_len : int = 64 ,
111
+ ):
112
+ with torch .no_grad ():
113
+ results = generator .text_completion (
114
+ prompts ,
115
+ max_gen_len = max_gen_len ,
116
+ temperature = temperature ,
117
+ top_p = top_p ,
118
+ )
119
+ return zip (prompts , results )
87
120
88
121
def __get_next_batch (dataloader ):
89
122
return next (iter (dataloader ))
90
123
91
124
92
- def benchmark ():
125
+ def benchmark (ckpt_dir ,
126
+ tokenizer_path ,
127
+ max_seq_len ,
128
+ max_batch_size ):
93
129
print ("Starting up..." )
94
130
95
131
print ("Building data loaders..." )
96
132
data_loader = get_data_loader ()
97
133
98
134
print ("Initializing Model..." )
99
- net = get_model ()
135
+ net = get_model (ckpt_dir , tokenizer_path , max_seq_len , max_batch_size )
100
136
101
137
print ("Running inference benchmark...\n " )
102
- _ , load , inference , total = run_batch_inference (data_loader , net )
103
-
104
- print ("Results..." )
105
- print ("C2.1: Data-loading times" )
138
+
139
+ with profile (activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ], record_shapes = True , profile_memory = PROFILE_MEMORY ) as prof :
140
+ # with record_function("run_benchmark"):
141
+ # _, load, inference, total = run_benchmark(data_loader, net)
142
+ _ , load , inference , total = run_benchmark (data_loader , net )
143
+
144
+ print ("\n \n Manual Profile Results..." )
145
+ print ("Data-loading times" )
106
146
print ("> per epoch: " , load )
107
147
print ("> average: " , torch .mean (load ))
108
- print ("C2.2: Training time for each epoch" )
148
+ print ("\n Inference time for each epoch" )
109
149
print ("> per epoch" , inference )
110
150
print ("> average" , torch .mean (inference ))
111
- print ("C2.3: Total time for each epoch" )
151
+ print ("\n Total time for each epoch" )
112
152
print ("> per epoch" , total )
113
153
print ("> average" , torch .mean (total ))
114
154
155
+ print ("\n \n " )
156
+ print ("Profiling sorted by CUDA time total" )
157
+ profile_cuda_time = prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 10 )
158
+ print (profile_cuda_time )
159
+
160
+ print ("\n \n " )
161
+ print ("Profiling sorted by CUDA memory usage" )
162
+ profile_cuda_mem = prof .key_averages ().table (sort_by = "self_cuda_memory_usage" , row_limit = 10 )
163
+ print (profile_cuda_mem )
164
+
165
+
115
166
if __name__ == "__main__" :
116
- benchmark ( )
167
+ fire . Fire ( benchmark )
0 commit comments