Skip to content

Commit 23f5aab

Browse files
committed
add: Support for image pull secrets for Ray Cluster images
1 parent 3b41a22 commit 23f5aab

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

src/codeflare_sdk/cluster/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def create_app_wrapper(self):
8484
instascale = self.config.instascale
8585
instance_types = self.config.machine_types
8686
env = self.config.envs
87+
pull_secret = self.config.pull_secret
8788
return generate_appwrapper(
8889
name=name,
8990
namespace=namespace,
@@ -98,6 +99,7 @@ def create_app_wrapper(self):
9899
instascale=instascale,
99100
instance_types=instance_types,
100101
env=env,
102+
pull_secret=pull_secret,
101103
)
102104

103105
# creates a new cluster with the provided or default spec

src/codeflare_sdk/cluster/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ class ClusterConfiguration:
4848
instascale: bool = False
4949
envs: dict = field(default_factory=dict)
5050
image: str = "ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103"
51+
pull_secret: dict = field(default_factory=dict)

src/codeflare_sdk/utils/generate_yaml.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
import uuid
2424

2525

26+
class NoAliasDumper(yaml.Dumper):
27+
def ignore_aliases(self, data):
28+
return True
29+
30+
2631
def read_template(template):
2732
with open(template, "r") as stream:
2833
try:
@@ -130,6 +135,13 @@ def update_image(spec, image):
130135
container["image"] = image
131136

132137

138+
def update_pull_secret(spec, pull_secret):
139+
if pull_secret:
140+
if "imagePullSecrets" not in spec:
141+
spec["imagePullSecrets"] = []
142+
spec["imagePullSecrets"].append(pull_secret)
143+
144+
133145
def update_env(spec, env):
134146
containers = spec.get("containers")
135147
for container in containers:
@@ -167,6 +179,7 @@ def update_nodes(
167179
image,
168180
instascale,
169181
env,
182+
pull_secret,
170183
):
171184
if "generictemplate" in item.keys():
172185
head = item.get("generictemplate").get("spec").get("headGroupSpec")
@@ -182,6 +195,7 @@ def update_nodes(
182195
for comp in [head, worker]:
183196
spec = comp.get("template").get("spec")
184197
update_affinity(spec, appwrapper_name, instascale)
198+
update_pull_secret(spec, pull_secret)
185199
update_image(spec, image)
186200
update_env(spec, env)
187201
if comp == head:
@@ -192,7 +206,7 @@ def update_nodes(
192206

193207
def write_user_appwrapper(user_yaml, output_file_name):
194208
with open(output_file_name, "w") as outfile:
195-
yaml.dump(user_yaml, outfile, default_flow_style=False)
209+
yaml.dump(user_yaml, outfile, default_flow_style=False, Dumper=NoAliasDumper)
196210
print(f"Written to: {output_file_name}")
197211

198212

@@ -210,6 +224,7 @@ def generate_appwrapper(
210224
instascale: bool,
211225
instance_types: list,
212226
env,
227+
pull_secret: str,
213228
):
214229
user_yaml = read_template(template)
215230
appwrapper_name, cluster_name = gen_names(name)
@@ -233,6 +248,7 @@ def generate_appwrapper(
233248
image,
234249
instascale,
235250
env,
251+
pull_secret,
236252
)
237253
update_dashboard_route(route_item, cluster_name, namespace)
238254
outfile = appwrapper_name + ".yaml"
@@ -314,6 +330,12 @@ def main(): # pragma: no cover
314330
default="default",
315331
help="Set the kubernetes namespace you want to deploy your cluster to. Default. If left blank, uses the 'default' namespace",
316332
)
333+
parser.add_argument(
334+
"--pull-secret",
335+
required=False,
336+
default="",
337+
help="Set pull secret for a private registry",
338+
)
317339

318340
args = parser.parse_args()
319341
name = args.name
@@ -329,6 +351,7 @@ def main(): # pragma: no cover
329351
instance_types = args.instance_types
330352
namespace = args.namespace
331353
env = {}
354+
pull_secret = args.pull_secret
332355

333356
outfile = generate_appwrapper(
334357
name,
@@ -344,6 +367,7 @@ def main(): # pragma: no cover
344367
instascale,
345368
instance_types,
346369
env,
370+
pull_secret,
347371
)
348372
return outfile
349373

0 commit comments

Comments
 (0)