@@ -31,7 +31,7 @@ def format_type_name(tp):
3131 return type_str
3232
3333
34- def get_torchrun_params (args : dict ):
34+ def get_torchrun_params (args : dict ) -> dict [ str , str | int ] :
3535 """
3636 Parse and load PyTorch distributed training parameters with hierarchical precedence.
3737
@@ -104,21 +104,18 @@ def validate_nproc_per_node(value: int | str) -> int | str:
104104 raise ValueError (f"nproc_per_node must be 'auto', 'gpu', or an integer, got type { type (value ).__name__ } " )
105105 if isinstance (value , int ):
106106 return value
107-
107+
108108 value_lower = value .lower ().strip ()
109109 if value_lower not in ['auto' , 'gpu' ] and not value_lower .isdigit ():
110110 raise ValueError (f"nproc_per_node must be 'auto', 'gpu', or an integer, got: { value !r} " )
111111 if value_lower .isdigit ():
112112 return int (value_lower )
113- elif value_lower == 'gpu' :
114- return 'gpu'
115113
116- # otherwise just handle auto logic
117- # convert 'auto' to 'gpu' if CUDA is available
118- if torch .cuda .is_available ():
114+ # handle 'auto' and 'gpu' - both require CUDA
115+ if value_lower in ['auto' , 'gpu' ] and torch .cuda .is_available ():
119116 return 'gpu'
120117 else :
121- raise ValueError ("nproc_per_node='auto ' requires CUDA GPUs, but none are available" )
118+ raise ValueError (f "nproc_per_node='{ value_lower } ' requires CUDA GPUs, but none are available" )
122119
123120 def get_param_reference (param_name : str , source : str ) -> str :
124121 """Format parameter reference based on source (args vs env)."""
@@ -151,7 +148,13 @@ def get_param_reference(param_name: str, source: str) -> str:
151148 # we know the final values in this case must be integers, so any non-None value here
152149 # should be castable to `int`.
153150 value , _ = get_param_value (param )
154- torchrun_args [param ] = int (value ) if value is not None else default
151+ if value is not None :
152+ try :
153+ torchrun_args [param ] = int (value )
154+ except (ValueError , TypeError ) as e :
155+ raise ValueError (f"Invalid value for { param } : { value !r} . Must be an integer." ) from e
156+ else :
157+ torchrun_args [param ] = default
155158
156159
157160 # rdzv_id will be either a str or int; we just perform some cleanup before
@@ -212,8 +215,13 @@ def get_param_reference(param_name: str, source: str) -> str:
212215 # validate env conflicts only when we're actually using master_port
213216 if master_port_source == 'env' :
214217 validate_env_conflict ('master_port' )
215- torchrun_args ['master_port' ] = int (master_port_val )
218+ try :
219+ torchrun_args ['master_port' ] = int (master_port_val )
220+ except (ValueError , TypeError ) as e :
221+ raise ValueError (f"Invalid value for master_port: { master_port_val !r} . Must be an integer." ) from e
216222
223+ # Note: If neither master_addr nor rdzv_endpoint is set, torchrun will use
224+ # its default behavior (typically localhost or other configured defaults)
217225 elif rdzv_endpoint_val :
218226 torchrun_args ['rdzv_endpoint' ] = rdzv_endpoint_val
219227
0 commit comments