@@ -61,6 +61,7 @@ def __init__(
61
61
max_retries : int = 0 ,
62
62
mounts : Optional [List [str ]] = None ,
63
63
rdzv_port : int = 29500 ,
64
+ rdzv_backend : str = None ,
64
65
scheduler_args : Optional [Dict [str , str ]] = None ,
65
66
image : Optional [str ] = None ,
66
67
):
@@ -81,6 +82,7 @@ def __init__(
81
82
self .max_retries = max_retries
82
83
self .mounts : List [str ] = mounts if mounts is not None else []
83
84
self .rdzv_port = rdzv_port
85
+ self .rdzv_backend = rdzv_backend
84
86
self .scheduler_args : Dict [str , str ] = (
85
87
scheduler_args if scheduler_args is not None else dict ()
86
88
)
@@ -104,6 +106,9 @@ def _dry_run(self, cluster: "Cluster"):
104
106
env = self .env ,
105
107
max_retries = self .max_retries ,
106
108
rdzv_port = self .rdzv_port ,
109
+ rdzv_backend = self .rdzv_backend
110
+ if self .rdzv_backend is not None
111
+ else "static" ,
107
112
mounts = self .mounts ,
108
113
),
109
114
scheduler = cluster .torchx_scheduler ,
@@ -142,6 +147,9 @@ def _dry_run_no_cluster(self):
142
147
env = self .env , # should this still exist?
143
148
max_retries = self .max_retries ,
144
149
rdzv_port = self .rdzv_port , # should this still exist?
150
+ rdzv_backend = self .rdzv_backend
151
+ if self .rdzv_backend is not None
152
+ else "c10d" ,
145
153
mounts = self .mounts ,
146
154
image = self .image
147
155
if self .image is not None
0 commit comments