Skip to content

Commit 13fb0a9

Browse files
Naman Goyalfacebook-github-bot
authored andcommitted
installing numpy headers for cython
Summary: Pull Request resolved: fairinternal/fairseq-py#848 Differential Revision: D17060283 fbshipit-source-id: c7e61cae76a0566cc3e2ddc3ab4d48f8dec9d777
1 parent 6548239 commit 13fb0a9

3 files changed

Lines changed: 57 additions & 12 deletions

File tree

fairseq/data/data_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
import sys
1616
import types
1717

18-
from fairseq.data.data_utils_fast import batch_by_size_fast
19-
2018

2119
def infer_language_pair(path):
2220
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
@@ -200,6 +198,12 @@ def batch_by_size(
200198
required_batch_size_multiple (int, optional): require batch size to
201199
be a multiple of N (default: 1).
202200
"""
201+
try:
202+
from fairseq.data.data_utils_fast import batch_by_size_fast
203+
except ImportError:
204+
raise ImportError(
205+
'Please build Cython components with: `pip install --editable .`'
206+
)
203207
max_tokens = max_tokens if max_tokens is not None else sys.maxsize
204208
max_sentences = max_sentences if max_sentences is not None else sys.maxsize
205209
bsz_mult = required_batch_size_multiple

fairseq/data/token_block_dataset.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@
66
import numpy as np
77
import torch
88

9-
from fairseq.data.token_block_utils_fast import (
10-
_get_slice_indices_fast,
11-
_get_block_to_dataset_index_fast,
12-
)
13-
149
from fairseq.data import FairseqDataset, plasma_utils
1510

1611

@@ -47,6 +42,16 @@ def __init__(
4742
include_targets=False,
4843
document_sep_len=1,
4944
):
45+
try:
46+
from fairseq.data.token_block_utils_fast import (
47+
_get_slice_indices_fast,
48+
_get_block_to_dataset_index_fast,
49+
)
50+
except ImportError:
51+
raise ImportError(
52+
'Please build Cython components with: `pip install --editable .`'
53+
)
54+
5055
super().__init__()
5156
self.dataset = dataset
5257
self.pad = pad

setup.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
readme = f.read()
1616

1717
if sys.platform == 'darwin':
18-
extra_compile_args = ['-stdlib=libc++']
18+
extra_compile_args = ['-stdlib=libc++', '-O3']
19+
extra_link_args = ['-stdlib=libc++']
1920
else:
20-
extra_compile_args = ['-std=c++11']
21+
extra_compile_args = ['-std=c++11', '-O3']
22+
extra_link_args = ['-std=c++11']
23+
2124
bleu = Extension(
2225
'fairseq.libbleu',
2326
sources=[
@@ -27,8 +30,39 @@
2730
extra_compile_args=extra_compile_args,
2831
)
2932

30-
token_block_utils = [Extension("fairseq.data.token_block_utils_fast", ["fairseq/data/token_block_utils_fast.pyx"])]
31-
data_utils_fast = [Extension("fairseq.data.data_utils_fast", ["fairseq/data/data_utils_fast.pyx"], language="c++")]
33+
34+
def get_cython_modules():
35+
token_block_utils = Extension(
36+
"fairseq.data.token_block_utils_fast",
37+
["fairseq/data/token_block_utils_fast.pyx"],
38+
extra_compile_args=extra_compile_args,
39+
extra_link_args=extra_link_args,
40+
)
41+
data_utils_fast = Extension(
42+
"fairseq.data.data_utils_fast",
43+
["fairseq/data/data_utils_fast.pyx"],
44+
language="c++",
45+
extra_compile_args=extra_compile_args,
46+
extra_link_args=extra_link_args,
47+
)
48+
return [token_block_utils, data_utils_fast]
49+
50+
51+
def my_build_ext(pars):
52+
"""
53+
Delay loading of numpy headers.
54+
More details: https://stackoverflow.com/questions/54117786/add-numpy-get-include-argument-to-setuptools-without-preinstalled-numpy
55+
"""
56+
from setuptools.command.build_ext import build_ext as _build_ext
57+
58+
class build_ext(_build_ext):
59+
def finalize_options(self):
60+
_build_ext.finalize_options(self)
61+
__builtins__.__NUMPY_SETUP__ = False
62+
import numpy
63+
self.include_dirs.append(numpy.get_include())
64+
return build_ext(pars)
65+
3266

3367
setup(
3468
name='fairseq',
@@ -45,6 +79,7 @@
4579
long_description=readme,
4680
long_description_content_type='text/markdown',
4781
setup_requires=[
82+
'numpy',
4883
'cython',
4984
'setuptools>=18.0',
5085
],
@@ -58,7 +93,7 @@
5893
'tqdm',
5994
],
6095
packages=find_packages(exclude=['scripts', 'tests']),
61-
ext_modules=token_block_utils + data_utils_fast + [bleu],
96+
ext_modules=get_cython_modules() + [bleu],
6297
test_suite='tests',
6398
entry_points={
6499
'console_scripts': [
@@ -71,5 +106,6 @@
71106
'fairseq-validate = fairseq_cli.validate:cli_main',
72107
],
73108
},
109+
cmdclass={'build_ext': my_build_ext},
74110
zip_safe=False,
75111
)

0 commit comments

Comments
 (0)