3434PROCESS_NAME = "selfdrive.modeld.modeld"
3535SEND_RAW_PRED = os .getenv ('SEND_RAW_PRED' )
3636
37- VISION_PKL_PATH = Path (__file__ ).parent / 'models/driving_vision_tinygrad.pkl'
38- POLICY_PKL_PATH = Path (__file__ ).parent / 'models/driving_policy_tinygrad.pkl'
39- VISION_METADATA_PATH = Path (__file__ ).parent / 'models/driving_vision_metadata.pkl'
40- POLICY_METADATA_PATH = Path (__file__ ).parent / 'models/driving_policy_metadata.pkl'
4137MODELS_DIR = Path (__file__ ).parent / 'models'
38+ VISION_PKL_PATH = MODELS_DIR / 'driving_vision_tinygrad.pkl'
39+ VISION_METADATA_PATH = MODELS_DIR / 'driving_vision_metadata.pkl'
40+ ON_POLICY_PKL_PATH = MODELS_DIR / 'driving_on_policy_tinygrad.pkl'
41+ ON_POLICY_METADATA_PATH = MODELS_DIR / 'driving_on_policy_metadata.pkl'
42+ OFF_POLICY_PKL_PATH = MODELS_DIR / 'driving_off_policy_tinygrad.pkl'
43+ OFF_POLICY_METADATA_PATH = MODELS_DIR / 'driving_off_policy_metadata.pkl'
4244
4345LAT_SMOOTH_SECONDS = 0.0
4446LONG_SMOOTH_SECONDS = 0.3
@@ -151,7 +153,13 @@ def __init__(self):
151153 self .vision_output_slices = vision_metadata ['output_slices' ]
152154 vision_output_size = vision_metadata ['output_shapes' ]['outputs' ][1 ]
153155
154- with open (POLICY_METADATA_PATH , 'rb' ) as f :
156+ with open (OFF_POLICY_METADATA_PATH , 'rb' ) as f :
157+ off_policy_metadata = pickle .load (f )
158+ self .off_policy_input_shapes = off_policy_metadata ['input_shapes' ]
159+ self .off_policy_output_slices = off_policy_metadata ['output_slices' ]
160+ off_policy_output_size = off_policy_metadata ['output_shapes' ]['outputs' ][1 ]
161+
162+ with open (ON_POLICY_METADATA_PATH , 'rb' ) as f :
155163 policy_metadata = pickle .load (f )
156164 self .policy_input_shapes = policy_metadata ['input_shapes' ]
157165 self .policy_output_slices = policy_metadata ['output_slices' ]
@@ -175,11 +183,13 @@ def __init__(self):
175183 self .vision_output = np .zeros (vision_output_size , dtype = np .float32 )
176184 self .policy_inputs = {k : Tensor (v , device = 'NPY' ).realize () for k ,v in self .numpy_inputs .items ()}
177185 self .policy_output = np .zeros (policy_output_size , dtype = np .float32 )
186+ self .off_policy_output = np .zeros (off_policy_output_size , dtype = np .float32 )
178187 self .parser = Parser ()
179188 self .frame_buf_params : dict [str , tuple [int , int , int , int ]] = {}
180189 self .update_imgs = None
181190 self .vision_run = pickle .loads (read_file_chunked (str (VISION_PKL_PATH )))
182- self .policy_run = pickle .loads (read_file_chunked (str (POLICY_PKL_PATH )))
191+ self .policy_run = pickle .loads (read_file_chunked (str (ON_POLICY_PKL_PATH )))
192+ self .off_policy_run = pickle .loads (read_file_chunked (str (OFF_POLICY_PKL_PATH )))
183193
184194 def slice_outputs (self , model_outputs : np .ndarray , output_slices : dict [str , slice ]) -> dict [str , np .ndarray ]:
185195 parsed_model_outputs = {k : model_outputs [np .newaxis , v ] for k ,v in output_slices .items ()}
@@ -228,9 +238,17 @@ def run(self, bufs: dict[str, VisionBuf], transforms: dict[str, np.ndarray],
228238
229239 self .policy_output = self .policy_run (** self .policy_inputs ).contiguous ().realize ().uop .base .buffer .numpy ().flatten ()
230240 policy_outputs_dict = self .parser .parse_policy_outputs (self .slice_outputs (self .policy_output , self .policy_output_slices ))
231- combined_outputs_dict = {** vision_outputs_dict , ** policy_outputs_dict }
241+
242+ self .off_policy_output = self .off_policy_run (** self .policy_inputs ).contiguous ().realize ().uop .base .buffer .numpy ()
243+ off_policy_outputs_dict = self .parser .parse_off_policy_outputs (self .slice_outputs (self .off_policy_output , self .off_policy_output_slices ))
244+ off_policy_outputs_dict .pop ('plan' )
245+
246+
247+ combined_outputs_dict = {** vision_outputs_dict , ** off_policy_outputs_dict , ** policy_outputs_dict }
248+ if 'planplus' in combined_outputs_dict and 'plan' in combined_outputs_dict :
249+ combined_outputs_dict ['plan' ] = combined_outputs_dict ['plan' ] + combined_outputs_dict ['planplus' ]
232250 if SEND_RAW_PRED :
233- combined_outputs_dict ['raw_pred' ] = np .concatenate ([self .vision_output .copy (), self .policy_output .copy ()])
251+ combined_outputs_dict ['raw_pred' ] = np .concatenate ([self .vision_output .copy (), self .policy_output .copy (), self . off_policy_output . copy () ])
234252
235253 return combined_outputs_dict
236254
0 commit comments