Skip to content

Commit f320844

Browse files
authored
Merge pull request #1 from pytorch/asuhan/initial_import
Initial import as a separate torch_xla extension
2 parents f1a5910 + d03f9b6 commit f320844

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+6879
-0
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "third_party/tensorflow"]
2+
path = third_party/tensorflow
3+
url = https://github.com/tensorflow/tensorflow.git

LICENSE

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
Copyright (c) 2018 Google Inc.
2+
All rights reserved.
3+
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are met:
6+
7+
1. Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
10+
2. Redistributions in binary form must reproduce the above copyright
11+
notice, this list of conditions and the following disclaimer in the
12+
documentation and/or other materials provided with the distribution.
13+
14+
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
15+
and IDIAP Research Institute nor the names of its contributors may be
16+
used to endorse or promote products derived from this software without
17+
specific prior written permission.
18+
19+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
23+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26+
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28+
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29+
POSSIBILITY OF SUCH DAMAGE.

README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# How To Build And Run PyTorch For TPU
2+
3+
To build:
4+
5+
* Build PyTorch from source, following the regular [instructions](https://github.com/pytorch/pytorch#from-source).
6+
* Clone this repository in the root folder of the PyTorch sources used for the previous step.
7+
Run `git submodule update --init` to get the third-party dependencies and `python setup.py install` to build and install the extension.
8+
9+
To run the tests, follow __one__ of the options below:
10+
11+
* Run on CPU using the local client:
12+
13+
`export XLA_USE_XRT=0 export XLA_GRPC_HOST="" XLA_PLATFORM="CPU"`
14+
15+
* Run on CPU using the XRT client:
16+
17+
`export XLA_USE_XRT=1 XRT_DEVICE_MAP="CPU:0;/job:localhost/replica:0/task:0/device:XLA_CPU:0" XRT_WORKERS="localhost:0;"`
18+
19+
* Run on TPU using the XRT client:
20+
21+
`export XLA_USE_XRT=1 XRT_DEVICE_MAP="TPU:0;/job:tpu_worker/replica:0/task:0/device:TPU:0" XRT_WORKERS="tpu_worker:0;grpc://localhost:51000"`. Specify the TPU node by doing __one__ of the following:
22+
23+
- create a `$HOME/.pytorch_tpu.conf` file with the following content: `worker: tpu_worker <ip of the tpu node>:8470`
24+
25+
- set the `XRT_TPU_CONFIG` environment variable: `export XRT_TPU_CONFIG="tpu_worker;0;<ip of the tpu node>:8470"`.
26+
27+
28+
29+
Then run `python test/test_operations.py`. Some of the tests are currently skipped.

build_torch_xla_libs.sh

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/usr/bin/env bash
2+
3+
set -ex
4+
5+
cd "$(dirname "$0")"
6+
PWD=`printf "%q\n" "$(pwd)"`
7+
BASE_DIR="$PWD"
8+
echo $BASE_DIR
9+
THIRD_PARTY_DIR="$BASE_DIR/third_party"
10+
11+
cp -r -f $THIRD_PARTY_DIR/xla_client $THIRD_PARTY_DIR/tensorflow/tensorflow/compiler/xla/
12+
13+
pushd $THIRD_PARTY_DIR/tensorflow
14+
git reset --hard
15+
git clean -f
16+
bazel build -c opt //tensorflow/compiler/xla/xla_client:libxla_computation_client.so
17+
popd
18+
19+
mkdir -p torch_xla/lib
20+
chmod 0644 $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/compiler/xla/xla_client/libxla_computation_client.so
21+
cp $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/compiler/xla/xla_client/libxla_computation_client.so torch_xla/lib
22+
chmod 0644 $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so
23+
cp $THIRD_PARTY_DIR/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so torch_xla/lib

setup.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#!/usr/bin/env python
2+
3+
from setuptools import setup, find_packages
4+
from torch.utils.cpp_extension import BuildExtension, CppExtension
5+
import os
6+
import platform
7+
import subprocess
8+
import sys
9+
10+
11+
def _check_env_flag(name, default=''):
12+
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
13+
14+
torch_xla_sources = [
15+
'torch_xla/csrc/batch_norm.cpp',
16+
'torch_xla/csrc/convolution.cpp',
17+
'torch_xla/csrc/cross_replica_reduces.cpp',
18+
'torch_xla/csrc/data_ops.cpp',
19+
'torch_xla/csrc/elementwise.cpp',
20+
'torch_xla/csrc/graph_context.cpp',
21+
'torch_xla/csrc/helpers.cpp',
22+
'torch_xla/csrc/init_python_bindings.cpp',
23+
'torch_xla/csrc/log_softmax.cpp',
24+
'torch_xla/csrc/module.cpp',
25+
'torch_xla/csrc/nll_loss.cpp',
26+
'torch_xla/csrc/pooling.cpp',
27+
'torch_xla/csrc/reduction.cpp',
28+
'torch_xla/csrc/tensor.cpp',
29+
'torch_xla/csrc/torch_util.cpp',
30+
'torch_xla/csrc/translator.cpp',
31+
'torch_xla/csrc/passes/eval_static_size.cpp',
32+
'torch_xla/csrc/passes/remove_unused_forward_outputs.cpp',
33+
'torch_xla/csrc/passes/replace_untraced_operators.cpp',
34+
'torch_xla/csrc/passes/threshold_backward_peephole.cpp',
35+
]
36+
37+
build_libs_cmd = './build_torch_xla_libs.sh'
38+
39+
if subprocess.call(build_libs_cmd) != 0:
40+
print("Failed to run '{}'".format(build_libs_cmd))
41+
sys.exit(1)
42+
43+
# Constant known variables used throughout this file
44+
cwd = os.path.dirname(os.path.abspath(__file__))
45+
lib_path = os.path.join(cwd, 'torch_xla', 'lib')
46+
pytorch_source_path = os.getenv('PYTORCH_SOURCE_PATH', '..')
47+
third_party_path = os.path.join(cwd, 'third_party')
48+
49+
include_dirs = [
50+
third_party_path + '/tensorflow/bazel-tensorflow',
51+
third_party_path + '/tensorflow/bazel-genfiles',
52+
third_party_path + '/tensorflow/bazel-tensorflow/external/protobuf_archive/src',
53+
third_party_path + '/tensorflow/bazel-tensorflow/external/eigen_archive',
54+
third_party_path + '/tensorflow/bazel-tensorflow/external/com_google_absl',
55+
]
56+
include_dirs += [
57+
pytorch_source_path,
58+
os.path.join(pytorch_source_path, 'torch', 'csrc'),
59+
os.path.join(pytorch_source_path, 'torch', 'lib', 'tmp_install', 'include'),
60+
]
61+
62+
library_dirs = []
63+
library_dirs.append(lib_path)
64+
65+
extra_link_args = []
66+
67+
DEBUG = _check_env_flag('DEBUG')
68+
IS_WINDOWS = (platform.system() == 'Windows')
69+
IS_DARWIN = (platform.system() == 'Darwin')
70+
IS_LINUX = (platform.system() == 'Linux')
71+
72+
73+
def make_relative_rpath(path):
74+
if IS_DARWIN:
75+
return '-Wl,-rpath,@loader_path/' + path
76+
elif IS_WINDOWS:
77+
return ''
78+
else:
79+
return '-Wl,-rpath,$ORIGIN/' + path
80+
81+
extra_compile_args = []
82+
83+
if DEBUG:
84+
if IS_WINDOWS:
85+
extra_link_args.append('/DEBUG:FULL')
86+
else:
87+
extra_compile_args += ['-O0', '-g']
88+
extra_link_args += ['-O0', '-g']
89+
90+
extra_link_args += ['-lxla_computation_client']
91+
92+
setup(
93+
name='torch_xla',
94+
version='0.1',
95+
description='XLA bridge for PyTorch',
96+
url='https://github.com/pytorch/xla',
97+
author='Alex Suhan, Davide Libenzi',
98+
99+
# Exclude the build files.
100+
packages=find_packages(exclude=['build']),
101+
ext_modules=[
102+
CppExtension(
103+
'_C',
104+
torch_xla_sources,
105+
include_dirs=include_dirs,
106+
extra_compile_args=extra_compile_args,
107+
library_dirs=library_dirs,
108+
extra_link_args=extra_link_args + [make_relative_rpath('torch_xla/lib')],
109+
),
110+
],
111+
package_data={
112+
'torch_xla': [
113+
'lib/*.so*',
114+
]
115+
},
116+
cmdclass={'build_ext': BuildExtension})

0 commit comments

Comments
 (0)