Skip to content

Commit ab6fb9b

Browse files
authored
bump ao to 0.5.0 to enable torchtune across platforms (#1136)
* update install_requirement.sh * bring torchao back to all platform * add essential comment * enable torchtune on macos * reformat
1 parent 713b64c commit ab6fb9b

File tree

3 files changed

+9
-41
lines changed

3 files changed

+9
-41
lines changed

install/install_requirements.sh

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ PYTORCH_NIGHTLY_VERSION=dev20240814
5252
# Nightly version for torchvision
5353
VISION_NIGHTLY_VERSION=dev20240814
5454

55-
# Nightly version for torchao
56-
AO_NIGHTLY_VERSION=dev20240905
57-
5855
# Nightly version for torchtune
5956
TUNE_NIGHTLY_VERSION=dev20240910
6057

@@ -79,10 +76,6 @@ fi
7976
REQUIREMENTS_TO_INSTALL=(
8077
torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}"
8178
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
82-
)
83-
84-
LINUX_REQUIREMENTS_TO_INSTALL=(
85-
torchao=="0.5.0.${AO_NIGHTLY_VERSION}"
8679
torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}"
8780
)
8881

@@ -94,27 +87,10 @@ LINUX_REQUIREMENTS_TO_INSTALL=(
9487
"${REQUIREMENTS_TO_INSTALL[@]}"
9588
)
9689

97-
PLATFORM=$(uname -s)
98-
99-
# Install torchtune and torchao requirements for Linux systems using nightly.
100-
# For non-Linux systems (e.g., macOS), install torchao from GitHub since nightly
101-
# build doesn't have macOS build.
102-
# TODO: Remove this and install nightly build, once it supports macOS
103-
if [ "$PLATFORM" == "Linux" ];
104-
then
105-
(
106-
set -x
107-
$PIP_EXECUTABLE install --pre --extra-index-url "${TORCH_NIGHTLY_URL}" --no-cache-dir \
108-
"${LINUX_REQUIREMENTS_TO_INSTALL[@]}"
109-
)
110-
else
111-
# For torchao need to install from github since nightly build doesn't have macos build.
112-
# TODO: Remove this and install nightly build, once it supports macos
113-
(
114-
set -x
115-
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@e11201a62669f582d81cdb33e031a07fb8dfc4f3
116-
)
117-
fi
90+
(
91+
set -x
92+
$PIP_EXECUTABLE install torchao=="0.5.0"
93+
)
11894

11995
if [[ -x "$(command -v nvidia-smi)" ]]; then
12096
(

torchchat/cli/builder.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,8 @@
3535
from torchchat.utils.measure_time import measure_time
3636
from torchchat.utils.quantize import quantize_model
3737

38-
# bypass the import issue before torchao is ready on macos
39-
try:
40-
from torchtune.models.convert_weights import meta_to_tune
41-
except:
42-
pass
38+
from torchtune.models.convert_weights import meta_to_tune
39+
4340

4441

4542

torchchat/model.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,9 @@
3030

3131
from torchchat.utils.build_utils import find_multiple, get_precision
3232

33-
# bypass the import issue, if any
34-
# TODO: remove this once the torchao is ready on macos
35-
try:
36-
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
37-
from torchtune.modules.model_fusion import DeepFusionModel
38-
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
39-
except:
40-
pass
33+
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
34+
from torchtune.modules.model_fusion import DeepFusionModel
35+
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
4136

4237
config_path = Path(f"{str(Path(__file__).parent)}/model_params")
4338

0 commit comments

Comments
 (0)