Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion neural_compressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""Intel® Neural Compressor: An open-source Python library supporting popular model compression techniques."""
from .version import __version__

# we need to set a global 'NA' backend, or Model can't be used
from .config import (
DistillationConfig,
PostTrainingQuantConfig,
Expand Down
48 changes: 46 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,47 @@

from setuptools import find_packages, setup

# Remove 2x content in "__init__.py" when only 3x is installed and recover it when 2x is installed
content_position = ("neural_compressor/__init__.py", 20) # file path and line number
backup_content = """from .config import (
DistillationConfig,
PostTrainingQuantConfig,
WeightPruningConfig,
QuantizationAwareTrainingConfig,
MixedPrecisionConfig,
)
from .contrib import *
from .model import *
from .metric import *
from .utils import options
from .utils.utility import set_random_seed, set_tensorboard, set_workspace, set_resume_from
"""


def delete_lines_from_file(file_path, start_line):
"""Deletes all lines from the specified start_line to the end of the file."""
with open(file_path, "r") as file:
lines = file.readlines()

# Keep only lines before the start_line
lines = lines[: start_line - 1]

with open(file_path, "w") as file:
file.writelines(lines)


def replace_lines_from_file(file_path, start_line, replacement_content):
"""Replaces all lines from the specified start_line to the end of the file with replacement_content."""
with open(file_path, "r") as file:
lines = file.readlines()

# Keep lines before the start_line and append replacement_content
lines = lines[: start_line - 1]
lines.append(replacement_content)

with open(file_path, "w") as file:
file.writelines(lines)


def fetch_requirements(path):
with open(path, "r") as fd:
Expand Down Expand Up @@ -106,10 +147,13 @@ def get_build_version():
if "pt" in sys.argv:
sys.argv.remove("pt")
cfg_key = "neural_compressor_pt"

if "tf" in sys.argv:
delete_lines_from_file(*content_position)
elif "tf" in sys.argv:
sys.argv.remove("tf")
cfg_key = "neural_compressor_tf"
delete_lines_from_file(*content_position)
else:
replace_lines_from_file(*content_position, backup_content)

project_name = PKG_INSTALL_CFG[cfg_key].get("project_name")
include_packages = PKG_INSTALL_CFG[cfg_key].get("include_packages") or {}
Expand Down
Loading