diff --git a/api/core/v1alpha1/model_types.go b/api/core/v1alpha1/model_types.go index ec16682c..199510a9 100644 --- a/api/core/v1alpha1/model_types.go +++ b/api/core/v1alpha1/model_types.go @@ -181,7 +181,7 @@ type ModelStatus struct { //+genclient //+kubebuilder:object:root=true //+kubebuilder:subresource:status -//+kubebuilder:resource:scope=Cluster +//+kubebuilder:resource:shortName=om,scope=Cluster // OpenModel is the Schema for the open models API type OpenModel struct { diff --git a/api/inference/v1alpha1/backendruntime_types.go b/api/inference/v1alpha1/backendruntime_types.go index e588fa3c..db1bb813 100644 --- a/api/inference/v1alpha1/backendruntime_types.go +++ b/api/inference/v1alpha1/backendruntime_types.go @@ -24,8 +24,8 @@ import ( type InferenceMode string const ( - DefaultInferenceMode InferenceMode = "default" - SpeculativeDecodingInferenceMode InferenceMode = "speculative-decoding" + DefaultInferenceMode InferenceMode = "Default" + SpeculativeDecodingInferenceMode InferenceMode = "SpeculativeDecoding" ) type BackendRuntimeArg struct { @@ -47,6 +47,7 @@ type BackendRuntimeSpec struct { // They can be appended or overwritten by the Playground args. // The key is the inference option, like default one or advanced // speculativeDecoding, the values are the corresponding args. + // Flag around with {{ .XXX }} is a flag waiting for render. Args []BackendRuntimeArg `json:"args,omitempty"` // Envs represents the environments set to the container. // +optional @@ -65,7 +66,7 @@ type BackendRuntimeStatus struct { //+kubebuilder:object:root=true //+kubebuilder:subresource:status -//+kubebuilder:resource:scope=Cluster +//+kubebuilder:resource:shortName=br,scope=Cluster // BackendRuntime is the Schema for the backendRuntime API type BackendRuntime struct { diff --git a/api/inference/v1alpha1/config_types.go b/api/inference/v1alpha1/config_types.go index 51e00d52..62070986 100644 --- a/api/inference/v1alpha1/config_types.go +++ b/api/inference/v1alpha1/config_types.go @@ -30,7 +30,6 @@ const ( type BackendRuntimeConfig struct { // Name represents the inference backend under the hood, e.g. vLLM. - // +kubebuilder:validation:Enum={vllm,sglang,llamacpp} // +kubebuilder:default=vllm // +optional Name *BackendName `json:"name,omitempty"` diff --git a/api/inference/v1alpha1/playground_types.go b/api/inference/v1alpha1/playground_types.go index 1ad7aed4..4aa3707a 100644 --- a/api/inference/v1alpha1/playground_types.go +++ b/api/inference/v1alpha1/playground_types.go @@ -61,6 +61,7 @@ type PlaygroundStatus struct { //+genclient //+kubebuilder:object:root=true //+kubebuilder:subresource:status +//+kubebuilder:resource:shortName={pl} // Playground is the Schema for the playgrounds API type Playground struct { diff --git a/config/crd/bases/inference.llmaz.io_backendruntimes.yaml b/config/crd/bases/inference.llmaz.io_backendruntimes.yaml index 0c5e26ba..a9654fbd 100644 --- a/config/crd/bases/inference.llmaz.io_backendruntimes.yaml +++ b/config/crd/bases/inference.llmaz.io_backendruntimes.yaml @@ -11,6 +11,8 @@ spec: kind: BackendRuntime listKind: BackendRuntimeList plural: backendruntimes + shortNames: + - br singular: backendruntime scope: Cluster versions: @@ -45,6 +47,7 @@ spec: They can be appended or overwritten by the Playground args. The key is the inference option, like default one or advanced speculativeDecoding, the values are the corresponding args. + Flag around with {{ .XXX }} is a flag waiting for render. items: properties: flags: diff --git a/config/crd/bases/inference.llmaz.io_playgrounds.yaml b/config/crd/bases/inference.llmaz.io_playgrounds.yaml index ef071eb3..ee9d86e8 100644 --- a/config/crd/bases/inference.llmaz.io_playgrounds.yaml +++ b/config/crd/bases/inference.llmaz.io_playgrounds.yaml @@ -11,6 +11,8 @@ spec: kind: Playground listKind: PlaygroundList plural: playgrounds + shortNames: + - pl singular: playground scope: Namespaced versions: @@ -179,10 +181,6 @@ spec: default: vllm description: Name represents the inference backend under the hood, e.g. vLLM. - enum: - - vllm - - sglang - - llamacpp type: string resources: description: |- diff --git a/config/crd/bases/llmaz.io_openmodels.yaml b/config/crd/bases/llmaz.io_openmodels.yaml index 7b3f0734..27013365 100644 --- a/config/crd/bases/llmaz.io_openmodels.yaml +++ b/config/crd/bases/llmaz.io_openmodels.yaml @@ -11,6 +11,8 @@ spec: kind: OpenModel listKind: OpenModelList plural: openmodels + shortNames: + - om singular: openmodel scope: Cluster versions: diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml index 52830493..0a67a841 100644 --- a/config/rbac/role.yaml +++ b/config/rbac/role.yaml @@ -34,7 +34,7 @@ rules: - apiGroups: - inference.llmaz.io resources: - - backends + - backendruntimes verbs: - create - delete @@ -46,13 +46,13 @@ rules: - apiGroups: - inference.llmaz.io resources: - - backends/finalizers + - backendruntimes/finalizers verbs: - update - apiGroups: - inference.llmaz.io resources: - - backends/status + - backendruntimes/status verbs: - get - patch diff --git a/docs/assets/arch.png b/docs/assets/arch.png index a1ffdd49..15ae1211 100644 Binary files a/docs/assets/arch.png and b/docs/assets/arch.png differ diff --git a/docs/examples/llamacpp/playground.yaml b/docs/examples/llamacpp/playground.yaml index bcd0e287..bf62f9a4 100644 --- a/docs/examples/llamacpp/playground.yaml +++ b/docs/examples/llamacpp/playground.yaml @@ -6,7 +6,7 @@ spec: replicas: 1 modelClaim: modelName: qwen2-0--5b-gguf - backendConfig: + backendRuntimeConfig: name: llamacpp args: - -fa # use flash attention diff --git a/docs/examples/sglang/playground.yaml b/docs/examples/sglang/playground.yaml index 4b0b6810..8bb8601c 100644 --- a/docs/examples/sglang/playground.yaml +++ b/docs/examples/sglang/playground.yaml @@ -6,5 +6,5 @@ spec: replicas: 1 modelClaim: modelName: qwen2-05b - backendConfig: + backendRuntimeConfig: name: sglang diff --git a/docs/examples/speculative-decoding/llamacpp/playground.yaml b/docs/examples/speculative-decoding/llamacpp/playground.yaml index e237503e..405b3577 100644 --- a/docs/examples/speculative-decoding/llamacpp/playground.yaml +++ b/docs/examples/speculative-decoding/llamacpp/playground.yaml @@ -13,7 +13,7 @@ spec: role: main - name: llama2-7b-q2-k-gguf # the draft model role: draft - backendConfig: + backendRuntimeConfig: name: llamacpp args: - -fa # use flash attention diff --git a/docs/examples/speculative-decoding/vllm/playground.yaml b/docs/examples/speculative-decoding/vllm/playground.yaml index 0be6c615..922e6423 100644 --- a/docs/examples/speculative-decoding/vllm/playground.yaml +++ b/docs/examples/speculative-decoding/vllm/playground.yaml @@ -10,13 +10,7 @@ spec: role: main - name: opt-125m # the draft model role: draft - backendConfig: - args: - - --use-v2-block-manager - - --num_speculative_tokens - - "5" - - -tp - - "1" + backendRuntimeConfig: resources: limits: cpu: 8 diff --git a/pkg/controller/inference/playground_controller.go b/pkg/controller/inference/playground_controller.go index 597eecbf..1956f851 100644 --- a/pkg/controller/inference/playground_controller.go +++ b/pkg/controller/inference/playground_controller.go @@ -44,7 +44,7 @@ import ( inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" coreclientgo "github.com/inftyai/llmaz/client-go/applyconfiguration/core/v1alpha1" inferenceclientgo "github.com/inftyai/llmaz/client-go/applyconfiguration/inference/v1alpha1" - "github.com/inftyai/llmaz/pkg/controller_helper/backend" + helper "github.com/inftyai/llmaz/pkg/controller_helper" modelSource "github.com/inftyai/llmaz/pkg/controller_helper/model_source" "github.com/inftyai/llmaz/pkg/util" ) @@ -94,32 +94,27 @@ func (r *PlaygroundReconciler) Reconcile(ctx context.Context, req ctrl.Request) } } - var serviceApplyConfiguration *inferenceclientgo.ServiceApplyConfiguration - - models := []*coreapi.OpenModel{} - if playground.Spec.ModelClaim != nil { - model := &coreapi.OpenModel{} - if err := r.Get(ctx, types.NamespacedName{Name: string(playground.Spec.ModelClaim.ModelName)}, model); err != nil { - if apierrors.IsNotFound(err) && handleUnexpectedCondition(playground, false, false) { - return ctrl.Result{}, r.Client.Status().Update(ctx, playground) - } - return ctrl.Result{}, err - } - models = append(models, model) - } else if playground.Spec.ModelClaims != nil { - for _, mr := range playground.Spec.ModelClaims.Models { - model := &coreapi.OpenModel{} - if err := r.Get(ctx, types.NamespacedName{Name: string(mr.Name)}, model); err != nil { - if apierrors.IsNotFound(err) && handleUnexpectedCondition(playground, false, false) { - return ctrl.Result{}, r.Client.Status().Update(ctx, playground) - } - return ctrl.Result{}, err - } - models = append(models, model) + models, err := helper.FetchModelsByPlayground(ctx, r.Client, playground) + if err != nil { + if apierrors.IsNotFound(err) && handleUnexpectedCondition(playground, false, false) { + return ctrl.Result{}, r.Client.Status().Update(ctx, playground) } + return ctrl.Result{}, err } - serviceApplyConfiguration = buildServiceApplyConfiguration(models, playground) + backendRuntimeName := inferenceapi.VLLM + if playground.Spec.BackendRuntimeConfig != nil && playground.Spec.BackendRuntimeConfig.Name != nil { + backendRuntimeName = *playground.Spec.BackendRuntimeConfig.Name + } + backendRuntime := &inferenceapi.BackendRuntime{} + if err := r.Get(ctx, types.NamespacedName{Name: string(backendRuntimeName)}, backendRuntime); err != nil { + return ctrl.Result{}, err + } + + serviceApplyConfiguration, err := buildServiceApplyConfiguration(models, playground, backendRuntime) + if err != nil { + return ctrl.Result{}, err + } if err := setControllerReferenceForService(playground, serviceApplyConfiguration, r.Scheme); err != nil { return ctrl.Result{}, err @@ -185,19 +180,19 @@ func (r *PlaygroundReconciler) SetupWithManager(mgr ctrl.Manager) error { Complete(r) } -func buildServiceApplyConfiguration(models []*coreapi.OpenModel, playground *inferenceapi.Playground) *inferenceclientgo.ServiceApplyConfiguration { +func buildServiceApplyConfiguration(models []*coreapi.OpenModel, playground *inferenceapi.Playground, backendRuntime *inferenceapi.BackendRuntime) (*inferenceclientgo.ServiceApplyConfiguration, error) { // Build metadata serviceApplyConfiguration := inferenceclientgo.Service(playground.Name, playground.Namespace) // Build spec. spec := inferenceclientgo.ServiceSpec() - claim := &coreclientgo.ModelClaimsApplyConfiguration{} + var claim *coreclientgo.ModelClaimsApplyConfiguration if playground.Spec.ModelClaim != nil { claim = coreclientgo.ModelClaims(). WithModels(coreclientgo.ModelRepresentative().WithName(playground.Spec.ModelClaim.ModelName).WithRole(coreapi.MainRole)). WithInferenceFlavors(playground.Spec.ModelClaim.InferenceFlavors...) - } else if playground.Spec.ModelClaims != nil { + } else { mrs := []*coreclientgo.ModelRepresentativeApplyConfiguration{} for _, model := range playground.Spec.ModelClaims.Models { role := coreapi.MainRole @@ -214,10 +209,15 @@ func buildServiceApplyConfiguration(models []*coreapi.OpenModel, playground *inf } spec.WithModelClaims(claim) - spec.WithWorkloadTemplate(buildWorkloadTemplate(models, playground)) + template, err := buildWorkloadTemplate(models, playground, backendRuntime) + if err != nil { + return nil, err + } + + spec.WithWorkloadTemplate(template) serviceApplyConfiguration.WithSpec(spec) - return serviceApplyConfiguration + return serviceApplyConfiguration, nil // TODO: handle MultiModelsClaims in the future. } @@ -226,7 +226,7 @@ func buildServiceApplyConfiguration(models []*coreapi.OpenModel, playground *inf // to cover both single-host and multi-host cases. There're some shortages for lws like can not force rolling // update when one replica failed, we'll fix this in the kubernetes upstream. // Model flavors will not be considered but in inferenceService controller to support accelerator fungibility. -func buildWorkloadTemplate(models []*coreapi.OpenModel, playground *inferenceapi.Playground) lws.LeaderWorkerSetSpec { +func buildWorkloadTemplate(models []*coreapi.OpenModel, playground *inferenceapi.Playground, backendRuntime *inferenceapi.BackendRuntime) (lws.LeaderWorkerSetSpec, error) { // TODO: this should be leaderWorkerSetTemplateSpec, we should support in the lws upstream. workload := lws.LeaderWorkerSetSpec{ // Use the default policy defined in lws. @@ -240,52 +240,36 @@ func buildWorkloadTemplate(models []*coreapi.OpenModel, playground *inferenceapi // TODO: handle multi-host scenarios, e.g. nvidia.com/gpu: 32, means we'll split into 4 hosts. // Do we need another configuration for playground for multi-host use case? I guess no currently. - workload.LeaderWorkerTemplate.WorkerTemplate = buildWorkerTemplate(models, playground) - - return workload -} - -func involveRole(playground *inferenceapi.Playground) coreapi.ModelRole { - if playground.Spec.ModelClaim != nil { - return coreapi.MainRole - } else if playground.Spec.ModelClaims != nil { - for _, mr := range playground.Spec.ModelClaims.Models { - if *mr.Role != coreapi.MainRole { - return *mr.Role - } - } + template, err := buildWorkerTemplate(models, playground, backendRuntime) + if err != nil { + return lws.LeaderWorkerSetSpec{}, err } + workload.LeaderWorkerTemplate.WorkerTemplate = template - return coreapi.MainRole + return workload, nil } -func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.Playground) corev1.PodTemplateSpec { - backendName := inferenceapi.DefaultBackend - if playground.Spec.BackendRuntimeConfig != nil && playground.Spec.BackendRuntimeConfig.Name != nil { - backendName = *playground.Spec.BackendRuntimeConfig.Name - } - bkd := backend.SwitchBackend(backendName) +func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.Playground, backendRuntime *inferenceapi.BackendRuntime) (corev1.PodTemplateSpec, error) { + parser := helper.NewBackendRuntimeParser(backendRuntime) - version := bkd.DefaultVersion() - if playground.Spec.BackendRuntimeConfig != nil && playground.Spec.BackendRuntimeConfig.Version != nil { - version = *playground.Spec.BackendRuntimeConfig.Version + args, err := parser.Args(helper.InferenceMode(playground), models) + if err != nil { + return corev1.PodTemplateSpec{}, err } + envs := parser.Envs() - args := bkd.Args(models, involveRole(playground)) - - var envs []corev1.EnvVar if playground.Spec.BackendRuntimeConfig != nil { args = append(args, playground.Spec.BackendRuntimeConfig.Args...) - envs = playground.Spec.BackendRuntimeConfig.Envs + envs = append(envs, playground.Spec.BackendRuntimeConfig.Envs...) } resources := corev1.ResourceRequirements{ - Limits: bkd.DefaultResources().Limits, - Requests: bkd.DefaultResources().Requests, + Requests: parser.Resources().Requests, + Limits: parser.Resources().Limits, } if playground.Spec.BackendRuntimeConfig != nil && playground.Spec.BackendRuntimeConfig.Resources != nil { - limits := util.MergeResources(playground.Spec.BackendRuntimeConfig.Resources.Limits, resources.Limits) - requests := util.MergeResources(playground.Spec.BackendRuntimeConfig.Resources.Requests, resources.Requests) + limits := util.MergeResources(playground.Spec.BackendRuntimeConfig.Resources.Limits, parser.Resources().Limits) + requests := util.MergeResources(playground.Spec.BackendRuntimeConfig.Resources.Requests, parser.Resources().Requests) resources = corev1.ResourceRequirements{ Limits: limits, @@ -302,6 +286,11 @@ func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.P } } + version := parser.Version() + if playground.Spec.BackendRuntimeConfig != nil && playground.Spec.BackendRuntimeConfig.Version != nil { + version = *playground.Spec.BackendRuntimeConfig.Version + } + template := corev1.PodTemplateSpec{ Spec: corev1.PodSpec{ // TODO: should we support image pull secret here? @@ -309,9 +298,9 @@ func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.P Containers: []corev1.Container{ { Name: modelSource.MODEL_RUNNER_CONTAINER_NAME, - Image: bkd.Image(version), + Image: parser.Image(version), Resources: resources, - Command: bkd.DefaultCommand(), + Command: parser.Commands(), Args: args, Env: envs, Ports: []corev1.ContainerPort{ @@ -326,7 +315,7 @@ func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.P }, } - return template + return template, nil } func handleUnexpectedCondition(playground *inferenceapi.Playground, modelExists bool, serviceWithSameNameExists bool) (changed bool) { diff --git a/pkg/controller/inference/service_controller.go b/pkg/controller/inference/service_controller.go index a8575929..3a55ba97 100644 --- a/pkg/controller/inference/service_controller.go +++ b/pkg/controller/inference/service_controller.go @@ -43,6 +43,7 @@ import ( coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" + helper "github.com/inftyai/llmaz/pkg/controller_helper" modelSource "github.com/inftyai/llmaz/pkg/controller_helper/model_source" "github.com/inftyai/llmaz/pkg/util" ) @@ -80,20 +81,9 @@ func (r *ServiceReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ct logger.V(10).Info("reconcile Service", "Playground", klog.KObj(service)) - models := []*coreapi.OpenModel{} - for _, mr := range service.Spec.ModelClaims.Models { - model := &coreapi.OpenModel{} - if err := r.Get(ctx, types.NamespacedName{Name: string(mr.Name)}, model); err != nil { - return ctrl.Result{}, err - } - // Make sure the main model is always the 0-index model. - // We only have one main model right now, if this changes, - // the logic may also change here. - if *mr.Role == coreapi.MainRole { - models = append([]*coreapi.OpenModel{model}, models...) - } else { - models = append(models, model) - } + models, err := helper.FetchModelsByService(ctx, r.Client, service) + if err != nil { + return ctrl.Result{}, err } workloadApplyConfiguration := buildWorkloadApplyConfiguration(service, models) diff --git a/pkg/controller_helper/backend/backend.go b/pkg/controller_helper/backend/backend.go deleted file mode 100644 index 249ec9ca..00000000 --- a/pkg/controller_helper/backend/backend.go +++ /dev/null @@ -1,64 +0,0 @@ -/* -Copyright 2024. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" - inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" -) - -const ( - DEFAULT_BACKEND_PORT = 8080 -) - -// Backend represents the inference engine, such as vllm. -type Backend interface { - // Name returns the inference backend name in this project. - Name() inferenceapi.BackendName - // Image returns the container image for the inference backend. - Image(version string) string - - // DefaultVersion returns the default version for the inference backend. - DefaultVersion() string - // DefaultResources returns the default resources set for the container. - DefaultResources() inferenceapi.ResourceRequirements - // DefaultCommand returns the command to start the inference backend. - DefaultCommand() []string - // Args returns the bootstrap arguments to start the backend. - // The second parameter represents which particular modelRole involved, like draft. - Args([]*coreapi.OpenModel, coreapi.ModelRole) []string -} - -// SpeculativeBackend represents backend supports speculativeDecoding inferenceMode. -type SpeculativeBackend interface { - // speculativeArgs returns the bootstrap arguments when inferenceMode is speculativeDecoding. - speculativeArgs([]*coreapi.OpenModel) []string -} - -func SwitchBackend(name inferenceapi.BackendName) Backend { - switch name { - case inferenceapi.VLLM: - return &VLLM{} - case inferenceapi.SGLANG: - return &SGLANG{} - case inferenceapi.LLAMACPP: - return &LLAMACPP{} - default: - // We should not reach here because apiserver already did validation. - return nil - } -} diff --git a/pkg/controller_helper/backend/backend_test.go b/pkg/controller_helper/backend/backend_test.go deleted file mode 100644 index c516167b..00000000 --- a/pkg/controller_helper/backend/backend_test.go +++ /dev/null @@ -1,71 +0,0 @@ -/* -Copyright 2024. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - "testing" - - inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" -) - -func TestSwitchBackend(t *testing.T) { - testCases := []struct { - name string - backendName inferenceapi.BackendName - expectedBackendName inferenceapi.BackendName - shouldErr bool - }{ - { - name: "vllm should support", - backendName: "vllm", - expectedBackendName: inferenceapi.VLLM, - shouldErr: false, - }, - { - name: "sglang should support", - backendName: "sglang", - expectedBackendName: inferenceapi.SGLANG, - shouldErr: false, - }, - { - name: "llamacpp should support", - backendName: "llamacpp", - expectedBackendName: inferenceapi.LLAMACPP, - shouldErr: false, - }, - { - name: "tgi should not support", - backendName: "tgi", - expectedBackendName: inferenceapi.BackendName(""), - shouldErr: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - backend := SwitchBackend(tc.backendName) - - if !tc.shouldErr && backend == nil { - t.Fatal("unexpected error") - } - - if !tc.shouldErr && backend.Name() != tc.expectedBackendName { - t.Fatalf("unexpected backend, want %s, got %s", tc.expectedBackendName, backend.Name()) - } - }) - } -} diff --git a/pkg/controller_helper/backend/llamacpp.go b/pkg/controller_helper/backend/llamacpp.go deleted file mode 100644 index cc2de38d..00000000 --- a/pkg/controller_helper/backend/llamacpp.go +++ /dev/null @@ -1,85 +0,0 @@ -/* -Copyright 2024. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - "strconv" - - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - - coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" - inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" - modelSource "github.com/inftyai/llmaz/pkg/controller_helper/model_source" -) - -var _ Backend = (*LLAMACPP)(nil) - -type LLAMACPP struct{} - -const ( - llama_cpp_image_registry = "ghcr.io/ggerganov/llama.cpp" -) - -func (l *LLAMACPP) Name() inferenceapi.BackendName { - return inferenceapi.LLAMACPP -} - -func (l *LLAMACPP) Image(version string) string { - return llama_cpp_image_registry + ":" + version -} - -func (l *LLAMACPP) DefaultVersion() string { - return "server" -} - -func (l *LLAMACPP) DefaultResources() inferenceapi.ResourceRequirements { - return inferenceapi.ResourceRequirements{ - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2"), - corev1.ResourceMemory: resource.MustParse("4Gi"), - }, - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2"), - corev1.ResourceMemory: resource.MustParse("4Gi"), - }, - } -} - -func (l *LLAMACPP) DefaultCommand() []string { - return []string{"./llama-server"} -} - -func (l *LLAMACPP) Args(models []*coreapi.OpenModel, involvedRole coreapi.ModelRole) []string { - targetModelSource := modelSource.NewModelSourceProvider(models[0]) - - if involvedRole == coreapi.DraftRole { - draftModelSource := modelSource.NewModelSourceProvider(models[1]) - return []string{ - "-m", targetModelSource.ModelPath(), - "-md", draftModelSource.ModelPath(), - "--host", "0.0.0.0", - "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), - } - } - - return []string{ - "-m", targetModelSource.ModelPath(), - "--host", "0.0.0.0", - "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), - } -} diff --git a/pkg/controller_helper/backend/llamacpp_test.go b/pkg/controller_helper/backend/llamacpp_test.go deleted file mode 100644 index b7402597..00000000 --- a/pkg/controller_helper/backend/llamacpp_test.go +++ /dev/null @@ -1,98 +0,0 @@ -/* -Copyright 2024. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/utils/ptr" -) - -func Test_llamacpp(t *testing.T) { - backend := LLAMACPP{} - models := []*coreapi.OpenModel{ - { - ObjectMeta: v1.ObjectMeta{ - Name: "model-1", - }, - Spec: coreapi.ModelSpec{ - Source: coreapi.ModelSource{ - ModelHub: &coreapi.ModelHub{ - Name: ptr.To[string]("model-1"), - ModelID: "hub/model-1", - }, - }, - }, - }, - { - ObjectMeta: v1.ObjectMeta{ - Name: "model-2", - }, - Spec: coreapi.ModelSpec{ - Source: coreapi.ModelSource{ - ModelHub: &coreapi.ModelHub{ - Name: ptr.To[string]("model-2"), - ModelID: "hub/model-2", - }, - }, - }, - }, - } - - testCases := []struct { - name string - involvedRole coreapi.ModelRole - wantCommand []string - wantArgs []string - }{ - { - name: "one main model", - involvedRole: coreapi.MainRole, - wantCommand: []string{"./llama-server"}, - wantArgs: []string{ - "-m", "/workspace/models/models--hub--model-1", - "--host", "0.0.0.0", - "--port", "8080", - }, - }, - { - name: "speculative decoding", - involvedRole: coreapi.DraftRole, - wantCommand: []string{"./llama-server"}, - wantArgs: []string{ - "-m", "/workspace/models/models--hub--model-1", - "-md", "/workspace/models/models--hub--model-2", - "--host", "0.0.0.0", - "--port", "8080", - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if diff := cmp.Diff(backend.DefaultCommand(), tc.wantCommand); diff != "" { - t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.DefaultCommand()) - } - if diff := cmp.Diff(backend.Args(models, tc.involvedRole), tc.wantArgs); diff != "" { - t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.involvedRole)) - } - }) - } -} diff --git a/pkg/controller_helper/backend/sglang.go b/pkg/controller_helper/backend/sglang.go deleted file mode 100644 index 12a7307b..00000000 --- a/pkg/controller_helper/backend/sglang.go +++ /dev/null @@ -1,81 +0,0 @@ -/* -Copyright 2024. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - "strconv" - - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - - coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" - inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" - modelSource "github.com/inftyai/llmaz/pkg/controller_helper/model_source" -) - -var _ Backend = (*SGLANG)(nil) - -type SGLANG struct{} - -const ( - sglang_image_registry = "lmsysorg/sglang" -) - -func (s *SGLANG) Name() inferenceapi.BackendName { - return inferenceapi.SGLANG -} - -func (s *SGLANG) Image(version string) string { - return sglang_image_registry + ":" + version -} - -func (s *SGLANG) DefaultVersion() string { - return "v0.2.10-cu121" -} - -func (s *SGLANG) DefaultResources() inferenceapi.ResourceRequirements { - return inferenceapi.ResourceRequirements{ - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("4"), - corev1.ResourceMemory: resource.MustParse("8Gi"), - }, - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("4"), - corev1.ResourceMemory: resource.MustParse("8Gi"), - }, - } -} - -func (s *SGLANG) DefaultCommand() []string { - return []string{"python3", "-m", "sglang.launch_server"} -} - -func (s *SGLANG) Args(models []*coreapi.OpenModel, involvedRole coreapi.ModelRole) []string { - targetModelSource := modelSource.NewModelSourceProvider(models[0]) - - if involvedRole == coreapi.DraftRole { - // TODO: support speculative decoding - return nil - } - - return []string{ - "--model-path", targetModelSource.ModelPath(), - "--served-model-name", targetModelSource.ModelName(), - "--host", "0.0.0.0", - "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), - } -} diff --git a/pkg/controller_helper/backend/sglang_test.go b/pkg/controller_helper/backend/sglang_test.go deleted file mode 100644 index ce1bb50a..00000000 --- a/pkg/controller_helper/backend/sglang_test.go +++ /dev/null @@ -1,94 +0,0 @@ -/* -Copyright 2024. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/utils/ptr" -) - -func Test_SGLANG(t *testing.T) { - backend := SGLANG{} - models := []*coreapi.OpenModel{ - { - ObjectMeta: v1.ObjectMeta{ - Name: "model-1", - }, - Spec: coreapi.ModelSpec{ - Source: coreapi.ModelSource{ - ModelHub: &coreapi.ModelHub{ - Name: ptr.To[string]("model-1"), - ModelID: "hub/model-1", - }, - }, - }, - }, - { - ObjectMeta: v1.ObjectMeta{ - Name: "model-2", - }, - Spec: coreapi.ModelSpec{ - Source: coreapi.ModelSource{ - ModelHub: &coreapi.ModelHub{ - Name: ptr.To[string]("model-2"), - ModelID: "hub/model-2", - }, - }, - }, - }, - } - - testCases := []struct { - name string - involvedRole coreapi.ModelRole - wantCommand []string - wantArgs []string - }{ - { - name: "one main model", - involvedRole: coreapi.MainRole, - wantCommand: []string{"python3", "-m", "sglang.launch_server"}, - wantArgs: []string{ - "--model-path", "/workspace/models/models--hub--model-1", - "--served-model-name", "model-1", - "--host", "0.0.0.0", - "--port", "8080", - }, - }, - { - name: "speculative decoding", - involvedRole: coreapi.DraftRole, - wantCommand: []string{"python3", "-m", "sglang.launch_server"}, - wantArgs: nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if diff := cmp.Diff(backend.DefaultCommand(), tc.wantCommand); diff != "" { - t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.DefaultCommand()) - } - if diff := cmp.Diff(backend.Args(models, tc.involvedRole), tc.wantArgs); diff != "" { - t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.involvedRole)) - } - }) - } -} diff --git a/pkg/controller_helper/backend/vllm.go b/pkg/controller_helper/backend/vllm.go deleted file mode 100644 index 467bfbc4..00000000 --- a/pkg/controller_helper/backend/vllm.go +++ /dev/null @@ -1,87 +0,0 @@ -/* -Copyright 2024. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - "strconv" - - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/resource" - - coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" - inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" - modelSource "github.com/inftyai/llmaz/pkg/controller_helper/model_source" -) - -var _ Backend = (*VLLM)(nil) - -type VLLM struct{} - -const ( - vllm_image_registry = "vllm/vllm-openai" -) - -func (v *VLLM) Name() inferenceapi.BackendName { - return inferenceapi.VLLM -} - -func (v *VLLM) Image(version string) string { - return vllm_image_registry + ":" + version -} - -func (v *VLLM) DefaultVersion() string { - return "v0.5.1" -} - -func (v *VLLM) DefaultResources() inferenceapi.ResourceRequirements { - return inferenceapi.ResourceRequirements{ - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("4"), - corev1.ResourceMemory: resource.MustParse("8Gi"), - }, - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("4"), - corev1.ResourceMemory: resource.MustParse("8Gi"), - }, - } -} - -func (v *VLLM) DefaultCommand() []string { - return []string{"python3", "-m", "vllm.entrypoints.openai.api_server"} -} - -func (v *VLLM) Args(models []*coreapi.OpenModel, involvedRole coreapi.ModelRole) []string { - targetModelSource := modelSource.NewModelSourceProvider(models[0]) - - if involvedRole == coreapi.DraftRole { - draftModelSource := modelSource.NewModelSourceProvider(models[1]) - return []string{ - "--model", targetModelSource.ModelPath(), - "--speculative_model", draftModelSource.ModelPath(), - "--served-model-name", targetModelSource.ModelName(), - "--host", "0.0.0.0", - "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), - } - } - - return []string{ - "--model", targetModelSource.ModelPath(), - "--served-model-name", targetModelSource.ModelName(), - "--host", "0.0.0.0", - "--port", strconv.Itoa(DEFAULT_BACKEND_PORT), - } -} diff --git a/pkg/controller_helper/backend/vllm_test.go b/pkg/controller_helper/backend/vllm_test.go deleted file mode 100644 index d75fe4e1..00000000 --- a/pkg/controller_helper/backend/vllm_test.go +++ /dev/null @@ -1,100 +0,0 @@ -/* -Copyright 2024. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package backend - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/utils/ptr" -) - -func Test_vllm(t *testing.T) { - backend := VLLM{} - models := []*coreapi.OpenModel{ - { - ObjectMeta: v1.ObjectMeta{ - Name: "model-1", - }, - Spec: coreapi.ModelSpec{ - Source: coreapi.ModelSource{ - ModelHub: &coreapi.ModelHub{ - Name: ptr.To[string]("model-1"), - ModelID: "hub/model-1", - }, - }, - }, - }, - { - ObjectMeta: v1.ObjectMeta{ - Name: "model-2", - }, - Spec: coreapi.ModelSpec{ - Source: coreapi.ModelSource{ - ModelHub: &coreapi.ModelHub{ - Name: ptr.To[string]("model-2"), - ModelID: "hub/model-2", - }, - }, - }, - }, - } - - testCases := []struct { - name string - involvedRole coreapi.ModelRole - wantCommand []string - wantArgs []string - }{ - { - name: "one main model", - involvedRole: coreapi.MainRole, - wantCommand: []string{"python3", "-m", "vllm.entrypoints.openai.api_server"}, - wantArgs: []string{ - "--model", "/workspace/models/models--hub--model-1", - "--served-model-name", "model-1", - "--host", "0.0.0.0", - "--port", "8080", - }, - }, - { - name: "speculative decoding", - involvedRole: coreapi.DraftRole, - wantCommand: []string{"python3", "-m", "vllm.entrypoints.openai.api_server"}, - wantArgs: []string{ - "--model", "/workspace/models/models--hub--model-1", - "--speculative_model", "/workspace/models/models--hub--model-2", - "--served-model-name", "model-1", - "--host", "0.0.0.0", - "--port", "8080", - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - if diff := cmp.Diff(backend.DefaultCommand(), tc.wantCommand); diff != "" { - t.Fatalf("unexpected command, want %v, got %v", tc.wantCommand, backend.DefaultCommand()) - } - if diff := cmp.Diff(backend.Args(models, tc.involvedRole), tc.wantArgs); diff != "" { - t.Fatalf("unexpected args, want %v, got %v", tc.wantArgs, backend.Args(models, tc.involvedRole)) - } - }) - } -} diff --git a/pkg/controller_helper/backendruntime.go b/pkg/controller_helper/backendruntime.go new file mode 100644 index 00000000..61f6463f --- /dev/null +++ b/pkg/controller_helper/backendruntime.go @@ -0,0 +1,115 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package helper + +import ( + "fmt" + "regexp" + + corev1 "k8s.io/api/core/v1" + + coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" + inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" + modelSource "github.com/inftyai/llmaz/pkg/controller_helper/model_source" +) + +// TODO: add unit tests. +type BackendRuntimeParser struct { + backendRuntime *inferenceapi.BackendRuntime +} + +func NewBackendRuntimeParser(backendRuntime *inferenceapi.BackendRuntime) *BackendRuntimeParser { + return &BackendRuntimeParser{backendRuntime} +} + +func (p *BackendRuntimeParser) Commands() []string { + return p.backendRuntime.Spec.Commands +} + +func (p *BackendRuntimeParser) Envs() []corev1.EnvVar { + return p.backendRuntime.Spec.Envs +} + +func (p *BackendRuntimeParser) Args(mode inferenceapi.InferenceMode, models []*coreapi.OpenModel) ([]string, error) { + if mode == inferenceapi.SpeculativeDecodingInferenceMode && len(models) != 2 { + return nil, fmt.Errorf("models number not right, want 2, got %d", len(models)) + } + + modelInfo := map[string]string{} + + if mode == inferenceapi.DefaultInferenceMode { + source := modelSource.NewModelSourceProvider(models[0]) + modelInfo = map[string]string{ + "ModelPath": source.ModelPath(), + "ModelName": source.ModelName(), + } + } + + if mode == inferenceapi.SpeculativeDecodingInferenceMode { + targetSource := modelSource.NewModelSourceProvider(models[0]) + draftSource := modelSource.NewModelSourceProvider(models[1]) + modelInfo = map[string]string{ + "ModelPath": targetSource.ModelPath(), + "ModelName": targetSource.ModelName(), + "DraftModelPath": draftSource.ModelPath(), + } + } + + for _, arg := range p.backendRuntime.Spec.Args { + if arg.Mode == mode { + return renderFlags(arg.Flags, modelInfo) + } + } + // We should not reach here. + return nil, fmt.Errorf("backendRuntime %s not supported", p.backendRuntime.Name) +} + +func (p *BackendRuntimeParser) Image(version string) string { + return p.backendRuntime.Spec.Image + ":" + version +} + +func (p *BackendRuntimeParser) Version() string { + return p.backendRuntime.Spec.Version +} + +func (p *BackendRuntimeParser) Resources() inferenceapi.ResourceRequirements { + return p.backendRuntime.Spec.Resources +} + +func renderFlags(flags []string, modelInfo map[string]string) ([]string, error) { + // Capture the word. + re := regexp.MustCompile(`\{\{\s*\.(\w+)\s*\}\}`) + res := []string{} + var value string + + for _, flag := range flags { + value = flag + match := re.FindStringSubmatch(flag) + if len(match) > 1 { + // Return the matched word. + value = modelInfo[match[1]] + + if value == "" { + return nil, fmt.Errorf("missing flag or the flag has format error: %s", flag) + } + } + + res = append(res, value) + } + + return res, nil +} diff --git a/pkg/controller_helper/backendruntime_test.go b/pkg/controller_helper/backendruntime_test.go new file mode 100644 index 00000000..ed311a6b --- /dev/null +++ b/pkg/controller_helper/backendruntime_test.go @@ -0,0 +1,82 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package helper + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestRenderFlags(t *testing.T) { + testCases := []struct { + name string + flags []string + modelInfo map[string]string + wantFlags []string + wantError bool + }{ + { + name: "normal parse", + flags: []string{"-m", "{{ .ModelPath }}", "--served-model-name", "{{ .ModelName }}", "--host", "0.0.0.0"}, + modelInfo: map[string]string{ + "ModelPath": "path/to/model", + "ModelName": "foo", + }, + wantFlags: []string{"-m", "path/to/model", "--served-model-name", "foo", "--host", "0.0.0.0"}, + }, + { + name: "miss some info", + flags: []string{"-m", "{{ .ModelPath }}", "--served-model-name", "{{ .ModelName }}", "--host", "0.0.0.0"}, + modelInfo: map[string]string{ + "ModelPath": "path/to/model", + }, + wantError: true, + }, + { + name: "missing . with flag", + flags: []string{"-m", "{{ ModelPath }}", "--served-model-name", "{{ .ModelName }}", "--host", "0.0.0.0"}, + modelInfo: map[string]string{ + "ModelPath": "path/to/model", + "ModelName": "foo", + }, + wantFlags: []string{"-m", "{{ ModelPath }}", "--served-model-name", "foo", "--host", "0.0.0.0"}, + }, + { + name: "no empty space between {{}}", + flags: []string{"-m", "{{.ModelPath}}", "--served-model-name", "{{.ModelName}}", "--host", "0.0.0.0"}, + modelInfo: map[string]string{ + "ModelPath": "path/to/model", + "ModelName": "foo", + }, + wantFlags: []string{"-m", "path/to/model", "--served-model-name", "foo", "--host", "0.0.0.0"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotFlags, err := renderFlags(tc.flags, tc.modelInfo) + if tc.wantError && err == nil { + t.Fatal("test should fail") + } + + if !tc.wantError && cmp.Diff(tc.wantFlags, gotFlags) != "" { + t.Fatalf("want flags %v, got flags %v", tc.wantFlags, gotFlags) + } + }) + } +} diff --git a/pkg/controller_helper/helper.go b/pkg/controller_helper/helper.go new file mode 100644 index 00000000..5e353c0a --- /dev/null +++ b/pkg/controller_helper/helper.go @@ -0,0 +1,79 @@ +/* +Copyright 2024. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package helper + +import ( + "context" + + coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" + inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func InferenceMode(playground *inferenceapi.Playground) inferenceapi.InferenceMode { + if playground.Spec.ModelClaim != nil { + return inferenceapi.DefaultInferenceMode + } + + if playground.Spec.ModelClaims != nil { + for _, mr := range playground.Spec.ModelClaims.Models { + if *mr.Role == coreapi.DraftRole { + return inferenceapi.SpeculativeDecodingInferenceMode + } + } + } + + // We should not reach here. + return inferenceapi.DefaultInferenceMode +} + +func FetchModelsByService(ctx context.Context, k8sClient client.Client, service *inferenceapi.Service) (models []*coreapi.OpenModel, err error) { + return fetchModels(ctx, k8sClient, service.Spec.ModelClaims.Models) +} + +func FetchModelsByPlayground(ctx context.Context, k8sClient client.Client, playground *inferenceapi.Playground) (models []*coreapi.OpenModel, err error) { + mainRole := coreapi.MainRole + mrs := []coreapi.ModelRepresentative{} + + if playground.Spec.ModelClaim != nil { + mrs = append(mrs, coreapi.ModelRepresentative{Name: playground.Spec.ModelClaim.ModelName, Role: &mainRole}) + } else { + mrs = playground.Spec.ModelClaims.Models + } + + return fetchModels(ctx, k8sClient, mrs) +} + +func fetchModels(ctx context.Context, k8sClient client.Client, mrs []coreapi.ModelRepresentative) (models []*coreapi.OpenModel, err error) { + for _, mr := range mrs { + model := &coreapi.OpenModel{} + if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(mr.Name)}, model); err != nil { + return nil, err + } + // Make sure the main model is always the 0-index model. + // We only have one main model right now, if this changes, + // the logic may also change here. + if *mr.Role == coreapi.MainRole { + models = append([]*coreapi.OpenModel{model}, models...) + } else { + models = append(models, model) + } + } + + return models, nil +} diff --git a/pkg/util/util.go b/pkg/util/util.go index 4a330f6c..896b2352 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -46,3 +46,13 @@ func MergeKVs(toMerge map[string]string, toBeMerged map[string]string) map[strin } return toMerge } + +// TODO: add unit tests. +func In(strings []string, s string) bool { + for _, str := range strings { + if str == s { + return true + } + } + return false +} diff --git a/pkg/webhook/backendruntime_webhook.go b/pkg/webhook/backendruntime_webhook.go index 32f86e63..6f353547 100644 --- a/pkg/webhook/backendruntime_webhook.go +++ b/pkg/webhook/backendruntime_webhook.go @@ -27,6 +27,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook/admission" inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" + "github.com/inftyai/llmaz/pkg/util" ) type BackendRuntimeWebhook struct{} @@ -70,6 +71,7 @@ func (w *BackendRuntimeWebhook) ValidateDelete(ctx context.Context, obj runtime. return nil, nil } +// TODO: the mode name should not be duplicated. func (w *BackendRuntimeWebhook) generateValidate(obj runtime.Object) field.ErrorList { backend := obj.(*inferenceapi.BackendRuntime) specPath := field.NewPath("spec") @@ -85,10 +87,18 @@ func (w *BackendRuntimeWebhook) generateValidate(obj runtime.Object) field.Error } } + modes := []string{} + for _, arg := range backend.Spec.Args { + if util.In(modes, string(arg.Mode)) { + allErrs = append(allErrs, field.Forbidden(specPath.Child("args", "mode"), fmt.Sprintf("duplicated mode %s", arg.Mode))) + } + // TODO: this may change in the future if user wants to customized there flags for easy usage. + // See https://github.com/InftyAI/llmaz/issues/140 if !(arg.Mode == inferenceapi.DefaultInferenceMode || arg.Mode == inferenceapi.SpeculativeDecodingInferenceMode) { allErrs = append(allErrs, field.Forbidden(specPath.Child("args", "mode"), fmt.Sprintf("inferenceMode of %s is forbidden", arg.Mode))) } + modes = append(modes, string(arg.Mode)) } return allErrs } diff --git a/test/config/backends/llamacpp.yaml b/test/config/backends/llamacpp.yaml new file mode 100644 index 00000000..2360973b --- /dev/null +++ b/test/config/backends/llamacpp.yaml @@ -0,0 +1,39 @@ +apiVersion: inference.llmaz.io/v1alpha1 +kind: BackendRuntime +metadata: + labels: + app.kubernetes.io/name: backendruntime + app.kubernetes.io/part-of: llmaz + app.kubernetes.io/created-by: llmaz + name: llamacpp +spec: + commands: + - ./llama-server + image: ghcr.io/ggerganov/llama.cpp + version: server + args: + - mode: Default + flags: + - -m + - "{{ .ModelPath }}" + - --host + - "0.0.0.0" + - --port + - "8080" + - mode: SpeculativeDecoding + flags: + - -m + - "{{ .ModelPath }}" + - -md + - "{{ .DraftModelPath }}" + - --host + - "0.0.0.0" + - --port + - "8080" + resources: + requests: + cpu: 2 + memory: 4Gi + limits: + cpu: 2 + memory: 4Gi diff --git a/test/config/backends/sglang.yaml b/test/config/backends/sglang.yaml new file mode 100644 index 00000000..7716d952 --- /dev/null +++ b/test/config/backends/sglang.yaml @@ -0,0 +1,33 @@ +apiVersion: inference.llmaz.io/v1alpha1 +kind: BackendRuntime +metadata: + labels: + app.kubernetes.io/name: backendruntime + app.kubernetes.io/part-of: llmaz + app.kubernetes.io/created-by: llmaz + name: sglang +spec: + commands: + - python3 + - -m + - sglang.launch_server + image: lmsysorg/sglang + version: v0.2.10-cu121 + args: + - mode: Default + flags: + - --model-path + - "{{ .ModelPath }}" + - --served-model-name + - "{{ .ModelName }}" + - --host + - "0.0.0.0" + - --port + - "8080" + resources: + requests: + cpu: 4 + memory: 8Gi + limits: + cpu: 4 + memory: 8Gi diff --git a/test/config/backends/vllm.yaml b/test/config/backends/vllm.yaml new file mode 100644 index 00000000..14ca8b79 --- /dev/null +++ b/test/config/backends/vllm.yaml @@ -0,0 +1,50 @@ +apiVersion: inference.llmaz.io/v1alpha1 +kind: BackendRuntime +metadata: + labels: + app.kubernetes.io/name: backendruntime + app.kubernetes.io/part-of: llmaz + app.kubernetes.io/created-by: llmaz + name: vllm +spec: + commands: + - python3 + - -m + - vllm.entrypoints.openai.api_server + image: vllm/vllm-openai + version: v0.6.0 + args: + - mode: Default + flags: + - --model + - "{{ .ModelPath }}" + - --served-model-name + - "{{ .ModelName }}" + - --host + - "0.0.0.0" + - --port + - "8080" + - mode: SpeculativeDecoding + flags: + - --model + - "{{ .ModelPath }}" + - --served-model-name + - "{{ .ModelName }}" + - --speculative_model + - "{{ .DraftModelPath }}" + - --host + - "0.0.0.0" + - --port + - "8080" + - --use-v2-block-manager + - --num_speculative_tokens + - "5" + - -tp + - "1" + resources: + requests: + cpu: 4 + memory: 8Gi + limits: + cpu: 4 + memory: 8Gi diff --git a/test/e2e/playground_test.go b/test/e2e/playground_test.go index 3537b19d..344fe2d7 100644 --- a/test/e2e/playground_test.go +++ b/test/e2e/playground_test.go @@ -54,7 +54,32 @@ var _ = ginkgo.Describe("playground e2e tests", func() { gomega.Expect(k8sClient.Delete(ctx, model)).To(gomega.Succeed()) }() - playground := wrapper.MakePlayground("qwen2-0-5b-gguf", ns.Name).ModelClaim("qwen2-0-5b-gguf").Backend("llamacpp").Replicas(3).Obj() + playground := wrapper.MakePlayground("qwen2-0-5b-gguf", ns.Name).ModelClaim("qwen2-0-5b-gguf").BackendRuntime("llamacpp").Replicas(3).Obj() + gomega.Expect(k8sClient.Create(ctx, playground)).To(gomega.Succeed()) + validation.ValidatePlayground(ctx, k8sClient, playground) + validation.ValidatePlaygroundStatusEqualTo(ctx, k8sClient, playground, inferenceapi.PlaygroundAvailable, "PlaygroundReady", metav1.ConditionTrue) + + service := &inferenceapi.Service{} + gomega.Expect(k8sClient.Get(ctx, types.NamespacedName{Name: playground.Name, Namespace: playground.Namespace}, service)).To(gomega.Succeed()) + validation.ValidateService(ctx, k8sClient, service) + validation.ValidateServiceStatusEqualTo(ctx, k8sClient, service, inferenceapi.ServiceAvailable, "ServiceReady", metav1.ConditionTrue) + validation.ValidateServicePods(ctx, k8sClient, service) + }) + ginkgo.It("Deploy a huggingface model with customized backendRuntime", func() { + backendRuntime := wrapper.MakeBackendRuntime("llmaz-llamacpp"). + Image("ghcr.io/ggerganov/llama.cpp").Version("server"). + Command([]string{"./llama-server"}). + Arg("Default", []string{"-m", "{{.ModelPath}}", "--host", "0.0.0.0", "--port", "8080"}). + Request("cpu", "2").Request("memory", "4Gi").Limit("cpu", "4").Limit("memory", "4Gi").Obj() + gomega.Expect(k8sClient.Create(ctx, backendRuntime)).To(gomega.Succeed()) + + model := wrapper.MakeModel("qwen2-0-5b-gguf").FamilyName("qwen2").ModelSourceWithModelHub("Huggingface").ModelSourceWithModelID("Qwen/Qwen2-0.5B-Instruct-GGUF", "qwen2-0_5b-instruct-q5_k_m.gguf").Obj() + gomega.Expect(k8sClient.Create(ctx, model)).To(gomega.Succeed()) + defer func() { + gomega.Expect(k8sClient.Delete(ctx, model)).To(gomega.Succeed()) + }() + + playground := wrapper.MakePlayground("qwen2-0-5b-gguf", ns.Name).ModelClaim("qwen2-0-5b-gguf").BackendRuntime("llmaz-llamacpp").Replicas(1).Obj() gomega.Expect(k8sClient.Create(ctx, playground)).To(gomega.Succeed()) validation.ValidatePlayground(ctx, k8sClient, playground) validation.ValidatePlaygroundStatusEqualTo(ctx, k8sClient, playground, inferenceapi.PlaygroundAvailable, "PlaygroundReady", metav1.ConditionTrue) @@ -80,7 +105,7 @@ var _ = ginkgo.Describe("playground e2e tests", func() { // playground := wrapper.MakePlayground("llamacpp-speculator", ns.Name). // MultiModelsClaim([]string{"llama2-7b-q8-gguf", "llama2-7b-q2-k-gguf"}, coreapi.SpeculativeDecoding). - // Backend("llamacpp").BackendLimit("cpu", "4").BackendRequest("memory", "8Gi"). + // BackendRuntime("llamacpp").BackendLimit("cpu", "4").BackendRequest("memory", "8Gi"). // Replicas(1). // Obj() // gomega.Expect(k8sClient.Create(ctx, playground)).To(gomega.Succeed()) diff --git a/test/e2e/suit_test.go b/test/e2e/suit_test.go index ab2af7a3..ec869008 100644 --- a/test/e2e/suit_test.go +++ b/test/e2e/suit_test.go @@ -83,6 +83,8 @@ var _ = BeforeSuite(func() { readyForTesting(k8sClient) Expect(os.Setenv("TEST_TYPE", "E2E")).Should(Succeed()) + + Expect(util.Setup(ctx, k8sClient, "../config/backends")).To(Succeed()) }) var _ = AfterSuite(func() { diff --git a/test/integration/controller/inference/playground_test.go b/test/integration/controller/inference/playground_test.go index e0130d41..a13b382d 100644 --- a/test/integration/controller/inference/playground_test.go +++ b/test/integration/controller/inference/playground_test.go @@ -157,7 +157,8 @@ var _ = ginkgo.Describe("playground controller test", func() { ginkgo.Entry("advance configured Playground with sglang", &testValidatingCase{ makePlayground: func() *inferenceapi.Playground { return wrapper.MakePlayground("playground", ns.Name).ModelClaim(model.Name).Label(coreapi.ModelNameLabelKey, model.Name). - Backend("sglang").BackendVersion("main").BackendArgs([]string{"--foo", "bar"}).BackendEnv("FOO", "BAR").BackendRequest("cpu", "1").BackendLimit("cpu", "10"). + BackendRuntime("sglang").BackendRuntimeVersion("main").BackendRuntimeArgs([]string{"--foo", "bar"}).BackendRuntimeEnv("FOO", "BAR"). + BackendRuntimeRequest("cpu", "1").BackendRuntimeLimit("cpu", "10"). Obj() }, updates: []*update{ @@ -210,7 +211,8 @@ var _ = ginkgo.Describe("playground controller test", func() { ginkgo.Entry("advance configured Playground with llamacpp", &testValidatingCase{ makePlayground: func() *inferenceapi.Playground { return wrapper.MakePlayground("playground", ns.Name).ModelClaim(model.Name).Label(coreapi.ModelNameLabelKey, model.Name). - Backend("llamacpp").BackendVersion("main").BackendArgs([]string{"--foo", "bar"}).BackendEnv("FOO", "BAR").BackendRequest("cpu", "1").BackendLimit("cpu", "10"). + BackendRuntime("llamacpp").BackendRuntimeVersion("main").BackendRuntimeArgs([]string{"--foo", "bar"}).BackendRuntimeEnv("FOO", "BAR"). + BackendRuntimeRequest("cpu", "1").BackendRuntimeLimit("cpu", "10"). Obj() }, updates: []*update{ @@ -324,10 +326,10 @@ var _ = ginkgo.Describe("playground controller test", func() { }, }, }), - ginkgo.Entry("Playground with backendConfig's resource requests greater than limits", &testValidatingCase{ + ginkgo.Entry("Playground with backendRuntimeConfig's resource requests greater than limits", &testValidatingCase{ makePlayground: func() *inferenceapi.Playground { return wrapper.MakePlayground("playground", ns.Name).ModelClaim(model.Name).Label(coreapi.ModelNameLabelKey, model.Name). - BackendRequest("cpu", "10"). + BackendRuntimeRequest("cpu", "10"). Obj() }, updates: []*update{ diff --git a/test/integration/controller/inference/suit_test.go b/test/integration/controller/inference/suit_test.go index 7fe7afe1..97e2169e 100644 --- a/test/integration/controller/inference/suit_test.go +++ b/test/integration/controller/inference/suit_test.go @@ -42,6 +42,7 @@ import ( inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" "github.com/inftyai/llmaz/pkg/controller" inferencecontroller "github.com/inftyai/llmaz/pkg/controller/inference" + "github.com/inftyai/llmaz/test/util" ) // These tests use Ginkgo (BDD-style Go testing framework). Refer to @@ -116,6 +117,8 @@ var _ = BeforeSuite(func() { serviceController := inferencecontroller.NewServiceReconciler(mgr.GetClient(), mgr.GetScheme(), mgr.GetEventRecorderFor("service")) Expect(serviceController.SetupWithManager(mgr)).NotTo(HaveOccurred()) + Expect(util.Setup(ctx, k8sClient, "../../../config/backends")).To(Succeed()) + go func() { defer GinkgoRecover() err = mgr.Start(ctx) diff --git a/test/integration/webhook/backendruntime_test.go b/test/integration/webhook/backendruntime_test.go index 8124d59f..a9f55d50 100644 --- a/test/integration/webhook/backendruntime_test.go +++ b/test/integration/webhook/backendruntime_test.go @@ -20,44 +20,61 @@ import ( "github.com/onsi/ginkgo/v2" "github.com/onsi/gomega" - inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" "github.com/inftyai/llmaz/test/util" ) var _ = ginkgo.Describe("BackendRuntime default and validation", func() { type testValidatingCase struct { - backendRuntime func() *inferenceapi.BackendRuntime - failed bool + creationFunc func() error + failed bool } ginkgo.DescribeTable("test validating", func(tc *testValidatingCase) { if tc.failed { - gomega.Expect(k8sClient.Create(ctx, tc.backendRuntime())).To(gomega.HaveOccurred()) + gomega.Expect(tc.creationFunc()).To(gomega.HaveOccurred()) } else { - gomega.Expect(k8sClient.Create(ctx, tc.backendRuntime())).To(gomega.Succeed()) + gomega.Expect(tc.creationFunc()).To(gomega.Succeed()) } }, ginkgo.Entry("normal BackendRuntime creation", &testValidatingCase{ - backendRuntime: func() *inferenceapi.BackendRuntime { - return util.MockASampleBackendRuntime().Obj() + creationFunc: func() error { + runtime := util.MockASampleBackendRuntime().Obj() + return k8sClient.Create(ctx, runtime) }, failed: false, }), ginkgo.Entry("BackendRuntime creation with no image", &testValidatingCase{ - backendRuntime: func() *inferenceapi.BackendRuntime { - return util.MockASampleBackendRuntime().Image("").Obj() + creationFunc: func() error { + runtime := util.MockASampleBackendRuntime().Image("").Obj() + return k8sClient.Create(ctx, runtime) }, failed: true, }), ginkgo.Entry("BackendRuntime creation with limits less than requests", &testValidatingCase{ - backendRuntime: func() *inferenceapi.BackendRuntime { - return util.MockASampleBackendRuntime().Limit("cpu", "1").Obj() + creationFunc: func() error { + runtime := util.MockASampleBackendRuntime().Limit("cpu", "1").Obj() + return k8sClient.Create(ctx, runtime) }, failed: true, }), - ginkgo.Entry("BackendRuntime creation with unsupported inferenceOption", &testValidatingCase{ - backendRuntime: func() *inferenceapi.BackendRuntime { - return util.MockASampleBackendRuntime().Arg("unknown", []string{"foo", "bar"}).Obj() + ginkgo.Entry("BackendRuntime creation with unsupported inferenceMode", &testValidatingCase{ + creationFunc: func() error { + runtime := util.MockASampleBackendRuntime().Arg("unknown", []string{"foo", "bar"}).Obj() + return k8sClient.Create(ctx, runtime) + }, + failed: true, + }), + ginkgo.Entry("BackendRuntime creation with duplicated inferenceMode", &testValidatingCase{ + creationFunc: func() error { + runtime := util.MockASampleBackendRuntime().Obj() + if err := k8sClient.Create(ctx, runtime); err != nil { + return err + } + anotherRuntime := util.MockASampleBackendRuntime().Name("another-vllm").Obj() + if err := k8sClient.Create(ctx, anotherRuntime); err != nil { + return err + } + return nil }, failed: true, }), diff --git a/test/integration/webhook/playground_test.go b/test/integration/webhook/playground_test.go index 4f704c59..8df0f8f7 100644 --- a/test/integration/webhook/playground_test.go +++ b/test/integration/webhook/playground_test.go @@ -75,21 +75,21 @@ var _ = ginkgo.Describe("Playground default and validation", func() { }, failed: true, }), - ginkgo.Entry("sglang backend supporeted", &testValidatingCase{ + ginkgo.Entry("sglang backendruntime supporeted", &testValidatingCase{ playground: func() *inferenceapi.Playground { - return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaim("llama3-8b").Backend(string(inferenceapi.SGLANG)).Obj() + return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaim("llama3-8b").BackendRuntime(string(inferenceapi.SGLANG)).Obj() }, failed: false, }), - ginkgo.Entry("llamacpp backend supporeted", &testValidatingCase{ + ginkgo.Entry("llamacpp backendruntime supporeted", &testValidatingCase{ playground: func() *inferenceapi.Playground { - return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaim("llama3-8b").Backend(string(inferenceapi.LLAMACPP)).Obj() + return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaim("llama3-8b").BackendRuntime(string(inferenceapi.LLAMACPP)).Obj() }, failed: false, }), ginkgo.Entry("speculativeDecoding with SGLang is not allowed", &testValidatingCase{ playground: func() *inferenceapi.Playground { - return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaims([]string{"llama3-405b", "llama3-8b"}, []string{"main", "draft"}).Backend(string(inferenceapi.SGLANG)).Obj() + return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaims([]string{"llama3-405b", "llama3-8b"}, []string{"main", "draft"}).BackendRuntime(string(inferenceapi.SGLANG)).Obj() }, failed: true, }), @@ -99,11 +99,11 @@ var _ = ginkgo.Describe("Playground default and validation", func() { }, failed: true, }), - ginkgo.Entry("unknown backend configured", &testValidatingCase{ + ginkgo.Entry("unknown backendRuntime configured", &testValidatingCase{ playground: func() *inferenceapi.Playground { - return wrapper.MakePlayground("playground", ns.Name).Replicas(1).Backend("unknown").Obj() + return wrapper.MakePlayground("playground", ns.Name).Replicas(1).ModelClaim("llama3-8b").BackendRuntime("unknown").Obj() }, - failed: true, + failed: false, }), ginkgo.Entry("no main model", &testValidatingCase{ playground: func() *inferenceapi.Playground { diff --git a/test/util/mock.go b/test/util/mock.go index 548d9601..6642b9b3 100644 --- a/test/util/mock.go +++ b/test/util/mock.go @@ -42,8 +42,8 @@ func MockASampleService(ns string) *inferenceapi.Service { func MockASampleBackendRuntime() *wrapper.BackendRuntimeWrapper { return wrapper.MakeBackendRuntime("vllm"). - Image("vllm").Version("v0.6.0"). + Image("vllm/vllm-openai").Version("v0.6.0"). Command([]string{"python3", "-m", "vllm.entrypoints.openai.api_server"}). - Arg("default", []string{"--model", "{{.ModelPath}}", "--served-model-name", "{{.ModelName}}", "--host", "0.0.0.0", "--port", "8080"}). + Arg("Default", []string{"--model", "{{.ModelPath}}", "--served-model-name", "{{.ModelName}}", "--host", "0.0.0.0", "--port", "8080"}). Request("cpu", "4").Limit("cpu", "4") } diff --git a/test/util/util.go b/test/util/util.go index f634a948..1311bff8 100644 --- a/test/util/util.go +++ b/test/util/util.go @@ -17,11 +17,15 @@ package util import ( "context" + "fmt" + "os" "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" apimeta "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime/serializer/yaml" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" lws "sigs.k8s.io/lws/api/leaderworkerset/v1" @@ -68,3 +72,37 @@ func UpdateLwsToUnReady(ctx context.Context, k8sClient client.Client, name, name return nil }, IntegrationTimeout, Interval).Should(gomega.Succeed()) } + +func applyYaml(ctx context.Context, k8sClient client.Client, file string) error { + yamlFile, err := os.ReadFile(file) + if err != nil { + return fmt.Errorf("failed to read YAML file: %v", err) + } + + decode := yaml.NewDecodingSerializer(unstructured.UnstructuredJSONScheme) + obj := &unstructured.Unstructured{} + _, _, err = decode.Decode(yamlFile, nil, obj) + if err != nil { + return fmt.Errorf("failed to decode YAML into Unstructured object: %v", err) + } + + if err = k8sClient.Create(ctx, obj); err != nil { + return fmt.Errorf("failed to create resource: %v", err) + } + + return nil +} + +func Setup(ctx context.Context, k8sClient client.Client, path string) error { + entries, err := os.ReadDir(path) + if err != nil { + return err + } + + for _, entry := range entries { + if err := applyYaml(ctx, k8sClient, path+"/"+entry.Name()); err != nil { + return err + } + } + return nil +} diff --git a/test/util/validation/validate_playground.go b/test/util/validation/validate_playground.go index ce9021e4..91236f89 100644 --- a/test/util/validation/validate_playground.go +++ b/test/util/validation/validate_playground.go @@ -32,20 +32,15 @@ import ( coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" inferenceapi "github.com/inftyai/llmaz/api/inference/v1alpha1" - "github.com/inftyai/llmaz/pkg/controller_helper/backend" + helper "github.com/inftyai/llmaz/pkg/controller_helper" modelSource "github.com/inftyai/llmaz/pkg/controller_helper/model_source" "github.com/inftyai/llmaz/test/util" "github.com/inftyai/llmaz/test/util/format" ) -func validateModelClaim(ctx context.Context, k8sClient client.Client, playground *inferenceapi.Playground, service inferenceapi.Service) error { - model := coreapi.OpenModel{} - +func validateModelClaim(models []*coreapi.OpenModel, playground *inferenceapi.Playground, service inferenceapi.Service) error { + // Make sure the first model is the main model, or the test may fail. if playground.Spec.ModelClaim != nil { - if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(playground.Spec.ModelClaim.ModelName), Namespace: playground.Namespace}, &model); err != nil { - return errors.New("failed to get model") - } - if playground.Spec.ModelClaim.ModelName != service.Spec.ModelClaims.Models[0].Name { return fmt.Errorf("expected modelName %s, got %s", playground.Spec.ModelClaim.ModelName, service.Spec.ModelClaims.Models[0].Name) } @@ -53,16 +48,16 @@ func validateModelClaim(ctx context.Context, k8sClient client.Client, playground return fmt.Errorf("unexpected flavors, want %v, got %v", playground.Spec.ModelClaim.InferenceFlavors, service.Spec.ModelClaims.InferenceFlavors) } } else if playground.Spec.ModelClaims != nil { - if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(playground.Spec.ModelClaims.Models[0].Name), Namespace: playground.Namespace}, &model); err != nil { - return errors.New("failed to get model") - } if diff := cmp.Diff(*playground.Spec.ModelClaims, service.Spec.ModelClaims); diff != "" { return fmt.Errorf("expected modelClaims, want %v, got %v", *playground.Spec.ModelClaims, service.Spec.ModelClaims) } + if diff := cmp.Diff(playground.Spec.ModelClaims.InferenceFlavors, service.Spec.ModelClaims.InferenceFlavors); diff != "" { + return fmt.Errorf("unexpected flavors, want %v, got %v", playground.Spec.ModelClaim.InferenceFlavors, service.Spec.ModelClaims.InferenceFlavors) + } } - if playground.Labels[coreapi.ModelNameLabelKey] != model.Name { - return fmt.Errorf("unexpected Playground label value, want %v, got %v", model.Name, playground.Labels[coreapi.ModelNameLabelKey]) + if playground.Labels[coreapi.ModelNameLabelKey] != models[0].Name { + return fmt.Errorf("unexpected Playground label value, want %v, got %v", models[0].Name, playground.Labels[coreapi.ModelNameLabelKey]) } return nil @@ -75,7 +70,12 @@ func ValidatePlayground(ctx context.Context, k8sClient client.Client, playground return errors.New("failed to get inferenceService") } - if err := validateModelClaim(ctx, k8sClient, playground, service); err != nil { + models, err := helper.FetchModelsByPlayground(ctx, k8sClient, playground) + if err != nil { + return err + } + + if err := validateModelClaim(models, playground, service); err != nil { return err } @@ -83,26 +83,31 @@ func ValidatePlayground(ctx context.Context, k8sClient client.Client, playground return fmt.Errorf("expected replicas: %d, got %d", *playground.Spec.Replicas, *service.Spec.WorkloadTemplate.Replicas) } - backendName := inferenceapi.DefaultBackend + backendRuntimeName := inferenceapi.DefaultBackend if playground.Spec.BackendRuntimeConfig != nil && playground.Spec.BackendRuntimeConfig.Name != nil { - backendName = *playground.Spec.BackendRuntimeConfig.Name + backendRuntimeName = *playground.Spec.BackendRuntimeConfig.Name } - bkd := backend.SwitchBackend(backendName) + backendRuntime := inferenceapi.BackendRuntime{} + if err := k8sClient.Get(ctx, types.NamespacedName{Name: string(backendRuntimeName)}, &backendRuntime); err != nil { + return errors.New("failed to get backendRuntime") + } + + parser := helper.NewBackendRuntimeParser(&backendRuntime) if service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Name != modelSource.MODEL_RUNNER_CONTAINER_NAME { return fmt.Errorf("container name not right, want %s, got %s", modelSource.MODEL_RUNNER_CONTAINER_NAME, service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Name) } - if diff := cmp.Diff(bkd.DefaultCommand(), service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Command); diff != "" { + if diff := cmp.Diff(parser.Commands(), service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Command); diff != "" { return errors.New("command not right") } if playground.Spec.BackendRuntimeConfig != nil { if playground.Spec.BackendRuntimeConfig.Version != nil { - if bkd.Image(*playground.Spec.BackendRuntimeConfig.Version) != service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Image { - return fmt.Errorf("expected container image %s, got %s", bkd.Image(*playground.Spec.BackendRuntimeConfig.Version), service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Image) + if parser.Image(*playground.Spec.BackendRuntimeConfig.Version) != service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Image { + return fmt.Errorf("expected container image %s, got %s", parser.Image(*playground.Spec.BackendRuntimeConfig.Version), service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Image) } } else { - if bkd.Image(bkd.DefaultVersion()) != service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Image { - return fmt.Errorf("expected container image %s, got %s", bkd.Image(bkd.DefaultVersion()), service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Image) + if parser.Image(parser.Version()) != service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Image { + return fmt.Errorf("expected container image %s, got %s", parser.Image(parser.Version()), service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Image) } } for _, arg := range playground.Spec.BackendRuntimeConfig.Args { @@ -113,28 +118,28 @@ func ValidatePlayground(ctx context.Context, k8sClient client.Client, playground if diff := cmp.Diff(service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Env, playground.Spec.BackendRuntimeConfig.Envs); diff != "" { return fmt.Errorf("unexpected envs") } - if playground.Spec.BackendRuntimeConfig.Resources != nil { - for k, v := range playground.Spec.BackendRuntimeConfig.Resources.Limits { - if !service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Limits[k].Equal(v) { - return fmt.Errorf("unexpected limit for %s, want %v, got %v", k, v, service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Limits[k]) - } + } + if playground.Spec.BackendRuntimeConfig != nil && playground.Spec.BackendRuntimeConfig.Resources != nil { + for k, v := range playground.Spec.BackendRuntimeConfig.Resources.Limits { + if !service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Limits[k].Equal(v) { + return fmt.Errorf("unexpected limits for %s, want %v, got %v", k, v, service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Limits[k]) } - for k, v := range playground.Spec.BackendRuntimeConfig.Resources.Requests { - if !service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Requests[k].Equal(v) { - return fmt.Errorf("unexpected limit for %s, want %v, got %v", k, v, service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Requests[k]) - } + } + for k, v := range playground.Spec.BackendRuntimeConfig.Resources.Requests { + if !service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Requests[k].Equal(v) { + return fmt.Errorf("unexpected requests for %s, want %v, got %v", k, v, service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Requests[k]) } - } else { - // Validate default resources requirements. - for k, v := range bkd.DefaultResources().Limits { - if !service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Limits[k].Equal(v) { - return fmt.Errorf("unexpected limit for %s, want %v, got %v", k, v, service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Limits[k]) - } + } + } else { + // Validate default resources requirements. + for k, v := range parser.Resources().Limits { + if !service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Limits[k].Equal(v) { + return fmt.Errorf("unexpected limit for %s, want %v, got %v", k, v, service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Limits[k]) } - for k, v := range bkd.DefaultResources().Requests { - if !service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Requests[k].Equal(v) { - return fmt.Errorf("unexpected limit for %s, want %v, got %v", k, v, service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Requests[k]) - } + } + for k, v := range parser.Resources().Requests { + if !service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Requests[k].Equal(v) { + return fmt.Errorf("unexpected limit for %s, want %v, got %v", k, v, service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Resources.Requests[k]) } } } diff --git a/test/util/wrapper/backend.go b/test/util/wrapper/backend.go index a6b727ec..66c4fabf 100644 --- a/test/util/wrapper/backend.go +++ b/test/util/wrapper/backend.go @@ -42,6 +42,11 @@ func (w *BackendRuntimeWrapper) Obj() *inferenceapi.BackendRuntime { return &w.BackendRuntime } +func (w *BackendRuntimeWrapper) Name(name string) *BackendRuntimeWrapper { + w.ObjectMeta.Name = name + return w +} + func (w *BackendRuntimeWrapper) Image(image string) *BackendRuntimeWrapper { w.Spec.Image = image return w diff --git a/test/util/wrapper/playground.go b/test/util/wrapper/playground.go index 3696d58e..f5ec5e20 100644 --- a/test/util/wrapper/playground.go +++ b/test/util/wrapper/playground.go @@ -91,7 +91,7 @@ func (w *PlaygroundWrapper) ModelClaims(modelNames []string, roles []string, fla return w } -func (w *PlaygroundWrapper) Backend(name string) *PlaygroundWrapper { +func (w *PlaygroundWrapper) BackendRuntime(name string) *PlaygroundWrapper { if w.Spec.BackendRuntimeConfig == nil { w.Spec.BackendRuntimeConfig = &inferenceapi.BackendRuntimeConfig{} } @@ -100,25 +100,25 @@ func (w *PlaygroundWrapper) Backend(name string) *PlaygroundWrapper { return w } -func (w *PlaygroundWrapper) BackendVersion(version string) *PlaygroundWrapper { +func (w *PlaygroundWrapper) BackendRuntimeVersion(version string) *PlaygroundWrapper { if w.Spec.BackendRuntimeConfig == nil { - w = w.Backend("vllm") + w = w.BackendRuntime("vllm") } w.Spec.BackendRuntimeConfig.Version = &version return w } -func (w *PlaygroundWrapper) BackendArgs(args []string) *PlaygroundWrapper { +func (w *PlaygroundWrapper) BackendRuntimeArgs(args []string) *PlaygroundWrapper { if w.Spec.BackendRuntimeConfig == nil { - w = w.Backend("vllm") + w = w.BackendRuntime("vllm") } w.Spec.BackendRuntimeConfig.Args = args return w } -func (w *PlaygroundWrapper) BackendEnv(k, v string) *PlaygroundWrapper { +func (w *PlaygroundWrapper) BackendRuntimeEnv(k, v string) *PlaygroundWrapper { if w.Spec.BackendRuntimeConfig == nil { - w = w.Backend("vllm") + w = w.BackendRuntime("vllm") } w.Spec.BackendRuntimeConfig.Envs = append(w.Spec.BackendRuntimeConfig.Envs, v1.EnvVar{ Name: k, @@ -127,9 +127,9 @@ func (w *PlaygroundWrapper) BackendEnv(k, v string) *PlaygroundWrapper { return w } -func (w *PlaygroundWrapper) BackendRequest(r, v string) *PlaygroundWrapper { +func (w *PlaygroundWrapper) BackendRuntimeRequest(r, v string) *PlaygroundWrapper { if w.Spec.BackendRuntimeConfig == nil { - w = w.Backend("vllm") + w = w.BackendRuntime("vllm") } if w.Spec.BackendRuntimeConfig.Resources == nil { w.Spec.BackendRuntimeConfig.Resources = &inferenceapi.ResourceRequirements{} @@ -141,9 +141,9 @@ func (w *PlaygroundWrapper) BackendRequest(r, v string) *PlaygroundWrapper { return w } -func (w *PlaygroundWrapper) BackendLimit(r, v string) *PlaygroundWrapper { +func (w *PlaygroundWrapper) BackendRuntimeLimit(r, v string) *PlaygroundWrapper { if w.Spec.BackendRuntimeConfig == nil { - w = w.Backend("vllm") + w = w.BackendRuntime("vllm") } if w.Spec.BackendRuntimeConfig.Resources == nil { w.Spec.BackendRuntimeConfig.Resources = &inferenceapi.ResourceRequirements{}