File tree Expand file tree Collapse file tree 3 files changed +15
-0
lines changed Expand file tree Collapse file tree 3 files changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -178,6 +178,8 @@ def create_pytorch_ray_engine(
178178    is_disaggregated : bool  =  False ,
179179    num_hosts : int  =  0 ,
180180    decode_pod_slice_name : str  =  None ,
181+     enable_jax_profiler : bool  =  False ,
182+     jax_profiler_port : int  =  9999 ,
181183) ->  Any :
182184
183185  # Return tuple as reponse: issues/107 
@@ -218,6 +220,8 @@ def create_pytorch_ray_engine(
218220        quantize_kv = quantize_kv ,
219221        max_cache_length = max_cache_length ,
220222        sharding_config = sharding_config ,
223+         enable_jax_profiler = enable_jax_profiler ,
224+         jax_profiler_port = jax_profiler_port ,
221225    )
222226    engine_workers .append (engine_worker )
223227
Original file line number Diff line number Diff line change @@ -114,6 +114,8 @@ def __init__(
114114      quantize_kv = False ,
115115      max_cache_length = 1024 ,
116116      sharding_config = None ,
117+       enable_jax_profiler : bool  =  False ,
118+       jax_profiler_port : int  =  9999 ,
117119  ):
118120
119121    jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
@@ -130,6 +132,10 @@ def __init__(
130132        f"---Jax device_count:{ device_count }  , local_device_count{ local_device_count }   " 
131133    )
132134
135+     if  enable_jax_profiler :
136+       jax .profiler .start_server (jax_profiler_port )
137+       print (f"Started JAX profiler server on port { jax_profiler_port }  " )
138+ 
133139    checkpoint_format  =  "" 
134140    checkpoint_path  =  "" 
135141
Original file line number Diff line number Diff line change 3434flags .DEFINE_integer ("prometheus_port" , 0 , "" )
3535flags .DEFINE_integer ("tpu_chips" , 16 , "device tpu_chips" )
3636
37+ flags .DEFINE_bool ("enable_jax_profiler" , False , "enable jax profiler" )
38+ flags .DEFINE_integer ("jax_profiler_port" , 9999 , "port of JAX profiler server" )
39+ 
3740
3841def  create_engine ():
3942  """create a pytorch engine""" 
@@ -53,6 +56,8 @@ def create_engine():
5356      quantize_kv = FLAGS .quantize_kv_cache ,
5457      max_cache_length = FLAGS .max_cache_length ,
5558      sharding_config = FLAGS .sharding_config ,
59+       enable_jax_profiler = FLAGS .enable_jax_profiler ,
60+       jax_profiler_port = FLAGS .jax_profiler_port ,
5661  )
5762
5863  print ("Initialize engine" , time .perf_counter () -  start )
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments