|
6 | 6 |
|
7 | 7 | from setuptools import find_packages, setup
|
8 | 8 |
|
| 9 | +# Remove 2x content in "__init__.py" when only 3x is installed and recover it when 2x is installed |
| 10 | +content_position = ("neural_compressor/__init__.py", 20) # file path and line number |
| 11 | +backup_content = """from .config import ( |
| 12 | + DistillationConfig, |
| 13 | + PostTrainingQuantConfig, |
| 14 | + WeightPruningConfig, |
| 15 | + QuantizationAwareTrainingConfig, |
| 16 | + MixedPrecisionConfig, |
| 17 | +) |
| 18 | +from .contrib import * |
| 19 | +from .model import * |
| 20 | +from .metric import * |
| 21 | +from .utils import options |
| 22 | +from .utils.utility import set_random_seed, set_tensorboard, set_workspace, set_resume_from |
| 23 | +""" |
| 24 | + |
| 25 | + |
| 26 | +def delete_lines_from_file(file_path, start_line): |
| 27 | + """Deletes all lines from the specified start_line to the end of the file.""" |
| 28 | + with open(file_path, "r") as file: |
| 29 | + lines = file.readlines() |
| 30 | + |
| 31 | + # Keep only lines before the start_line |
| 32 | + lines = lines[: start_line - 1] |
| 33 | + |
| 34 | + with open(file_path, "w") as file: |
| 35 | + file.writelines(lines) |
| 36 | + |
| 37 | + |
| 38 | +def replace_lines_from_file(file_path, start_line, replacement_content): |
| 39 | + """Replaces all lines from the specified start_line to the end of the file with replacement_content.""" |
| 40 | + with open(file_path, "r") as file: |
| 41 | + lines = file.readlines() |
| 42 | + |
| 43 | + # Keep lines before the start_line and append replacement_content |
| 44 | + lines = lines[: start_line - 1] |
| 45 | + lines.append(replacement_content) |
| 46 | + |
| 47 | + with open(file_path, "w") as file: |
| 48 | + file.writelines(lines) |
| 49 | + |
9 | 50 |
|
10 | 51 | def fetch_requirements(path):
|
11 | 52 | with open(path, "r") as fd:
|
@@ -106,10 +147,13 @@ def get_build_version():
|
106 | 147 | if "pt" in sys.argv:
|
107 | 148 | sys.argv.remove("pt")
|
108 | 149 | cfg_key = "neural_compressor_pt"
|
109 |
| - |
110 |
| - if "tf" in sys.argv: |
| 150 | + delete_lines_from_file(*content_position) |
| 151 | + elif "tf" in sys.argv: |
111 | 152 | sys.argv.remove("tf")
|
112 | 153 | cfg_key = "neural_compressor_tf"
|
| 154 | + delete_lines_from_file(*content_position) |
| 155 | + else: |
| 156 | + replace_lines_from_file(*content_position, backup_content) |
113 | 157 |
|
114 | 158 | project_name = PKG_INSTALL_CFG[cfg_key].get("project_name")
|
115 | 159 | include_packages = PKG_INSTALL_CFG[cfg_key].get("include_packages") or {}
|
|
0 commit comments