Skip to content

Commit 3e50c42

Browse files
mikekgfbmalfet
authored andcommitted
arg handling (pytorch#292)
* arg handling * phase ordering issue resolved
1 parent 57a7964 commit 3e50c42

File tree

2 files changed

+17
-37
lines changed

2 files changed

+17
-37
lines changed

build/builder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,19 @@ def __post_init__(self):
6969

7070
@classmethod
7171
def from_args(cls, args): # -> BuilderArgs:
72+
73+
# Handle disabled checkpoint_dir option
74+
checkpoint_dir = None
75+
if hasattr(args, "checkpoint_dir"):
76+
checkpoint_dir = args.checkpoint_dir
77+
7278
is_chat_model = False
7379
if args.is_chat_model:
7480
is_chat_model = True
7581
else:
7682
for path in [
7783
args.checkpoint_path,
78-
args.checkpoint_dir,
84+
checkpoint_dir,
7985
args.dso_path,
8086
args.pte_path,
8187
args.gguf_path,
@@ -89,7 +95,7 @@ def from_args(cls, args): # -> BuilderArgs:
8995

9096
return cls(
9197
checkpoint_path=args.checkpoint_path,
92-
checkpoint_dir=args.checkpoint_dir,
98+
checkpoint_dir=checkpoint_dir,
9399
params_path=args.params_path,
94100
params_table=args.params_table,
95101
gguf_path=args.gguf_path,

cli.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,10 @@
99

1010
import torch
1111

12+
default_device = "cpu"
1213

13-
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
14-
15-
strict = False
16-
17-
18-
def check_args(args, command_name: str):
19-
global strict
20-
21-
# chat and generate support the same options
22-
if command_name in ["generate", "chat", "gui"]:
23-
# examples, can add more. Note that attributes convert dash to _
24-
disallowed_args = ["output_pte_path", "output_dso_path"]
25-
elif command_name == "export":
26-
# examples, can add more. Note that attributes convert dash to _
27-
disallowed_args = ["pte_path", "dso_path"]
28-
elif command_name == "eval":
29-
# TBD
30-
disallowed_args = []
31-
else:
32-
raise RuntimeError(f"{command_name} is not a valid command")
33-
34-
for disallowed in disallowed_args:
35-
if hasattr(args, disallowed):
36-
text = f"command {command_name} does not support option {disallowed.replace('_', '-')}"
37-
if strict:
38-
raise RuntimeError(text)
39-
else:
40-
print(f"Warning: {text}")
41-
14+
def check_args(args, name: str) -> None:
15+
pass
4216

4317
def add_arguments_for_generate(parser):
4418
# Only generate specific options should be here
@@ -123,12 +97,12 @@ def _add_arguments_common(parser):
12397
default="not_specified",
12498
help="Model checkpoint path.",
12599
)
126-
parser.add_argument(
127-
"--checkpoint-dir",
128-
type=Path,
129-
default=None,
130-
help="Model checkpoint directory.",
131-
)
100+
# parser.add_argument(
101+
# "--checkpoint-dir",
102+
# type=Path,
103+
# default=None,
104+
# help="Model checkpoint directory.",
105+
# )
132106
parser.add_argument(
133107
"--params-path",
134108
type=Path,

0 commit comments

Comments
 (0)