Skip to content

Commit 6ca13a0

Browse files
committed
Add flags for disabling pre-pull memory estimation
Signed-off-by: Piotr Stankiewicz <[email protected]>
1 parent ebb4723 commit 6ca13a0

File tree

18 files changed

+348
-41
lines changed

18 files changed

+348
-41
lines changed

commands/compose.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7-
"github.com/docker/model-cli/pkg/types"
8-
"github.com/spf13/pflag"
97
"slices"
108
"strings"
119

10+
"github.com/docker/model-cli/pkg/types"
11+
"github.com/spf13/pflag"
12+
1213
"github.com/docker/model-cli/desktop"
1314
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
14-
"github.com/docker/model-runner/pkg/inference/scheduling"
1515
dmrm "github.com/docker/model-runner/pkg/inference/models"
16+
"github.com/docker/model-runner/pkg/inference/scheduling"
1617
"github.com/spf13/cobra"
1718
)
1819

@@ -155,7 +156,7 @@ func downloadModelsOnlyIfNotFound(desktopClient *desktop.Client, models []string
155156
}
156157
return false
157158
}) {
158-
_, _, err = desktopClient.Pull(model, func(s string) {
159+
_, _, err = desktopClient.Pull(model, false, func(s string) {
159160
_ = sendInfo(s)
160161
})
161162
if err != nil {

commands/pull.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ import (
1111
)
1212

1313
func newPullCmd() *cobra.Command {
14+
var ignoreRuntimeMemoryCheck bool
15+
1416
c := &cobra.Command{
1517
Use: "pull MODEL",
1618
Short: "Pull a model from Docker Hub or HuggingFace to your local environment",
1719
Args: func(cmd *cobra.Command, args []string) error {
1820
if len(args) != 1 {
1921
return fmt.Errorf(
20-
"'docker model run' requires 1 argument.\n\n" +
22+
"'docker model pull' requires 1 argument.\n\n" +
2123
"Usage: docker model pull MODEL\n\n" +
2224
"See 'docker model pull --help' for more information",
2325
)
@@ -28,21 +30,24 @@ func newPullCmd() *cobra.Command {
2830
if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), cmd); err != nil {
2931
return fmt.Errorf("unable to initialize standalone model runner: %w", err)
3032
}
31-
return pullModel(cmd, desktopClient, args[0])
33+
return pullModel(cmd, desktopClient, args[0], ignoreRuntimeMemoryCheck)
3234
},
3335
ValidArgsFunction: completion.NoComplete,
3436
}
37+
38+
c.Flags().BoolVar(&ignoreRuntimeMemoryCheck, "ignore-runtime-memory-check", false, "Do not block pull if estimated runtime memory for model exceeds system resources.")
39+
3540
return c
3641
}
3742

38-
func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error {
43+
func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, ignoreRuntimeMemoryCheck bool) error {
3944
var progress func(string)
4045
if isatty.IsTerminal(os.Stdout.Fd()) {
4146
progress = TUIProgress
4247
} else {
4348
progress = RawProgress
4449
}
45-
response, progressShown, err := desktopClient.Pull(model, progress)
50+
response, progressShown, err := desktopClient.Pull(model, ignoreRuntimeMemoryCheck, progress)
4651

4752
// Add a newline before any output (success or error) if progress was shown.
4853
if progressShown {

commands/run.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ func readMultilineInput(cmd *cobra.Command, scanner *bufio.Scanner) (string, err
8080
func newRunCmd() *cobra.Command {
8181
var debug bool
8282
var backend string
83+
var ignoreRuntimeMemoryCheck bool
8384

8485
const cmdArgs = "MODEL [PROMPT]"
8586
c := &cobra.Command{
@@ -124,7 +125,7 @@ func newRunCmd() *cobra.Command {
124125
return handleNotRunningError(handleClientError(err, "Failed to inspect model"))
125126
}
126127
cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.")
127-
if err := pullModel(cmd, desktopClient, model); err != nil {
128+
if err := pullModel(cmd, desktopClient, model, ignoreRuntimeMemoryCheck); err != nil {
128129
return err
129130
}
130131
}
@@ -188,6 +189,7 @@ func newRunCmd() *cobra.Command {
188189
c.Flags().BoolVar(&debug, "debug", false, "Enable debug logging")
189190
c.Flags().StringVar(&backend, "backend", "", fmt.Sprintf("Specify the backend to use (%s)", ValidBackendsKeys()))
190191
c.Flags().MarkHidden("backend")
192+
c.Flags().BoolVar(&ignoreRuntimeMemoryCheck, "ignore-runtime-memory-check", false, "Do not block pull if estimated runtime memory for model exceeds system resources.")
191193

192194
return c
193195
}

desktop/desktop.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ func (c *Client) Status() Status {
106106
}
107107
}
108108

109-
func (c *Client) Pull(model string, progress func(string)) (string, bool, error) {
109+
func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) {
110110
model = normalizeHuggingFaceModelName(model)
111-
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model})
111+
jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck})
112112
if err != nil {
113113
return "", false, fmt.Errorf("error marshaling request: %w", err)
114114
}

desktop/desktop_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func TestPullHuggingFaceModel(t *testing.T) {
3636
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
3737
}, nil)
3838

39-
_, _, err := client.Pull(modelName, func(s string) {})
39+
_, _, err := client.Pull(modelName, false, func(s string) {})
4040
assert.NoError(t, err)
4141
}
4242

@@ -122,7 +122,7 @@ func TestNonHuggingFaceModel(t *testing.T) {
122122
Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)),
123123
}, nil)
124124

125-
_, _, err := client.Pull(modelName, func(s string) {})
125+
_, _, err := client.Pull(modelName, false, func(s string) {})
126126
assert.NoError(t, err)
127127
}
128128

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ require (
1111
github.com/docker/docker v28.2.2+incompatible
1212
github.com/docker/go-connections v0.5.0
1313
github.com/docker/go-units v0.5.0
14-
github.com/docker/model-distribution v0.0.0-20250724114133-a11d745e582c
15-
github.com/docker/model-runner v0.0.0-20250724122432-ecfa5e7e6807
14+
github.com/docker/model-distribution v0.0.0-20250811072316-8ae9665e6889
15+
github.com/docker/model-runner v0.0.0-20250819142513-e761a7751875
1616
github.com/fatih/color v1.15.0
1717
github.com/google/go-containerregistry v0.20.6
1818
github.com/mattn/go-isatty v0.0.20

go.sum

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ github.com/docker/go-metrics v0.0.1/go.mod h1:cG1hvH2utMXtqgqqYE9plW6lDxS3/5ayHz
7878
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
7979
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
8080
github.com/docker/libtrust v0.0.0-20160708172513-aabc10ec26b7/go.mod h1:cyGadeNEkKy96OOhEzfZl+yxihPEzKnqJwvfuSUqbZE=
81-
github.com/docker/model-distribution v0.0.0-20250724114133-a11d745e582c h1:w9MekYamXmWLe9ZWXWgNXJ7BLDDemXwB8WcF7wzHF5Q=
82-
github.com/docker/model-distribution v0.0.0-20250724114133-a11d745e582c/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
83-
github.com/docker/model-runner v0.0.0-20250724122432-ecfa5e7e6807 h1:02vImD8wqUDv6VJ2cBLbqzbjn17IMYEi4ileCEjXMQ8=
84-
github.com/docker/model-runner v0.0.0-20250724122432-ecfa5e7e6807/go.mod h1:rCzRjRXJ42E8JVIA69E9hErJVV5mnUpWdJ2POsktfRs=
81+
github.com/docker/model-distribution v0.0.0-20250811072316-8ae9665e6889 h1:O1m0yG1N2t6qV8MWT/pe9Z9ukaV0+BX27gg8fsFXDKk=
82+
github.com/docker/model-distribution v0.0.0-20250811072316-8ae9665e6889/go.mod h1:dThpO9JoG5Px3i+rTluAeZcqLGw8C0qepuEL4gL2o/c=
83+
github.com/docker/model-runner v0.0.0-20250819142513-e761a7751875 h1:ERaUmjdswQZ0rNHhuusvlXY+ueKsFIdOsBdqXeiEtY0=
84+
github.com/docker/model-runner v0.0.0-20250819142513-e761a7751875/go.mod h1:XtmhY0MoCQ+YLBjwTBxRWxq3kil/VXD+r5xCdWQUIkY=
8585
github.com/dvsekhvalnov/jose2go v0.0.0-20170216131308-f21a8cedbbae/go.mod h1:7BvyPhdbLxMXIYTFPLsyJRFMsKmOZnQmzh6Gb+uquuM=
8686
github.com/elastic/go-sysinfo v1.15.3 h1:W+RnmhKFkqPTCRoFq2VCTmsT4p/fwpo+3gKNQsn1XU0=
8787
github.com/elastic/go-sysinfo v1.15.3/go.mod h1:K/cNrqYTDrSoMh2oDkYEMS2+a72GRxMvNP+GC+vRIlo=

vendor/github.com/docker/model-distribution/distribution/client.go

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vendor/github.com/docker/model-distribution/registry/client.go

Lines changed: 46 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vendor/github.com/docker/model-runner/pkg/inference/backend.go

Lines changed: 6 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)