Skip to content

Commit ebca2d9

Browse files
committed
Implement PodSets inference for RayClusters
1 parent cb4a4e3 commit ebca2d9

File tree

3 files changed

+53
-8
lines changed

3 files changed

+53
-8
lines changed

internal/webhook/appwrapper_fixtures_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,20 @@ func rayCluster(workerCount int, milliCPU int64) workloadv1beta2.AppWrapperCompo
396396
}
397397
}
398398

399+
func rayClusterForInference(workerCount int, milliCPU int64) workloadv1beta2.AppWrapperComponent {
400+
workerCPU := resource.NewMilliQuantity(milliCPU, resource.DecimalSI)
401+
yamlString := fmt.Sprintf(rayClusterYAML,
402+
randName("raycluster"),
403+
workerCount, workerCount, workerCount,
404+
workerCPU)
405+
406+
jsonBytes, err := yaml.YAMLToJSON([]byte(yamlString))
407+
Expect(err).NotTo(HaveOccurred())
408+
return workloadv1beta2.AppWrapperComponent{
409+
Template: runtime.RawExtension{Raw: jsonBytes},
410+
}
411+
}
412+
399413
const jobSetYAML = `
400414
apiVersion: jobset.x-k8s.io/v1alpha2
401415
kind: JobSet

internal/webhook/appwrapper_webhook_test.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,23 @@ var _ = Describe("AppWrapper Webhook Tests", func() {
191191
Expect(k8sClient.Delete(ctx, aw)).To(Succeed())
192192
})
193193

194-
It("PodSets are inferred for known GVKs", func() {
195-
aw := toAppWrapper(pod(100), deploymentForInference(1, 100), podForInference(100),
196-
jobForInference(2, 4, 100), jobForInference(8, 4, 100), pytorchJobForInference(100, 4, 100))
194+
Context("PodSets are inferred for known GVKs", func() {
195+
It("PodSets are inferred for common kinds", func() {
196+
aw := toAppWrapper(pod(100), deploymentForInference(1, 100), podForInference(100),
197+
jobForInference(2, 4, 100), jobForInference(8, 4, 100))
198+
199+
Expect(k8sClient.Create(ctx, aw)).To(Succeed(), "PodSets should be inferred")
200+
Expect(aw.Spec.Suspend).Should(BeTrue())
201+
Expect(k8sClient.Delete(ctx, aw)).To(Succeed())
202+
})
197203

198-
Expect(k8sClient.Create(ctx, aw)).To(Succeed(), "PodSets for deployments and pods should be inferred")
199-
Expect(aw.Spec.Suspend).Should(BeTrue())
200-
Expect(k8sClient.Delete(ctx, aw)).To(Succeed())
204+
It("PodSets are inferred for PyTorchJobs and RayClusters", func() {
205+
aw := toAppWrapper(pytorchJobForInference(100, 4, 100), rayClusterForInference(7, 100))
206+
207+
Expect(k8sClient.Create(ctx, aw)).To(Succeed(), "PodSets should be inferred")
208+
Expect(aw.Spec.Suspend).Should(BeTrue())
209+
Expect(k8sClient.Delete(ctx, aw)).To(Succeed())
210+
})
201211
})
202212
})
203213

pkg/utils/utils.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,34 @@ func InferPodSets(obj *unstructured.Unstructured) ([]workloadv1beta2.AppWrapperP
237237
for _, replicaType := range []string{"Master", "Worker"} {
238238
prefix := "template.spec.pytorchReplicaSpecs." + replicaType + "."
239239
// validate path to replica template
240-
if _, err := getValueAtPath(obj.UnstructuredContent(), prefix+"template"); err == nil {
240+
if _, err := getValueAtPath(obj.UnstructuredContent(), prefix+templateString); err == nil {
241241
// infer replica count
242242
replicas, err := InferReplicas(obj.UnstructuredContent(), prefix+"replicas")
243243
if err != nil {
244244
return nil, err
245245
}
246-
podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(replicas), Path: prefix + "template"})
246+
podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(replicas), Path: prefix + templateString})
247+
}
248+
}
249+
250+
case schema.GroupVersionKind{Group: "ray.io", Version: "v1", Kind: "RayCluster"}:
251+
if _, err := getValueAtPath(obj.UnstructuredContent(), "template.spec.headGroupSpec.template"); err == nil {
252+
podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(int32(1)), Path: "template.spec.headGroupSpec.template"})
253+
}
254+
if workers, err := getValueAtPath(obj.UnstructuredContent(), "template.spec.workerGroupSpecs"); err == nil {
255+
if workers, ok := workers.([]interface{}); ok {
256+
for i := range workers {
257+
prefix := fmt.Sprintf("template.spec.workerGroupSpecs[%v].", i)
258+
// validate path to replica template
259+
if _, err := getValueAtPath(obj.UnstructuredContent(), prefix+templateString); err == nil {
260+
// infer replica count
261+
replicas, err := InferReplicas(obj.UnstructuredContent(), prefix+"replicas")
262+
if err != nil {
263+
return nil, err
264+
}
265+
podSets = append(podSets, workloadv1beta2.AppWrapperPodSet{Replicas: ptr.To(replicas), Path: prefix + templateString})
266+
}
267+
}
247268
}
248269
}
249270

0 commit comments

Comments
 (0)