diff --git a/go.mod b/go.mod index 945d75fb..65a33d1b 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/onsi/ginkgo/v2 v2.22.2 github.com/onsi/gomega v1.36.2 github.com/open-policy-agent/cert-controller v0.12.0 + github.com/stretchr/testify v1.9.0 k8s.io/api v0.32.2 k8s.io/apiextensions-apiserver v0.32.2 k8s.io/apimachinery v0.32.2 @@ -49,6 +50,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_golang v1.20.2 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.55.0 // indirect diff --git a/pkg/controller_helper/modelsource/modelhub.go b/pkg/controller_helper/modelsource/modelhub.go index 3d15489e..91db583e 100644 --- a/pkg/controller_helper/modelsource/modelhub.go +++ b/pkg/controller_helper/modelsource/modelhub.go @@ -75,6 +75,9 @@ func (p *ModelHubProvider) InjectModelLoader(template *corev1.PodTemplateSpec, i }, } + // We have exactly one container in the template.Spec.Containers. + spreadEnvToInitContainer(template.Spec.Containers[0].Env, initContainer) + // This is related to the model loader logics which will read the environment when loading models weights. initContainer.Env = append( initContainer.Env, @@ -157,3 +160,7 @@ func (p *ModelHubProvider) InjectModelLoader(template *corev1.PodTemplateSpec, i }, }) } + +func spreadEnvToInitContainer(containerEnv []corev1.EnvVar, initContainer *corev1.Container) { + initContainer.Env = append(initContainer.Env, containerEnv...) +} diff --git a/pkg/controller_helper/modelsource/modelsource_test.go b/pkg/controller_helper/modelsource/modelsource_test.go index a088be96..d8b5df71 100644 --- a/pkg/controller_helper/modelsource/modelsource_test.go +++ b/pkg/controller_helper/modelsource/modelsource_test.go @@ -19,6 +19,9 @@ package modelSource import ( "testing" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + coreapi "github.com/inftyai/llmaz/api/core/v1alpha1" "github.com/inftyai/llmaz/test/util" "github.com/inftyai/llmaz/test/util/wrapper" @@ -69,3 +72,60 @@ func Test_ModelSourceProvider(t *testing.T) { }) } } + +func TestEnvInjectModelLoader(t *testing.T) { + tests := []struct { + name string + provider ModelSourceProvider + template *corev1.PodTemplateSpec + }{ + { + name: "Spread container env to initContiner using modelhub", + provider: &ModelHubProvider{ + modelName: "test-model", + }, + template: &corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "model-runner", + Image: "vllm:test", + Env: []corev1.EnvVar{ + {Name: "http_proxy", Value: "1.1.1.1:1234"}, + {Name: "https_proxy", Value: "1.1.1.1:1234"}, + }, + }, + }, + }, + }, + }, + { + name: "Spread container env to initContiner using objstores", + provider: &URIProvider{ + modelName: "test-model", + }, + template: &corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "model-runner", + Image: "vllm:test", + Env: []corev1.EnvVar{ + {Name: "http_proxy", Value: "1.1.1.1:1234"}, + {Name: "https_proxy", Value: "1.1.1.1:1234"}, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.provider.InjectModelLoader(tt.template, 0) + initContainer := tt.template.Spec.InitContainers[0] + assert.Subset(t, initContainer.Env, tt.template.Spec.Containers[0].Env) + }) + } +} diff --git a/pkg/controller_helper/modelsource/uri.go b/pkg/controller_helper/modelsource/uri.go index 707b0da2..6790dd54 100644 --- a/pkg/controller_helper/modelsource/uri.go +++ b/pkg/controller_helper/modelsource/uri.go @@ -120,6 +120,9 @@ func (p *URIProvider) InjectModelLoader(template *corev1.PodTemplateSpec, index }, } + // We have exactly one container in the template.Spec.Containers. + spreadEnvToInitContainer(template.Spec.Containers[0].Env, initContainer) + switch p.protocol { case OSS: initContainer.Env = append(