|
6 | 6 |
|
7 | 7 | from setuptools import find_packages, setup |
8 | 8 |
|
9 | | -# Remove 2x content in "__init__.py" when only 3x is installed and recover it when 2x is installed |
10 | | -content_position = ("neural_compressor/__init__.py", 20) # file path and line number |
11 | | -backup_content = """from .config import ( |
12 | | - DistillationConfig, |
13 | | - PostTrainingQuantConfig, |
14 | | - WeightPruningConfig, |
15 | | - QuantizationAwareTrainingConfig, |
16 | | - MixedPrecisionConfig, |
17 | | -) |
18 | | -from .contrib import * |
19 | | -from .model import * |
20 | | -from .metric import * |
21 | | -from .utils import options |
22 | | -from .utils.utility import set_random_seed, set_tensorboard, set_workspace, set_resume_from |
23 | | -""" |
24 | | - |
25 | | - |
26 | | -def delete_lines_from_file(file_path, start_line): |
27 | | - """Deletes all lines from the specified start_line to the end of the file.""" |
28 | | - with open(file_path, "r") as file: |
29 | | - lines = file.readlines() |
30 | | - |
31 | | - # Keep only lines before the start_line |
32 | | - lines = lines[: start_line - 1] |
33 | | - |
34 | | - with open(file_path, "w") as file: |
35 | | - file.writelines(lines) |
36 | | - |
37 | | - |
38 | | -def replace_lines_from_file(file_path, start_line, replacement_content): |
39 | | - """Replaces all lines from the specified start_line to the end of the file with replacement_content.""" |
40 | | - with open(file_path, "r") as file: |
41 | | - lines = file.readlines() |
42 | | - |
43 | | - # Keep lines before the start_line and append replacement_content |
44 | | - lines = lines[: start_line - 1] |
45 | | - lines.append(replacement_content) |
46 | | - |
47 | | - with open(file_path, "w") as file: |
48 | | - file.writelines(lines) |
49 | | - |
50 | 9 |
|
51 | 10 | def fetch_requirements(path): |
52 | 11 | with open(path, "r") as fd: |
@@ -147,13 +106,10 @@ def get_build_version(): |
147 | 106 | if "pt" in sys.argv: |
148 | 107 | sys.argv.remove("pt") |
149 | 108 | cfg_key = "neural_compressor_pt" |
150 | | - delete_lines_from_file(*content_position) |
151 | | - elif "tf" in sys.argv: |
| 109 | + |
| 110 | + if "tf" in sys.argv: |
152 | 111 | sys.argv.remove("tf") |
153 | 112 | cfg_key = "neural_compressor_tf" |
154 | | - delete_lines_from_file(*content_position) |
155 | | - else: |
156 | | - replace_lines_from_file(*content_position, backup_content) |
157 | 113 |
|
158 | 114 | project_name = PKG_INSTALL_CFG[cfg_key].get("project_name") |
159 | 115 | include_packages = PKG_INSTALL_CFG[cfg_key].get("include_packages") or {} |
|
0 commit comments