diff --git a/api/inference/v1alpha1/service_types.go b/api/inference/v1alpha1/service_types.go index 39ee036a..b29cce3c 100644 --- a/api/inference/v1alpha1/service_types.go +++ b/api/inference/v1alpha1/service_types.go @@ -43,8 +43,9 @@ type ServiceSpec struct { WorkloadTemplate lws.LeaderWorkerTemplate `json:"workloadTemplate"` // RolloutStrategy defines the strategy that will be applied to update replicas // when a revision is made to the leaderWorkerTemplate. + // +kubebuilder:default:={type: "RollingUpdate", rollingUpdateConfiguration: {"maxUnavailable": 1, "maxSurge": 0}} // +optional - RolloutStrategy lws.RolloutStrategy `json:"rolloutStrategy,omitempty"` + RolloutStrategy *lws.RolloutStrategy `json:"rolloutStrategy,omitempty"` } const ( diff --git a/api/inference/v1alpha1/zz_generated.deepcopy.go b/api/inference/v1alpha1/zz_generated.deepcopy.go index 9950e0cd..eecffe54 100644 --- a/api/inference/v1alpha1/zz_generated.deepcopy.go +++ b/api/inference/v1alpha1/zz_generated.deepcopy.go @@ -26,6 +26,7 @@ import ( "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" runtime "k8s.io/apimachinery/pkg/runtime" + leaderworkersetv1 "sigs.k8s.io/lws/api/leaderworkerset/v1" ) // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. @@ -541,7 +542,11 @@ func (in *ServiceSpec) DeepCopyInto(out *ServiceSpec) { **out = **in } in.WorkloadTemplate.DeepCopyInto(&out.WorkloadTemplate) - in.RolloutStrategy.DeepCopyInto(&out.RolloutStrategy) + if in.RolloutStrategy != nil { + in, out := &in.RolloutStrategy, &out.RolloutStrategy + *out = new(leaderworkersetv1.RolloutStrategy) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ServiceSpec. diff --git a/config/crd/bases/inference.llmaz.io_services.yaml b/config/crd/bases/inference.llmaz.io_services.yaml index 7d2a2f8b..908ef21a 100644 --- a/config/crd/bases/inference.llmaz.io_services.yaml +++ b/config/crd/bases/inference.llmaz.io_services.yaml @@ -92,6 +92,11 @@ spec: format: int32 type: integer rolloutStrategy: + default: + rollingUpdateConfiguration: + maxSurge: 0 + maxUnavailable: 1 + type: RollingUpdate description: |- RolloutStrategy defines the strategy that will be applied to update replicas when a revision is made to the leaderWorkerTemplate. diff --git a/docs/examples/multi-nodes/model.yaml b/docs/examples/multi-nodes/model.yaml index fd39abd1..5d623a58 100644 --- a/docs/examples/multi-nodes/model.yaml +++ b/docs/examples/multi-nodes/model.yaml @@ -7,8 +7,3 @@ spec: source: modelHub: modelID: meta-llama/Llama-3.1-405B - inferenceConfig: - flavors: - - name: a100-80gb - limits: - nvidia.com/gpu: 8 # single node request diff --git a/docs/examples/multi-nodes/service.yaml b/docs/examples/multi-nodes/service.yaml index f50ba149..711fad23 100644 --- a/docs/examples/multi-nodes/service.yaml +++ b/docs/examples/multi-nodes/service.yaml @@ -3,10 +3,9 @@ kind: Service metadata: name: llama3-405b-instruct spec: - modelClaim: - modelName: llama3-405b-instruct - inferenceFlavors: - - a100-80gb # actually no need to specify this since we have only one flavor + modelClaims: + models: + - name: llama3-405b-instruct replicas: 2 workloadTemplate: size: 2 @@ -17,7 +16,7 @@ spec: role: leader spec: containers: - - name: sglang-leader + - name: model-runner image: lmsysorg/sglang:latest env: - name: HUGGING_FACE_HUB_TOKEN @@ -31,7 +30,7 @@ spec: - -m - sglang.launch_server - --model-path - - meta-llama/Meta-Llama-3.1-8B-Instruct + - /workspace/models/models--meta-llama--Meta-Llama-3.1-8B-Instruct - --tp - "2" # Size of Tensor Parallelism - --dist-init-addr @@ -65,7 +64,7 @@ spec: workerTemplate: spec: containers: - - name: sglang-worker + - name: model-runner image: lmsysorg/sglang:latest env: - name: HUGGING_FACE_HUB_TOKEN @@ -79,7 +78,7 @@ spec: - -m - sglang.launch_server - --model-path - - meta-llama/Meta-Llama-3.1-8B-Instruct + - /workspace/models/models--meta-llama--Meta-Llama-3.1-8B-Instruct - --tp - "2" # Size of Tensor Parallelism - --dist-init-addr diff --git a/pkg/controller/inference/service_controller.go b/pkg/controller/inference/service_controller.go index e7ae64ad..8688671d 100644 --- a/pkg/controller/inference/service_controller.go +++ b/pkg/controller/inference/service_controller.go @@ -149,10 +149,15 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*co spec.WithLeaderWorkerTemplate(leaderWorkerTemplate) spec.LeaderWorkerTemplate.WithSize(*service.Spec.WorkloadTemplate.Size) spec.WithReplicas(*service.Spec.Replicas) - spec.WithRolloutStrategy(applyconfigurationv1.RolloutStrategy().WithType(service.Spec.RolloutStrategy.Type)) - if service.Spec.RolloutStrategy.RollingUpdateConfiguration != nil { - spec.RolloutStrategy.RollingUpdateConfiguration.WithMaxSurge(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxSurge) - spec.RolloutStrategy.RollingUpdateConfiguration.WithMaxUnavailable(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxUnavailable) + if service.Spec.RolloutStrategy != nil { + spec.WithRolloutStrategy(applyconfigurationv1.RolloutStrategy().WithType(service.Spec.RolloutStrategy.Type)) + if service.Spec.RolloutStrategy.RollingUpdateConfiguration != nil { + spec.RolloutStrategy.WithRollingUpdateConfiguration( + applyconfigurationv1.RollingUpdateConfiguration(). + WithMaxSurge(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxSurge). + WithMaxUnavailable(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxUnavailable), + ) + } } spec.WithStartupPolicy(lws.LeaderReadyStartupPolicy) diff --git a/test/integration/webhook/service_test.go b/test/integration/webhook/service_test.go index 2b59f400..20678bd9 100644 --- a/test/integration/webhook/service_test.go +++ b/test/integration/webhook/service_test.go @@ -17,10 +17,12 @@ limitations under the License. package webhook import ( + "github.com/google/go-cmp/cmp/cmpopts" "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + lws "sigs.k8s.io/lws/api/leaderworkerset/v1" inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" "github.com/inftyai/llmaz/test/util" @@ -44,6 +46,32 @@ var _ = ginkgo.Describe("service default and validation", func() { gomega.Expect(k8sClient.Delete(ctx, ns)).To(gomega.Succeed()) }) + type testDefaultingCase struct { + service func() *inferenceapi.Service + wantService func() *inferenceapi.Service + } + ginkgo.DescribeTable("Defaulting test", + func(tc *testDefaultingCase) { + svc := tc.service() + gomega.Expect(k8sClient.Create(ctx, svc)).To(gomega.Succeed()) + gomega.Expect(svc).To(gomega.BeComparableTo(tc.wantService(), + cmpopts.IgnoreTypes(inferenceapi.ServiceStatus{}), + cmpopts.IgnoreFields(metav1.ObjectMeta{}, "UID", "ResourceVersion", "Generation", "CreationTimestamp", "ManagedFields"))) + }, + ginkgo.Entry("apply service rollingUpdate strategy", &testDefaultingCase{ + service: func() *inferenceapi.Service { + return wrapper.MakeService("service-llama3-8b", ns.Name).WorkerTemplate().Obj() + }, + wantService: func() *inferenceapi.Service { + return wrapper.MakeService("service-llama3-8b", ns.Name). + RolloutStrategy(string(lws.RollingUpdateStrategyType), 1, 0). + RestartPolicy("RecreateGroupOnPodRestart"). + Replicas(1).Size(1). + WorkerTemplate().Obj() + }, + }), + ) + type testValidatingCase struct { service func() *inferenceapi.Service failed bool diff --git a/test/util/wrapper/service.go b/test/util/wrapper/service.go index 8f37321e..8f224c09 100644 --- a/test/util/wrapper/service.go +++ b/test/util/wrapper/service.go @@ -19,6 +19,8 @@ package wrapper import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/utils/ptr" lws "sigs.k8s.io/lws/api/leaderworkerset/v1" coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" @@ -65,9 +67,6 @@ func (w *ServiceWrapper) ModelClaims(modelNames []string, roles []string, flavor } func (w *ServiceWrapper) WorkerTemplate() *ServiceWrapper { - w.Spec.RolloutStrategy = lws.RolloutStrategy{ - Type: lws.RollingUpdateStrategyType, - } w.Spec.WorkloadTemplate.WorkerTemplate = corev1.PodTemplateSpec{ Spec: corev1.PodSpec{ Containers: []corev1.Container{ @@ -90,3 +89,30 @@ func (w *ServiceWrapper) InitContainerName(name string) *ServiceWrapper { w.Spec.WorkloadTemplate.WorkerTemplate.Spec.InitContainers[0].Name = name return w } + +func (w *ServiceWrapper) RolloutStrategy(typ string, maxUnavailable int, maxSurge int) *ServiceWrapper { + if w.Spec.RolloutStrategy == nil { + w.Spec.RolloutStrategy = &lws.RolloutStrategy{} + } + w.Spec.RolloutStrategy.Type = lws.RolloutStrategyType(typ) + w.Spec.RolloutStrategy.RollingUpdateConfiguration = &lws.RollingUpdateConfiguration{ + MaxUnavailable: intstr.FromInt(maxUnavailable), + MaxSurge: intstr.FromInt(maxSurge), + } + return w +} + +func (w *ServiceWrapper) Size(size int32) *ServiceWrapper { + w.Spec.WorkloadTemplate.Size = ptr.To[int32](size) + return w +} + +func (w *ServiceWrapper) Replicas(replicas int32) *ServiceWrapper { + w.Spec.Replicas = ptr.To[int32](replicas) + return w +} + +func (w *ServiceWrapper) RestartPolicy(policy string) *ServiceWrapper { + w.Spec.WorkloadTemplate.RestartPolicy = lws.RestartPolicyType(policy) + return w +}