Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions ai/worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/api/types/mount"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/api/types/system"
docker "github.com/docker/docker/client"
"github.com/docker/docker/errdefs"
"github.com/docker/docker/pkg/jsonmessage"
Expand Down Expand Up @@ -96,6 +97,7 @@ type DockerClient interface {
ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error
ImageInspectWithRaw(ctx context.Context, imageID string) (types.ImageInspect, []byte, error)
ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error)
Info(ctx context.Context) (system.Info, error)
}

// Compile-time assertion to ensure docker.Client implements DockerClient.
Expand All @@ -104,11 +106,29 @@ var _ DockerClient = (*docker.Client)(nil)
// Create global references to functions to allow for mocking in tests.
var dockerWaitUntilRunningFunc = dockerWaitUntilRunning

// checkRuntimeAvailable checks if a specified container runtime is available
func checkRuntimeAvailable(ctx context.Context, client DockerClient, runtimeName string) (bool, error) {
info, err := client.Info(ctx)
if err != nil {
return false, fmt.Errorf("failed to get Docker daemon info: %w", err)
}

// Check if the specified runtime is registered.
if runtimes := info.Runtimes; runtimes != nil {
if _, exists := runtimes[runtimeName]; exists {
return true, nil
}
}

return false, nil
}

type DockerManager struct {
gpus []string
modelDir string
overrides ImageOverrides
verboseLogs bool
gpus []string
modelDir string
overrides ImageOverrides
verboseLogs bool
dockerRuntime string

dockerClient DockerClient
// gpu ID => container
Expand All @@ -118,19 +138,36 @@ type DockerManager struct {
mu *sync.Mutex
}

func NewDockerManager(overrides ImageOverrides, verboseLogs bool, gpus []string, modelDir string, client DockerClient) (*DockerManager, error) {
func NewDockerManager(overrides ImageOverrides, verboseLogs bool, gpus []string, modelDir string, client DockerClient, dockerRuntime string) (*DockerManager, error) {
ctx, cancel := context.WithTimeout(context.Background(), containerTimeout)
if err := removeExistingContainers(ctx, client); err != nil {
cancel()
return nil, err
}
cancel()

// Check runtime availability if a specific runtime is requested.
if dockerRuntime != "" {
runtimeCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

if available, err := checkRuntimeAvailable(runtimeCtx, client, dockerRuntime); err != nil {
slog.Warn("Docker runtime check failed, using default runtime", slog.String("error", err.Error()))
} else {
if available {
slog.Info("Docker runtime detected and will be used", slog.String("runtime", dockerRuntime))
} else {
slog.Warn("Docker runtime not available, using default runtime", slog.String("runtime", dockerRuntime))
}
}
}

manager := &DockerManager{
gpus: gpus,
modelDir: modelDir,
overrides: overrides,
verboseLogs: verboseLogs,
dockerRuntime: dockerRuntime,
dockerClient: client,
gpuContainers: make(map[string]*RunnerContainer),
containers: make(map[string]*RunnerContainer),
Expand Down Expand Up @@ -424,6 +461,12 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
RestartPolicy: restartPolicy,
}

// Use custom Docker runtime if specified.
if m.dockerRuntime != "" {
hostConfig.Runtime = m.dockerRuntime
slog.Info("Using custom Docker runtime for container", slog.String("container", containerName), slog.String("runtime", m.dockerRuntime))
}

resp, err := m.dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, containerName)
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions ai/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ type Worker struct {
mu *sync.Mutex
}

func NewWorker(imageOverrides ImageOverrides, verboseLogs bool, gpus []string, modelDir string) (*Worker, error) {
func NewWorker(imageOverrides ImageOverrides, verboseLogs bool, gpus []string, modelDir string, dockerRuntime string) (*Worker, error) {
dockerClient, err := docker.NewClientWithOpts(docker.FromEnv, docker.WithAPIVersionNegotiation())
if err != nil {
return nil, fmt.Errorf("error creating docker client: %w", err)
}

manager, err := NewDockerManager(imageOverrides, verboseLogs, gpus, modelDir, dockerClient)
manager, err := NewDockerManager(imageOverrides, verboseLogs, gpus, modelDir, dockerClient, dockerRuntime)
if err != nil {
return nil, fmt.Errorf("error creating docker manager: %w", err)
}
Expand Down
1 change: 1 addition & 0 deletions cmd/livepeer/starter/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func NewLivepeerConfig(fs *flag.FlagSet) LivepeerConfig {
// AI:
cfg.AIServiceRegistry = fs.Bool("aiServiceRegistry", *cfg.AIServiceRegistry, "Set to true to use an AI ServiceRegistry contract address")
cfg.AIWorker = fs.Bool("aiWorker", *cfg.AIWorker, "Set to true to run an AI worker")
cfg.DockerRuntime = fs.String("dockerRuntime", *cfg.DockerRuntime, "Docker container runtime for enhanced security isolation (e.g., 'kata-runtime', 'runsc', 'kata-fc', 'kata-qemu'). If empty, uses Docker default runtime.")
cfg.AIModels = fs.String("aiModels", *cfg.AIModels, "Set models (pipeline:model_id) for AI worker to load upon initialization")
cfg.AIModelsDir = fs.String("aiModelsDir", *cfg.AIModelsDir, "Set directory where AI model weights are stored")
cfg.AIRunnerImage = fs.String("aiRunnerImage", *cfg.AIRunnerImage, "[Deprecated] Specify the base Docker image for the AI runner. Example: livepeer/ai-runner:0.0.1. Use -aiRunnerImageOverrides instead.")
Expand Down
5 changes: 4 additions & 1 deletion cmd/livepeer/starter/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ type LivepeerConfig struct {
Transcoder *bool
AIServiceRegistry *bool
AIWorker *bool
DockerRuntime *string
Gateway *bool
Broadcaster *bool
OrchSecret *string
Expand Down Expand Up @@ -228,6 +229,7 @@ func DefaultLivepeerConfig() LivepeerConfig {
// AI:
defaultAIServiceRegistry := false
defaultAIWorker := false
defaultDockerRuntime := ""
defaultAIModels := ""
defaultAIModelsDir := ""
defaultAIRunnerImage := "livepeer/ai-runner:latest"
Expand Down Expand Up @@ -346,6 +348,7 @@ func DefaultLivepeerConfig() LivepeerConfig {
// AI:
AIServiceRegistry: &defaultAIServiceRegistry,
AIWorker: &defaultAIWorker,
DockerRuntime: &defaultDockerRuntime,
AIModels: &defaultAIModels,
AIModelsDir: &defaultAIModelsDir,
AIRunnerImage: &defaultAIRunnerImage,
Expand Down Expand Up @@ -1308,7 +1311,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) {
}
}

n.AIWorker, err = worker.NewWorker(imageOverrides, *cfg.AIVerboseLogs, gpus, modelsDir)
n.AIWorker, err = worker.NewWorker(imageOverrides, *cfg.AIVerboseLogs, gpus, modelsDir, *cfg.DockerRuntime)
if err != nil {
glog.Errorf("Error starting AI worker: %v", err)
return
Expand Down
Loading