@@ -20,13 +20,23 @@ import (
2020 "context"
2121 "errors"
2222 "fmt"
23+ "io"
24+ "net/http"
25+ "os"
26+ "os/signal"
2327 "strconv"
28+ "strings"
29+ "syscall"
2430
2531 "github.com/onsi/gomega"
2632 corev1 "k8s.io/api/core/v1"
2733 apimeta "k8s.io/apimachinery/pkg/api/meta"
2834 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2935 "k8s.io/apimachinery/pkg/types"
36+ "k8s.io/client-go/kubernetes"
37+ "k8s.io/client-go/rest"
38+ "k8s.io/client-go/tools/portforward"
39+ "k8s.io/client-go/transport/spdy"
3040 "sigs.k8s.io/controller-runtime/pkg/client"
3141 lws "sigs.k8s.io/lws/api/leaderworkerset/v1"
3242
@@ -233,3 +243,108 @@ func ValidateServicePods(ctx context.Context, k8sClient client.Client, service *
233243 return nil
234244 }).Should (gomega .Succeed ())
235245}
246+
247+ type CheckServiceAvailableFunc func () error
248+
249+ func ValidateServiceAvaliable (ctx context.Context , k8sClient client.Client , cfg * rest.Config , service * inferenceapi.Service , check CheckServiceAvailableFunc ) error {
250+ pods := corev1.PodList {}
251+ podSelector := client .MatchingLabels (map [string ]string {
252+ lws .SetNameLabelKey : service .Name ,
253+ })
254+ if err := k8sClient .List (ctx , & pods , podSelector , client .InNamespace (service .Namespace )); err != nil {
255+ return err
256+ }
257+ if len (pods .Items ) != int (* service .Spec .Replicas )* int (* service .Spec .WorkloadTemplate .Size ) {
258+ return fmt .Errorf ("pods number not right, want: %d, got: %d" , int (* service .Spec .Replicas )* int (* service .Spec .WorkloadTemplate .Size ), len (pods .Items ))
259+ }
260+
261+ var targetPod * corev1.Pod
262+ for i := range pods .Items {
263+ if pods .Items [i ].Status .Phase == corev1 .PodRunning {
264+ targetPod = & pods .Items [i ]
265+ break
266+ }
267+ }
268+
269+ if targetPod == nil {
270+ return fmt .Errorf ("no running pods found for service %s" , service .Name )
271+ }
272+
273+ portForwardK8sClient , err := kubernetes .NewForConfig (cfg )
274+ if err != nil {
275+ return fmt .Errorf ("init port forward client failed: %w" , err )
276+ }
277+
278+ targetPort := targetPod .Spec .Containers [0 ].Ports [0 ].ContainerPort
279+ stopChan , readyChan := make (chan struct {}, 1 ), make (chan struct {}, 1 )
280+ req := portForwardK8sClient .CoreV1 ().RESTClient ().Post ().
281+ Resource ("pods" ).
282+ Namespace (service .Namespace ).
283+ Name (targetPod .Name ).
284+ SubResource ("portforward" )
285+
286+ transport , upgrader , err := spdy .RoundTripperFor (cfg )
287+ if err != nil {
288+ return fmt .Errorf ("creating round tripper failed: %v" , err )
289+ }
290+
291+ dialer := spdy .NewDialer (upgrader , & http.Client {Transport : transport }, "POST" , req .URL ())
292+ // create port forwarder
293+ fw , err := portforward .New (dialer , []string {fmt .Sprintf ("%d:%d" , modelSource .DEFAULT_BACKEND_PORT , targetPort )}, stopChan , readyChan , os .Stdout , os .Stderr )
294+ if err != nil {
295+ return fmt .Errorf ("creating port forwarder failed: %v" , err )
296+ }
297+ // stop port forward when done
298+ defer fw .Close ()
299+ signals := make (chan os.Signal , 1 )
300+ signal .Notify (signals , os .Interrupt , syscall .SIGTERM )
301+
302+ go func () {
303+ <- signals
304+ fmt .Println ("Received termination signal, shutting down port forward..." )
305+ close (stopChan )
306+ }()
307+
308+ // wait for port forward to be ready
309+ go func () {
310+ if err = fw .ForwardPorts (); err != nil {
311+ fmt .Printf ("Error forwarding ports: %v\n " , err )
312+ close (stopChan )
313+ }
314+ }()
315+ <- readyChan
316+ gomega .Eventually (check ()).Should (gomega .Succeed ())
317+ return nil
318+ }
319+
320+ func CheckServiceAvaliable () error {
321+ url := fmt .Sprintf ("http://localhost:%d/completions" , modelSource .DEFAULT_BACKEND_PORT )
322+ reqBody := `{"prompt":"What is the capital city of China?","stream":false}`
323+
324+ req , err := http .NewRequest ("POST" , url , strings .NewReader (reqBody ))
325+ if err != nil {
326+ return err
327+ }
328+ client := & http.Client {}
329+ resp , err := client .Do (req )
330+ if err != nil {
331+ return err
332+ }
333+ defer func () {
334+ _ = resp .Body .Close ()
335+ }()
336+
337+ if resp .StatusCode != http .StatusOK {
338+ return fmt .Errorf ("error HTTP status code %d" , resp .StatusCode )
339+ }
340+
341+ body , err := io .ReadAll (resp .Body )
342+ if err != nil {
343+ return fmt .Errorf ("error reading response: %v" , err )
344+ }
345+
346+ if ! strings .Contains (strings .ToLower (string (body )), "beijing" ) {
347+ return fmt .Errorf ("error response body: %s" , string (body ))
348+ }
349+ return nil
350+ }
0 commit comments