diff --git a/.golangci.yaml b/.golangci.yaml index af3aa62b..9a2138e5 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -32,7 +32,6 @@ linters: - errcheck - exportloopref - goconst - - gocyclo - gofmt - goimports - gosimple diff --git a/api/core/v1alpha1/model_types.go b/api/core/v1alpha1/model_types.go index abb72c29..8b2daf58 100644 --- a/api/core/v1alpha1/model_types.go +++ b/api/core/v1alpha1/model_types.go @@ -139,12 +139,14 @@ type ModelClaim struct { type ModelRole string const ( - // Main represents the main model, if only one model is required, + // MainRole represents the main model, if only one model is required, // it must be the main model. Only one main model is allowed. MainRole ModelRole = "main" - // Draft represents the draft model in speculative decoding, + // DraftRole represents the draft model in speculative decoding, // the main model is the target model then. DraftRole ModelRole = "draft" + // LoraRole represents the lora model. + LoraRole ModelRole = "lora" ) // ModelRefer refers to a created Model with it's role. diff --git a/api/inference/v1alpha1/config_types.go b/api/inference/v1alpha1/config_types.go index ffab7aea..1aa56889 100644 --- a/api/inference/v1alpha1/config_types.go +++ b/api/inference/v1alpha1/config_types.go @@ -33,10 +33,17 @@ type BackendRuntimeConfig struct { // from the default version. // +optional Version *string `json:"version,omitempty"` - // Args represents the arguments appended to the backend. - // You can add new args or overwrite the default args. + // ArgName represents the argument name set in the backendRuntimeArg. + // If not set, will be derived by the model role, e.g. if one model's role + // is , the argName will be set to . Better to + // set the argName explicitly. + // By default, the argName will be treated as in runtime. // +optional - Args []string `json:"args,omitempty"` + ArgName *string `json:"argName,omitempty"` + // ArgFlags represents the argument flags appended to the backend. + // You can add new flags or overwrite the default flags. + // +optional + ArgFlags []string `json:"argFlags,omitempty"` // Envs represents the environments set to the container. // +optional Envs []corev1.EnvVar `json:"envs,omitempty"` diff --git a/api/inference/v1alpha1/zz_generated.deepcopy.go b/api/inference/v1alpha1/zz_generated.deepcopy.go index 41263081..cad051d8 100644 --- a/api/inference/v1alpha1/zz_generated.deepcopy.go +++ b/api/inference/v1alpha1/zz_generated.deepcopy.go @@ -87,8 +87,13 @@ func (in *BackendRuntimeConfig) DeepCopyInto(out *BackendRuntimeConfig) { *out = new(string) **out = **in } - if in.Args != nil { - in, out := &in.Args, &out.Args + if in.ArgName != nil { + in, out := &in.ArgName, &out.ArgName + *out = new(string) + **out = **in + } + if in.ArgFlags != nil { + in, out := &in.ArgFlags, &out.ArgFlags *out = make([]string, len(*in)) copy(*out, *in) } diff --git a/client-go/applyconfiguration/inference/v1alpha1/backendruntimeconfig.go b/client-go/applyconfiguration/inference/v1alpha1/backendruntimeconfig.go index 68cabc91..9624854c 100644 --- a/client-go/applyconfiguration/inference/v1alpha1/backendruntimeconfig.go +++ b/client-go/applyconfiguration/inference/v1alpha1/backendruntimeconfig.go @@ -27,7 +27,8 @@ import ( type BackendRuntimeConfigApplyConfiguration struct { Name *v1alpha1.BackendName `json:"name,omitempty"` Version *string `json:"version,omitempty"` - Args []string `json:"args,omitempty"` + ArgName *string `json:"argName,omitempty"` + ArgFlags []string `json:"argFlags,omitempty"` Envs []v1.EnvVar `json:"envs,omitempty"` Resources *ResourceRequirementsApplyConfiguration `json:"resources,omitempty"` } @@ -54,12 +55,20 @@ func (b *BackendRuntimeConfigApplyConfiguration) WithVersion(value string) *Back return b } -// WithArgs adds the given value to the Args field in the declarative configuration +// WithArgName sets the ArgName field in the declarative configuration to the given value +// and returns the receiver, so that objects can be built by chaining "With" function invocations. +// If called multiple times, the ArgName field is set to the value of the last call. +func (b *BackendRuntimeConfigApplyConfiguration) WithArgName(value string) *BackendRuntimeConfigApplyConfiguration { + b.ArgName = &value + return b +} + +// WithArgFlags adds the given value to the ArgFlags field in the declarative configuration // and returns the receiver, so that objects can be build by chaining "With" function invocations. -// If called multiple times, values provided by each call will be appended to the Args field. -func (b *BackendRuntimeConfigApplyConfiguration) WithArgs(values ...string) *BackendRuntimeConfigApplyConfiguration { +// If called multiple times, values provided by each call will be appended to the ArgFlags field. +func (b *BackendRuntimeConfigApplyConfiguration) WithArgFlags(values ...string) *BackendRuntimeConfigApplyConfiguration { for i := range values { - b.Args = append(b.Args, values[i]) + b.ArgFlags = append(b.ArgFlags, values[i]) } return b } diff --git a/config/crd/bases/inference.llmaz.io_playgrounds.yaml b/config/crd/bases/inference.llmaz.io_playgrounds.yaml index d4d6f480..ba52db66 100644 --- a/config/crd/bases/inference.llmaz.io_playgrounds.yaml +++ b/config/crd/bases/inference.llmaz.io_playgrounds.yaml @@ -46,13 +46,21 @@ spec: BackendRuntimeConfig represents the inference backendRuntime configuration under the hood, e.g. vLLM, which is the default backendRuntime. properties: - args: + argFlags: description: |- - Args represents the arguments appended to the backend. - You can add new args or overwrite the default args. + ArgFlags represents the argument flags appended to the backend. + You can add new flags or overwrite the default flags. items: type: string type: array + argName: + description: |- + ArgName represents the argument name set in the backendRuntimeArg. + If not set, will be derived by the model role, e.g. if one model's role + is , the argName will be set to . Better to + set the argName explicitly. + By default, the argName will be treated as in runtime. + type: string envs: description: Envs represents the environments set to the container. items: diff --git a/docs/examples/llamacpp/playground.yaml b/docs/examples/llamacpp/playground.yaml index bf62f9a4..c2b94901 100644 --- a/docs/examples/llamacpp/playground.yaml +++ b/docs/examples/llamacpp/playground.yaml @@ -8,5 +8,5 @@ spec: modelName: qwen2-0--5b-gguf backendRuntimeConfig: name: llamacpp - args: + argFlags: - -fa # use flash attention diff --git a/docs/examples/speculative-decoding/llamacpp/playground.yaml b/docs/examples/speculative-decoding/llamacpp/playground.yaml index 405b3577..daec5e67 100644 --- a/docs/examples/speculative-decoding/llamacpp/playground.yaml +++ b/docs/examples/speculative-decoding/llamacpp/playground.yaml @@ -15,7 +15,7 @@ spec: role: draft backendRuntimeConfig: name: llamacpp - args: + argFlags: - -fa # use flash attention resources: requests: diff --git a/pkg/controller/inference/playground_controller.go b/pkg/controller/inference/playground_controller.go index 53f9adbf..fe81d9ec 100644 --- a/pkg/controller/inference/playground_controller.go +++ b/pkg/controller/inference/playground_controller.go @@ -257,14 +257,14 @@ func buildWorkloadTemplate(models []*coreapi.OpenModel, playground *inferenceapi func buildWorkerTemplate(models []*coreapi.OpenModel, playground *inferenceapi.Playground, backendRuntime *inferenceapi.BackendRuntime) (corev1.PodTemplateSpec, error) { parser := helper.NewBackendRuntimeParser(backendRuntime) - args, err := parser.Args(helper.PlaygroundInferenceMode(playground), models) + args, err := parser.Args(playground, models) if err != nil { return corev1.PodTemplateSpec{}, err } envs := parser.Envs() if playground.Spec.BackendRuntimeConfig != nil { - args = append(args, playground.Spec.BackendRuntimeConfig.Args...) + args = append(args, playground.Spec.BackendRuntimeConfig.ArgFlags...) envs = append(envs, playground.Spec.BackendRuntimeConfig.Envs...) } diff --git a/pkg/controller_helper/backendruntime.go b/pkg/controller_helper/backendruntime.go index 4f9f9486..03e2386a 100644 --- a/pkg/controller_helper/backendruntime.go +++ b/pkg/controller_helper/backendruntime.go @@ -45,34 +45,29 @@ func (p *BackendRuntimeParser) Envs() []corev1.EnvVar { return p.backendRuntime.Spec.Envs } -func (p *BackendRuntimeParser) Args(mode InferenceMode, models []*coreapi.OpenModel) ([]string, error) { - // TODO: add validation in webhook. - if mode == SpeculativeDecodingInferenceMode && len(models) != 2 { - return nil, fmt.Errorf("models number not right, want 2, got %d", len(models)) +func (p *BackendRuntimeParser) Args(playground *inferenceapi.Playground, models []*coreapi.OpenModel) ([]string, error) { + var argName string + if playground.Spec.BackendRuntimeConfig != nil && playground.Spec.BackendRuntimeConfig.ArgName != nil { + argName = *playground.Spec.BackendRuntimeConfig.ArgName + } else { + // Auto detect the args from model roles. + argName = DetectArgFrom(playground) } - modelInfo := map[string]string{} - - if mode == DefaultInferenceMode { - source := modelSource.NewModelSourceProvider(models[0]) - modelInfo = map[string]string{ - "ModelPath": source.ModelPath(), - "ModelName": source.ModelName(), - } + source := modelSource.NewModelSourceProvider(models[0]) + modelInfo := map[string]string{ + "ModelPath": source.ModelPath(), + "ModelName": source.ModelName(), } - if mode == SpeculativeDecodingInferenceMode { - targetSource := modelSource.NewModelSourceProvider(models[0]) - draftSource := modelSource.NewModelSourceProvider(models[1]) - modelInfo = map[string]string{ - "ModelPath": targetSource.ModelPath(), - "ModelName": targetSource.ModelName(), - "DraftModelPath": draftSource.ModelPath(), - } + // TODO: This is not that reliable because two models doesn't always means speculative-decoding. + // Revisit this later. + if len(models) > 1 { + modelInfo["DraftModelPath"] = modelSource.NewModelSourceProvider(models[1]).ModelPath() } for _, arg := range p.backendRuntime.Spec.Args { - if InferenceMode(arg.Name) == mode { + if arg.Name == argName { return renderFlags(arg.Flags, modelInfo) } } diff --git a/pkg/controller_helper/helper.go b/pkg/controller_helper/helper.go index 7405e2c6..bd8f751b 100644 --- a/pkg/controller_helper/helper.go +++ b/pkg/controller_helper/helper.go @@ -25,31 +25,28 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" ) -type InferenceMode string - // These two modes are preset. const ( - DefaultInferenceMode InferenceMode = "default" - SpeculativeDecodingInferenceMode InferenceMode = "speculative-decoding" + DefaultArg string = "default" + SpeculativeDecodingArg string = "speculative-decoding" ) -// PlaygroundInferenceMode gets the mode of inference process, supports default -// or speculative-decoding for now, which is aligned with backendRuntime. -func PlaygroundInferenceMode(playground *inferenceapi.Playground) InferenceMode { +// DetectArgFrom wil auto detect the arg from model roles if not set explicitly. +func DetectArgFrom(playground *inferenceapi.Playground) string { if playground.Spec.ModelClaim != nil { - return DefaultInferenceMode + return DefaultArg } if playground.Spec.ModelClaims != nil { for _, mr := range playground.Spec.ModelClaims.Models { if *mr.Role == coreapi.DraftRole { - return SpeculativeDecodingInferenceMode + return SpeculativeDecodingArg } } } // We should not reach here. - return DefaultInferenceMode + return DefaultArg } func FetchModelsByService(ctx context.Context, k8sClient client.Client, service *inferenceapi.Service) (models []*coreapi.OpenModel, err error) { diff --git a/pkg/webhook/playground_webhook.go b/pkg/webhook/playground_webhook.go index 094e4bb1..983ab678 100644 --- a/pkg/webhook/playground_webhook.go +++ b/pkg/webhook/playground_webhook.go @@ -112,8 +112,8 @@ func (w *PlaygroundWebhook) generateValidate(obj runtime.Object) field.ErrorList } } - mode := helper.PlaygroundInferenceMode(playground) - if mode == helper.SpeculativeDecodingInferenceMode { + arg := helper.DetectArgFrom(playground) + if arg == helper.SpeculativeDecodingArg { if len(playground.Spec.ModelClaims.Models) != 2 { allErrs = append(allErrs, field.Forbidden(specPath.Child("modelClaims", "models"), "only two models are allowed in speculativeDecoding mode")) } diff --git a/test/config/backends/fake_backend.yaml b/test/config/backends/fake_backend.yaml new file mode 100644 index 00000000..18374277 --- /dev/null +++ b/test/config/backends/fake_backend.yaml @@ -0,0 +1,35 @@ +apiVersion: inference.llmaz.io/v1alpha1 +kind: BackendRuntime +metadata: + labels: + app.kubernetes.io/name: backendruntime + app.kubernetes.io/part-of: llmaz + app.kubernetes.io/created-by: llmaz + name: fake-backend +spec: + commands: + - sh + - -c + - echo "hello" + image: busybox + version: latest + args: + - name: default + flags: + - mode + - "default" + - name: speculative-decoding + flags: + - mode + - "speculative-decoding" + - name: fuz + flags: + - mode + - "fuz" + resources: + requests: + cpu: 4 + memory: 8Gi + limits: + cpu: 4 + memory: 8Gi diff --git a/test/integration/controller/inference/playground_test.go b/test/integration/controller/inference/playground_test.go index 8f47e03e..586b6715 100644 --- a/test/integration/controller/inference/playground_test.go +++ b/test/integration/controller/inference/playground_test.go @@ -183,7 +183,7 @@ var _ = ginkgo.Describe("playground controller test", func() { ginkgo.Entry("advance configured Playground with sglang", &testValidatingCase{ makePlayground: func() *inferenceapi.Playground { return wrapper.MakePlayground("playground", ns.Name).ModelClaim(model.Name).Label(coreapi.ModelNameLabelKey, model.Name). - BackendRuntime("sglang").BackendRuntimeVersion("main").BackendRuntimeArgs([]string{"--foo", "bar"}).BackendRuntimeEnv("FOO", "BAR"). + BackendRuntime("sglang").BackendRuntimeVersion("main").BackendRuntimeArgFlags([]string{"--foo", "bar"}).BackendRuntimeEnv("FOO", "BAR"). BackendRuntimeRequest("cpu", "1").BackendRuntimeLimit("cpu", "10"). Obj() }, @@ -211,7 +211,7 @@ var _ = ginkgo.Describe("playground controller test", func() { ginkgo.Entry("advance configured Playground with llamacpp", &testValidatingCase{ makePlayground: func() *inferenceapi.Playground { return wrapper.MakePlayground("playground", ns.Name).ModelClaim(model.Name).Label(coreapi.ModelNameLabelKey, model.Name). - BackendRuntime("llamacpp").BackendRuntimeVersion("main").BackendRuntimeArgs([]string{"--foo", "bar"}).BackendRuntimeEnv("FOO", "BAR"). + BackendRuntime("llamacpp").BackendRuntimeVersion("main").BackendRuntimeArgFlags([]string{"--foo", "bar"}).BackendRuntimeEnv("FOO", "BAR"). BackendRuntimeRequest("cpu", "1").BackendRuntimeLimit("cpu", "10"). Obj() }, @@ -239,7 +239,7 @@ var _ = ginkgo.Describe("playground controller test", func() { ginkgo.Entry("advance configured Playground with tgi", &testValidatingCase{ makePlayground: func() *inferenceapi.Playground { return wrapper.MakePlayground("playground", ns.Name).ModelClaim(model.Name).Label(coreapi.ModelNameLabelKey, model.Name). - BackendRuntime("tgi").BackendRuntimeVersion("main").BackendRuntimeArgs([]string{"--model-id", "Qwen/Qwen2-0.5B-Instruct"}).BackendRuntimeEnv("FOO", "BAR"). + BackendRuntime("tgi").BackendRuntimeVersion("main").BackendRuntimeArgFlags([]string{"--model-id", "Qwen/Qwen2-0.5B-Instruct"}).BackendRuntimeEnv("FOO", "BAR"). BackendRuntimeRequest("cpu", "1").BackendRuntimeLimit("cpu", "10"). Obj() }, @@ -267,7 +267,7 @@ var _ = ginkgo.Describe("playground controller test", func() { ginkgo.Entry("advance configured Playground with ollama", &testValidatingCase{ makePlayground: func() *inferenceapi.Playground { return wrapper.MakePlayground("playground", ns.Name).ModelClaim(model.Name).Label(coreapi.ModelNameLabelKey, model.Name). - BackendRuntime("ollama").BackendRuntimeVersion("main").BackendRuntimeArgs([]string{"--foo", "bar"}).BackendRuntimeEnv("FOO", "BAR"). + BackendRuntime("ollama").BackendRuntimeVersion("main").BackendRuntimeArgFlags([]string{"--foo", "bar"}).BackendRuntimeEnv("FOO", "BAR"). BackendRuntimeRequest("cpu", "1").BackendRuntimeLimit("cpu", "10"). Obj() }, @@ -292,6 +292,24 @@ var _ = ginkgo.Describe("playground controller test", func() { }, }, }), + ginkgo.Entry("advance configured Playground with argName set", &testValidatingCase{ + makePlayground: func() *inferenceapi.Playground { + return wrapper.MakePlayground("playground", ns.Name).ModelClaim(model.Name).Label(coreapi.ModelNameLabelKey, model.Name). + BackendRuntime("fake-backend").BackendRuntimeVersion("main").BackendRuntimeArgName("fuz").BackendRuntimeArgFlags([]string{"--model-id", "Qwen/Qwen2-0.5B-Instruct"}).BackendRuntimeEnv("FOO", "BAR"). + BackendRuntimeRequest("cpu", "1").BackendRuntimeLimit("cpu", "10"). + Obj() + }, + updates: []*update{ + { + updateFunc: func(playground *inferenceapi.Playground) { + gomega.Expect(k8sClient.Create(ctx, playground)).To(gomega.Succeed()) + }, + checkFunc: func(ctx context.Context, k8sClient client.Client, playground *inferenceapi.Playground) { + validation.ValidatePlayground(ctx, k8sClient, playground) + }, + }, + }, + }), ginkgo.Entry("playground is created when service exists with the same name", &testValidatingCase{ makePlayground: func() *inferenceapi.Playground { return util.MockASamplePlayground(ns.Name) diff --git a/test/util/validation/validate_playground.go b/test/util/validation/validate_playground.go index 91236f89..a793f918 100644 --- a/test/util/validation/validate_playground.go +++ b/test/util/validation/validate_playground.go @@ -22,6 +22,7 @@ import ( "fmt" "os" "slices" + "strings" "github.com/google/go-cmp/cmp" "github.com/onsi/gomega" @@ -110,11 +111,26 @@ func ValidatePlayground(ctx context.Context, k8sClient client.Client, playground return fmt.Errorf("expected container image %s, got %s", parser.Image(parser.Version()), service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Image) } } - for _, arg := range playground.Spec.BackendRuntimeConfig.Args { + + // We assumed the 0-index arg is the default one. + argFlags := backendRuntime.Spec.Args[0].Flags + if playground.Spec.BackendRuntimeConfig.ArgName != nil { + for _, arg := range backendRuntime.Spec.Args { + if arg.Name == *playground.Spec.BackendRuntimeConfig.ArgName { + argFlags = arg.Flags + } + } + } + argFlags = append(argFlags, playground.Spec.BackendRuntimeConfig.ArgFlags...) + for _, arg := range argFlags { + if strings.Contains(arg, "{{") && strings.Contains(arg, "}}") { + continue + } if !slices.Contains(service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Args, arg) { return fmt.Errorf("didn't contain arg: %s", arg) } } + if diff := cmp.Diff(service.Spec.WorkloadTemplate.LeaderWorkerTemplate.WorkerTemplate.Spec.Containers[0].Env, playground.Spec.BackendRuntimeConfig.Envs); diff != "" { return fmt.Errorf("unexpected envs") } diff --git a/test/util/wrapper/playground.go b/test/util/wrapper/playground.go index b053e076..fc4f2627 100644 --- a/test/util/wrapper/playground.go +++ b/test/util/wrapper/playground.go @@ -109,11 +109,19 @@ func (w *PlaygroundWrapper) BackendRuntimeVersion(version string) *PlaygroundWra return w } -func (w *PlaygroundWrapper) BackendRuntimeArgs(args []string) *PlaygroundWrapper { +func (w *PlaygroundWrapper) BackendRuntimeArgName(name string) *PlaygroundWrapper { if w.Spec.BackendRuntimeConfig == nil { w = w.BackendRuntime("vllm") } - w.Spec.BackendRuntimeConfig.Args = args + w.Spec.BackendRuntimeConfig.ArgName = &name + return w +} + +func (w *PlaygroundWrapper) BackendRuntimeArgFlags(args []string) *PlaygroundWrapper { + if w.Spec.BackendRuntimeConfig == nil { + w = w.BackendRuntime("vllm") + } + w.Spec.BackendRuntimeConfig.ArgFlags = args return w }