|
| 1 | +import os |
| 2 | +from curses.ascii import isdigit |
| 3 | +from importlib.metadata import pass_none |
1 | 4 | from typing import get_origin, get_args |
2 | 5 |
|
3 | 6 | def format_type_name(tp): |
@@ -26,3 +29,73 @@ def format_type_name(tp): |
26 | 29 | return type_str[8:-2] |
27 | 30 |
|
28 | 31 | return type_str |
| 32 | + |
| 33 | + |
| 34 | +def get_torchrun_params(args: dict): |
| 35 | + """ |
| 36 | + Parse and load PyTorch variables from dict with fallback to environment variables. |
| 37 | +
|
| 38 | + Args: |
| 39 | + args (dict): Dictionary containing PyTorch configuration parameters |
| 40 | +
|
| 41 | + Returns: |
| 42 | + dict: Dictionary with PyTorch parameters loaded from args or environment |
| 43 | + """ |
| 44 | + pytorch_vars = ['nproc_per_node', 'nnodes', 'node_rank', 'rdzv_id', 'rdzv_endpoint', 'master_addr', 'master_port'] |
| 45 | + torchrun_args = {} |
| 46 | + |
| 47 | + def validate_nproc_per_node(value): |
| 48 | + """Validate and convert nproc_per_node value.""" |
| 49 | + if isinstance(value, str): |
| 50 | + if value.lower() == 'auto': |
| 51 | + return 'gpu' |
| 52 | + elif value.lower() == 'gpu': |
| 53 | + return 'gpu' |
| 54 | + else: |
| 55 | + try: |
| 56 | + return int(value) |
| 57 | + except ValueError: |
| 58 | + raise ValueError(f"nproc_per_node must be 'auto', 'gpu', or an integer, got: {value}") |
| 59 | + elif isinstance(value, int): |
| 60 | + return value |
| 61 | + else: |
| 62 | + raise ValueError(f"nproc_per_node must be 'auto', 'gpu', or an integer, got: {value}") |
| 63 | + |
| 64 | + def get_env_var_name(var_name): |
| 65 | + """Get environment variable name based on PyTorch convention.""" |
| 66 | + return var_name.upper() if var_name in ['master_addr', 'master_port'] else f"PET_{var_name.upper()}" |
| 67 | + |
| 68 | + for var_name in pytorch_vars: |
| 69 | + # Try args dict first |
| 70 | + if var_name in args and args[var_name] is not None and args[var_name] != "": |
| 71 | + value = args[var_name] |
| 72 | + if var_name == 'nproc_per_node': |
| 73 | + torchrun_args[var_name] = validate_nproc_per_node(value) |
| 74 | + elif var_name in ['nnodes', 'node_rank', 'rdzv_id', 'master_port']: |
| 75 | + torchrun_args[var_name] = int(value) if isinstance(value, (str, int)) else value |
| 76 | + else: |
| 77 | + torchrun_args[var_name] = value |
| 78 | + else: |
| 79 | + # Fallback to environment variable |
| 80 | + env_value = os.getenv(get_env_var_name(var_name)) |
| 81 | + if env_value is not None: |
| 82 | + if var_name == 'nproc_per_node': |
| 83 | + torchrun_args[var_name] = validate_nproc_per_node(env_value) |
| 84 | + elif var_name in ['nnodes', 'node_rank', 'rdzv_id', 'master_port']: |
| 85 | + try: |
| 86 | + torchrun_args[var_name] = int(env_value) |
| 87 | + except ValueError: |
| 88 | + torchrun_args[var_name] = env_value |
| 89 | + else: |
| 90 | + torchrun_args[var_name] = env_value |
| 91 | + else: |
| 92 | + # Set defaults |
| 93 | + defaults = {'nnodes': 1, 'rdzv_id': 0} |
| 94 | + torchrun_args[var_name] = defaults.get(var_name, "") |
| 95 | + |
| 96 | + # Validate mutually exclusive parameters |
| 97 | + if (torchrun_args.get('rdzv_endpoint', '') != "" and |
| 98 | + (torchrun_args.get('master_addr', '') != "" or torchrun_args.get('master_port', '') != "")): |
| 99 | + raise ValueError("Cannot specify both rdzv_endpoint and master_addr/master_port. These are mutually exclusive parameters.") |
| 100 | + |
| 101 | + return torchrun_args |
0 commit comments