Skip to content

Commit 9a7ab71

Browse files
committed
add: Support for image pull secrets for Ray Cluster images
1 parent b4d84c1 commit 9a7ab71

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

src/codeflare_sdk/cluster/cluster.py

+2
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def create_app_wrapper(self):
8484
instascale = self.config.instascale
8585
instance_types = self.config.machine_types
8686
env = self.config.envs
87+
image_pull_secrets = self.config.image_pull_secrets
8788
return generate_appwrapper(
8889
name=name,
8990
namespace=namespace,
@@ -98,6 +99,7 @@ def create_app_wrapper(self):
9899
instascale=instascale,
99100
instance_types=instance_types,
100101
env=env,
102+
image_pull_secrets=image_pull_secrets,
101103
)
102104

103105
# creates a new cluster with the provided or default spec

src/codeflare_sdk/cluster/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,4 @@ class ClusterConfiguration:
4848
instascale: bool = False
4949
envs: dict = field(default_factory=dict)
5050
image: str = "ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103"
51+
image_pull_secrets: list = field(default_factory=list)

src/codeflare_sdk/utils/generate_yaml.py

+20
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,14 @@ def update_image(spec, image):
130130
container["image"] = image
131131

132132

133+
def update_image_pull_secrets(spec, image_pull_secrets):
134+
if image_pull_secrets:
135+
if "imagePullSecrets" not in spec:
136+
spec["imagePullSecrets"] = []
137+
for image_pull_secret in image_pull_secrets:
138+
spec["imagePullSecrets"].append({"name": image_pull_secret})
139+
140+
133141
def update_env(spec, env):
134142
containers = spec.get("containers")
135143
for container in containers:
@@ -167,6 +175,7 @@ def update_nodes(
167175
image,
168176
instascale,
169177
env,
178+
image_pull_secrets,
170179
):
171180
if "generictemplate" in item.keys():
172181
head = item.get("generictemplate").get("spec").get("headGroupSpec")
@@ -182,6 +191,7 @@ def update_nodes(
182191
for comp in [head, worker]:
183192
spec = comp.get("template").get("spec")
184193
update_affinity(spec, appwrapper_name, instascale)
194+
update_image_pull_secrets(spec, image_pull_secrets)
185195
update_image(spec, image)
186196
update_env(spec, env)
187197
if comp == head:
@@ -211,6 +221,7 @@ def generate_appwrapper(
211221
instascale: bool,
212222
instance_types: list,
213223
env,
224+
image_pull_secrets: list,
214225
):
215226
user_yaml = read_template(template)
216227
appwrapper_name, cluster_name = gen_names(name)
@@ -234,6 +245,7 @@ def generate_appwrapper(
234245
image,
235246
instascale,
236247
env,
248+
image_pull_secrets,
237249
)
238250
update_dashboard_route(route_item, cluster_name, namespace)
239251
outfile = appwrapper_name + ".yaml"
@@ -315,6 +327,12 @@ def main(): # pragma: no cover
315327
default="default",
316328
help="Set the kubernetes namespace you want to deploy your cluster to. Default. If left blank, uses the 'default' namespace",
317329
)
330+
parser.add_argument(
331+
"--image-pull-secrets",
332+
required=False,
333+
default=[],
334+
help="Set image pull secrets for private registries",
335+
)
318336

319337
args = parser.parse_args()
320338
name = args.name
@@ -330,6 +348,7 @@ def main(): # pragma: no cover
330348
instance_types = args.instance_types
331349
namespace = args.namespace
332350
env = {}
351+
image_pull_secrets = args.image_pull_secrets
333352

334353
outfile = generate_appwrapper(
335354
name,
@@ -345,6 +364,7 @@ def main(): # pragma: no cover
345364
instascale,
346365
instance_types,
347366
env,
367+
image_pull_secrets,
348368
)
349369
return outfile
350370

0 commit comments

Comments
 (0)