-
Notifications
You must be signed in to change notification settings - Fork 543
Initial import as a separate torch_xla extension #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "third_party/tensorflow"] | ||
path = third_party/tensorflow | ||
url = https://github.com/tensorflow/tensorflow.git |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
Copyright (c) 2018 Google Inc. | ||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
1. Redistributions of source code must retain the above copyright | ||
notice, this list of conditions and the following disclaimer. | ||
|
||
2. Redistributions in binary form must reproduce the above copyright | ||
notice, this list of conditions and the following disclaimer in the | ||
documentation and/or other materials provided with the distribution. | ||
|
||
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America | ||
and IDIAP Research Institute nor the names of its contributors may be | ||
used to endorse or promote products derived from this software without | ||
specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | ||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||
POSSIBILITY OF SUCH DAMAGE. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# How To Build And Run PyTorch For TPU | ||
|
||
To build: | ||
|
||
* Build PyTorch from source, following the regular [instructions](https://github.com/pytorch/pytorch#from-source). | ||
* Clone this repository in the root folder of the PyTorch sources used for the previous step. | ||
Run `git submodule update --init` to get the third-party dependencies and `python setup.py install` to build and install the extension. | ||
|
||
To run the tests, follow __one__ of the options below: | ||
|
||
* Run on CPU using the local client: | ||
|
||
`export XLA_USE_XRT=0 export XLA_GRPC_HOST="" XLA_PLATFORM="CPU"` | ||
|
||
* Run on CPU using the XRT client: | ||
|
||
`export XLA_USE_XRT=1 XRT_DEVICE_MAP="CPU:0;/job:localhost/replica:0/task:0/device:XLA_CPU:0" XRT_WORKERS="localhost:0;"` | ||
|
||
* Run on TPU using the XRT client: | ||
|
||
`export XLA_USE_XRT=1 XRT_DEVICE_MAP="TPU:0;/job:tpu_worker/replica:0/task:0/device:TPU:0" XRT_WORKERS="tpu_worker:0;grpc://localhost:51000"`. Specify the TPU node by doing __one__ of the following: | ||
|
||
- create a `$HOME/.pytorch_tpu.conf` file with the following content: `worker: tpu_worker <ip of the tpu node>:8470` | ||
|
||
- set the `XRT_TPU_CONFIG` environment variable: `export XRT_TPU_CONFIG="tpu_worker;0;<ip of the tpu node>:8470"`. | ||
|
||
|
||
|
||
Then run `python test/test_operations.py`. Some of the tests are currently skipped. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -ex | ||
|
||
cd "$(dirname "$0")" | ||
PWD=`printf "%q\n" "$(pwd)"` | ||
BASE_DIR="$PWD" | ||
echo $BASE_DIR | ||
THIRD_PARTY_DIR="$BASE_DIR/third_party" | ||
|
||
cp -r -f $THIRD_PARTY_DIR/xla_client $THIRD_PARTY_DIR/tensorflow/tensorflow/compiler/xla/ | ||
|
||
pushd $THIRD_PARTY_DIR/tensorflow | ||
git reset --hard | ||
git clean -f | ||
bazel build -c opt //tensorflow/compiler/xla/xla_client:libxla_computation_client.so | ||
popd | ||
|
||
mkdir -p torch_xla/lib | ||
chmod 0644 $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/compiler/xla/xla_client/libxla_computation_client.so | ||
cp $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/compiler/xla/xla_client/libxla_computation_client.so torch_xla/lib | ||
chmod 0644 $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so | ||
cp $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so torch_xla/lib |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
#!/usr/bin/env python | ||
|
||
from setuptools import setup, find_packages | ||
from torch.utils.cpp_extension import BuildExtension, CppExtension | ||
import os | ||
import platform | ||
import subprocess | ||
import sys | ||
|
||
|
||
def _check_env_flag(name, default=''): | ||
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] | ||
|
||
torch_xla_sources = [ | ||
'torch_xla/csrc/batch_norm.cpp', | ||
'torch_xla/csrc/convolution.cpp', | ||
'torch_xla/csrc/cross_replica_reduces.cpp', | ||
'torch_xla/csrc/data_ops.cpp', | ||
'torch_xla/csrc/elementwise.cpp', | ||
'torch_xla/csrc/graph_context.cpp', | ||
'torch_xla/csrc/helpers.cpp', | ||
'torch_xla/csrc/init_python_bindings.cpp', | ||
'torch_xla/csrc/log_softmax.cpp', | ||
'torch_xla/csrc/module.cpp', | ||
'torch_xla/csrc/nll_loss.cpp', | ||
'torch_xla/csrc/pooling.cpp', | ||
'torch_xla/csrc/reduction.cpp', | ||
'torch_xla/csrc/tensor.cpp', | ||
'torch_xla/csrc/torch_util.cpp', | ||
'torch_xla/csrc/translator.cpp', | ||
'torch_xla/csrc/passes/eval_static_size.cpp', | ||
'torch_xla/csrc/passes/remove_unused_forward_outputs.cpp', | ||
'torch_xla/csrc/passes/replace_untraced_operators.cpp', | ||
'torch_xla/csrc/passes/threshold_backward_peephole.cpp', | ||
] | ||
|
||
build_libs_cmd = './build_torch_xla_libs.sh' | ||
|
||
if subprocess.call(build_libs_cmd) != 0: | ||
print("Failed to run '{}'".format(build_libs_cmd)) | ||
sys.exit(1) | ||
|
||
# Constant known variables used throughout this file | ||
cwd = os.path.dirname(os.path.abspath(__file__)) | ||
lib_path = os.path.join(cwd, 'torch_xla', 'lib') | ||
pytorch_source_path = os.getenv('PYTORCH_SOURCE_PATH', '..') | ||
third_party_path = os.path.join(cwd, 'third_party') | ||
|
||
include_dirs = [ | ||
third_party_path + '/tensorflow/bazel-tensorflow', | ||
third_party_path + '/tensorflow/bazel-genfiles', | ||
third_party_path + '/tensorflow/bazel-tensorflow/external/protobuf_archive/src', | ||
third_party_path + '/tensorflow/bazel-tensorflow/external/eigen_archive', | ||
third_party_path + '/tensorflow/bazel-tensorflow/external/com_google_absl', | ||
] | ||
include_dirs += [ | ||
pytorch_source_path, | ||
os.path.join(pytorch_source_path, 'torch', 'csrc'), | ||
os.path.join(pytorch_source_path, 'torch', 'lib', 'tmp_install', 'include'), | ||
] | ||
|
||
library_dirs = [] | ||
library_dirs.append(lib_path) | ||
|
||
extra_link_args = [] | ||
|
||
DEBUG = _check_env_flag('DEBUG') | ||
IS_WINDOWS = (platform.system() == 'Windows') | ||
IS_DARWIN = (platform.system() == 'Darwin') | ||
IS_LINUX = (platform.system() == 'Linux') | ||
|
||
|
||
def make_relative_rpath(path): | ||
if IS_DARWIN: | ||
return '-Wl,-rpath,@loader_path/' + path | ||
elif IS_WINDOWS: | ||
return '' | ||
else: | ||
return '-Wl,-rpath,$ORIGIN/' + path | ||
asuhan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
extra_compile_args = [] | ||
|
||
if DEBUG: | ||
if IS_WINDOWS: | ||
extra_link_args.append('/DEBUG:FULL') | ||
else: | ||
extra_compile_args += ['-O0', '-g'] | ||
extra_link_args += ['-O0', '-g'] | ||
|
||
extra_link_args += ['-lxla_computation_client'] | ||
|
||
setup( | ||
name='torch_xla', | ||
version='0.1', | ||
description='XLA bridge for PyTorch', | ||
url='https://github.com/pytorch/xla', | ||
author='Alex Suhan, Davide Libenzi', | ||
author_email='[email protected], [email protected]' | ||
# Exclude the build files. | ||
packages=find_packages(exclude=['build']), | ||
ext_modules=[ | ||
CppExtension( | ||
'_C', | ||
torch_xla_sources, | ||
include_dirs=include_dirs, | ||
extra_compile_args=extra_compile_args, | ||
library_dirs=library_dirs, | ||
extra_link_args=extra_link_args + [make_relative_rpath('torch_xla/lib')], | ||
), | ||
], | ||
package_data={ | ||
'torch_xla': [ | ||
'lib/*.so*', | ||
] | ||
}, | ||
cmdclass={'build_ext': BuildExtension}) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.