Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api/inference/v1alpha1/service_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ type ServiceSpec struct {
WorkloadTemplate lws.LeaderWorkerTemplate `json:"workloadTemplate"`
// RolloutStrategy defines the strategy that will be applied to update replicas
// when a revision is made to the leaderWorkerTemplate.
// +kubebuilder:default:={type: "RollingUpdate", rollingUpdateConfiguration: {"maxUnavailable": 1, "maxSurge": 0}}
// +optional
RolloutStrategy lws.RolloutStrategy `json:"rolloutStrategy,omitempty"`
RolloutStrategy *lws.RolloutStrategy `json:"rolloutStrategy,omitempty"`
}

const (
Expand Down
7 changes: 6 additions & 1 deletion api/inference/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions config/crd/bases/inference.llmaz.io_services.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ spec:
format: int32
type: integer
rolloutStrategy:
default:
rollingUpdateConfiguration:
maxSurge: 0
maxUnavailable: 1
type: RollingUpdate
description: |-
RolloutStrategy defines the strategy that will be applied to update replicas
when a revision is made to the leaderWorkerTemplate.
Expand Down
5 changes: 0 additions & 5 deletions docs/examples/multi-nodes/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,3 @@ spec:
source:
modelHub:
modelID: meta-llama/Llama-3.1-405B
inferenceConfig:
flavors:
- name: a100-80gb
limits:
nvidia.com/gpu: 8 # single node request
15 changes: 7 additions & 8 deletions docs/examples/multi-nodes/service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ kind: Service
metadata:
name: llama3-405b-instruct
spec:
modelClaim:
modelName: llama3-405b-instruct
inferenceFlavors:
- a100-80gb # actually no need to specify this since we have only one flavor
modelClaims:
models:
- name: llama3-405b-instruct
replicas: 2
workloadTemplate:
size: 2
Expand All @@ -17,7 +16,7 @@ spec:
role: leader
spec:
containers:
- name: sglang-leader
- name: model-runner
image: lmsysorg/sglang:latest
env:
- name: HUGGING_FACE_HUB_TOKEN
Expand All @@ -31,7 +30,7 @@ spec:
- -m
- sglang.launch_server
- --model-path
- meta-llama/Meta-Llama-3.1-8B-Instruct
- /workspace/models/models--meta-llama--Meta-Llama-3.1-8B-Instruct
- --tp
- "2" # Size of Tensor Parallelism
- --dist-init-addr
Expand Down Expand Up @@ -65,7 +64,7 @@ spec:
workerTemplate:
spec:
containers:
- name: sglang-worker
- name: model-runner
image: lmsysorg/sglang:latest
env:
- name: HUGGING_FACE_HUB_TOKEN
Expand All @@ -79,7 +78,7 @@ spec:
- -m
- sglang.launch_server
- --model-path
- meta-llama/Meta-Llama-3.1-8B-Instruct
- /workspace/models/models--meta-llama--Meta-Llama-3.1-8B-Instruct
- --tp
- "2" # Size of Tensor Parallelism
- --dist-init-addr
Expand Down
13 changes: 9 additions & 4 deletions pkg/controller/inference/service_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,15 @@ func buildWorkloadApplyConfiguration(service *inferenceapi.Service, models []*co
spec.WithLeaderWorkerTemplate(leaderWorkerTemplate)
spec.LeaderWorkerTemplate.WithSize(*service.Spec.WorkloadTemplate.Size)
spec.WithReplicas(*service.Spec.Replicas)
spec.WithRolloutStrategy(applyconfigurationv1.RolloutStrategy().WithType(service.Spec.RolloutStrategy.Type))
if service.Spec.RolloutStrategy.RollingUpdateConfiguration != nil {
spec.RolloutStrategy.RollingUpdateConfiguration.WithMaxSurge(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxSurge)
spec.RolloutStrategy.RollingUpdateConfiguration.WithMaxUnavailable(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxUnavailable)
if service.Spec.RolloutStrategy != nil {
spec.WithRolloutStrategy(applyconfigurationv1.RolloutStrategy().WithType(service.Spec.RolloutStrategy.Type))
if service.Spec.RolloutStrategy.RollingUpdateConfiguration != nil {
spec.RolloutStrategy.WithRollingUpdateConfiguration(
applyconfigurationv1.RollingUpdateConfiguration().
WithMaxSurge(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxSurge).
WithMaxUnavailable(service.Spec.RolloutStrategy.RollingUpdateConfiguration.MaxUnavailable),
)
}
}
spec.WithStartupPolicy(lws.LeaderReadyStartupPolicy)

Expand Down
28 changes: 28 additions & 0 deletions test/integration/webhook/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ limitations under the License.
package webhook

import (
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/onsi/ginkgo/v2"
"github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
lws "sigs.k8s.io/lws/api/leaderworkerset/v1"

inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1"
"github.com/inftyai/llmaz/test/util"
Expand All @@ -44,6 +46,32 @@ var _ = ginkgo.Describe("service default and validation", func() {
gomega.Expect(k8sClient.Delete(ctx, ns)).To(gomega.Succeed())
})

type testDefaultingCase struct {
service func() *inferenceapi.Service
wantService func() *inferenceapi.Service
}
ginkgo.DescribeTable("Defaulting test",
func(tc *testDefaultingCase) {
svc := tc.service()
gomega.Expect(k8sClient.Create(ctx, svc)).To(gomega.Succeed())
gomega.Expect(svc).To(gomega.BeComparableTo(tc.wantService(),
cmpopts.IgnoreTypes(inferenceapi.ServiceStatus{}),
cmpopts.IgnoreFields(metav1.ObjectMeta{}, "UID", "ResourceVersion", "Generation", "CreationTimestamp", "ManagedFields")))
},
ginkgo.Entry("apply service rollingUpdate strategy", &testDefaultingCase{
service: func() *inferenceapi.Service {
return wrapper.MakeService("service-llama3-8b", ns.Name).WorkerTemplate().Obj()
},
wantService: func() *inferenceapi.Service {
return wrapper.MakeService("service-llama3-8b", ns.Name).
RolloutStrategy(string(lws.RollingUpdateStrategyType), 1, 0).
RestartPolicy("RecreateGroupOnPodRestart").
Replicas(1).Size(1).
WorkerTemplate().Obj()
},
}),
)

type testValidatingCase struct {
service func() *inferenceapi.Service
failed bool
Expand Down
32 changes: 29 additions & 3 deletions test/util/wrapper/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package wrapper
import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/utils/ptr"
lws "sigs.k8s.io/lws/api/leaderworkerset/v1"

coreapi "github.com/inftyai/llmaz/api/core/v1alpha1"
Expand Down Expand Up @@ -65,9 +67,6 @@ func (w *ServiceWrapper) ModelClaims(modelNames []string, roles []string, flavor
}

func (w *ServiceWrapper) WorkerTemplate() *ServiceWrapper {
w.Spec.RolloutStrategy = lws.RolloutStrategy{
Type: lws.RollingUpdateStrategyType,
}
w.Spec.WorkloadTemplate.WorkerTemplate = corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
Expand All @@ -90,3 +89,30 @@ func (w *ServiceWrapper) InitContainerName(name string) *ServiceWrapper {
w.Spec.WorkloadTemplate.WorkerTemplate.Spec.InitContainers[0].Name = name
return w
}

func (w *ServiceWrapper) RolloutStrategy(typ string, maxUnavailable int, maxSurge int) *ServiceWrapper {
if w.Spec.RolloutStrategy == nil {
w.Spec.RolloutStrategy = &lws.RolloutStrategy{}
}
w.Spec.RolloutStrategy.Type = lws.RolloutStrategyType(typ)
w.Spec.RolloutStrategy.RollingUpdateConfiguration = &lws.RollingUpdateConfiguration{
MaxUnavailable: intstr.FromInt(maxUnavailable),
MaxSurge: intstr.FromInt(maxSurge),
}
return w
}

func (w *ServiceWrapper) Size(size int32) *ServiceWrapper {
w.Spec.WorkloadTemplate.Size = ptr.To[int32](size)
return w
}

func (w *ServiceWrapper) Replicas(replicas int32) *ServiceWrapper {
w.Spec.Replicas = ptr.To[int32](replicas)
return w
}

func (w *ServiceWrapper) RestartPolicy(policy string) *ServiceWrapper {
w.Spec.WorkloadTemplate.RestartPolicy = lws.RestartPolicyType(policy)
return w
}
Loading