Skip to content

Remove use of torch.set_default_tensor_type #674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/transformers_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def forward(
return TokensPlus(**token_data), lambda d_tokens: []

return Model(
"tokenizer", forward, attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
"tokenizer",
forward,
attrs={"tokenizer": AutoTokenizer.from_pretrained(name)},
)


Expand Down Expand Up @@ -166,11 +168,14 @@ def convert_transformer_outputs(model, inputs_outputs, is_train):

def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs:
# Restore entries for bos and eos markers.
shim = model.shims[0]
row = model.ops.alloc2f(1, d_tokvecs[0].shape[1])
d_tokvecs = [model.ops.xp.vstack((row, arr, row)) for arr in d_tokvecs]
return ArgsKwargs(
args=(torch_tokvecs,),
kwargs={"grad_tensors": xp2torch(model.ops.pad(d_tokvecs))},
kwargs={
"grad_tensors": xp2torch(model.ops.pad(d_tokvecs, device=shim.device))
},
)

return tokvecs, backprop
Expand Down
3 changes: 1 addition & 2 deletions thinc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._cupy_allocators import cupy_tensorflow_allocator, cupy_pytorch_allocator
from ._param_server import ParamServer
from ..util import assert_tensorflow_installed, assert_pytorch_installed
from ..util import is_cupy_array, set_torch_tensor_type_for_ops, require_cpu
from ..util import is_cupy_array, require_cpu
from .. import registry
from ..compat import cupy, has_cupy

Expand Down Expand Up @@ -134,7 +134,6 @@ def set_current_ops(ops: Ops) -> None:
"""Change the current backend object."""
context_ops.set(ops)
_get_thread_state().ops = ops
set_torch_tensor_type_for_ops(ops)


def contextvars_eq_thread_ops() -> bool:
Expand Down
7 changes: 5 additions & 2 deletions thinc/backends/_cupy_allocators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import cast

from ..types import ArrayXd
from ..util import tensorflow2xp
from ..util import get_torch_default_device, tensorflow2xp
from ..compat import torch, cupy, tensorflow


Expand All @@ -23,6 +23,7 @@ def cupy_tensorflow_allocator(size_in_bytes: int):


def cupy_pytorch_allocator(size_in_bytes: int):
device = get_torch_default_device()
"""Function that can be passed into cupy.cuda.set_allocator, to have cupy
allocate memory via PyTorch. This is important when using the two libraries
together, as otherwise OOM errors can occur when there's available memory
Expand All @@ -34,7 +35,9 @@ def cupy_pytorch_allocator(size_in_bytes: int):
# creating a whole Tensor.
# This turns out to be way faster than making FloatStorage? Maybe
# a Python vs C++ thing I guess?
torch_tensor = torch.zeros((size_in_bytes // 4,), requires_grad=False)
torch_tensor = torch.zeros(
(size_in_bytes // 4,), requires_grad=False, device=device
)
# cupy has a neat class to help us here. Otherwise it will try to free.
# I think this is a private API? It's not in the types.
address = torch_tensor.data_ptr() # type: ignore
Expand Down
30 changes: 24 additions & 6 deletions thinc/layers/pytorchwrapper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Callable, Tuple, Optional, Any, cast

from ..compat import torch
from ..model import Model
from ..shims import PyTorchGradScaler, PyTorchShim
from ..config import registry
from ..util import is_xp_array, is_torch_array
from ..util import is_xp_array, is_torch_array, partial
from ..util import xp2torch, torch2xp, convert_recursive
from ..types import Floats3d, ArgsKwargs, Padded

Expand Down Expand Up @@ -76,6 +77,7 @@ def PyTorchWrapper_v2(
convert_outputs: Optional[Callable] = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
) -> Model[Any, Any]:
"""Wrap a PyTorch model, so that it has the same API as Thinc models.
To optimize the model, you'll need to create a PyTorch optimizer and call
Expand Down Expand Up @@ -105,6 +107,10 @@ def PyTorchWrapper_v2(
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""
if convert_inputs is None:
convert_inputs = convert_pytorch_default_inputs
Expand All @@ -116,7 +122,10 @@ def PyTorchWrapper_v2(
attrs={"convert_inputs": convert_inputs, "convert_outputs": convert_outputs},
shims=[
PyTorchShim(
pytorch_model, mixed_precision=mixed_precision, grad_scaler=grad_scaler
pytorch_model,
mixed_precision=mixed_precision,
grad_scaler=grad_scaler,
device=device,
)
],
dims={"nI": None, "nO": None},
Expand Down Expand Up @@ -149,7 +158,8 @@ def backprop(dY: Any) -> Any:
def convert_pytorch_default_inputs(
model: Model, X: Any, is_train: bool
) -> Tuple[ArgsKwargs, Callable[[ArgsKwargs], Any]]:
xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train)
shim = cast(PyTorchShim, model.shims[0])
xp2torch_ = lambda x: xp2torch(x, requires_grad=is_train, device=shim.device)
converted = convert_recursive(is_xp_array, xp2torch_, X)
if isinstance(converted, ArgsKwargs):

Expand Down Expand Up @@ -181,11 +191,14 @@ def reverse_conversion(dXtorch):


def convert_pytorch_default_outputs(model: Model, X_Ytorch: Any, is_train: bool):
shim = cast(PyTorchShim, model.shims[0])
X, Ytorch = X_Ytorch
Y = convert_recursive(is_torch_array, torch2xp, Ytorch)

def reverse_conversion(dY: Any) -> ArgsKwargs:
dYtorch = convert_recursive(is_xp_array, xp2torch, dY)
dYtorch = convert_recursive(
is_xp_array, partial(xp2torch, device=shim.device), dY
)
return ArgsKwargs(args=((Ytorch,),), kwargs={"grad_tensors": dYtorch})

return Y, reverse_conversion
Expand All @@ -195,6 +208,7 @@ def reverse_conversion(dY: Any) -> ArgsKwargs:


def convert_rnn_inputs(model: Model, Xp: Padded, is_train: bool):
shim = cast(PyTorchShim, model.shims[0])
size_at_t = Xp.size_at_t
lengths = Xp.lengths
indices = Xp.indices
Expand All @@ -203,15 +217,19 @@ def convert_from_torch_backward(d_inputs: ArgsKwargs) -> Padded:
dX = torch2xp(d_inputs.args[0])
return Padded(dX, size_at_t, lengths, indices) # type: ignore

output = ArgsKwargs(args=(xp2torch(Xp.data, requires_grad=True), None), kwargs={})
output = ArgsKwargs(
args=(xp2torch(Xp.data, requires_grad=True, device=shim.device), None),
kwargs={},
)
return output, convert_from_torch_backward


def convert_rnn_outputs(model: Model, inputs_outputs: Tuple, is_train):
shim = cast(PyTorchShim, model.shims[0])
Xp, (Ytorch, _) = inputs_outputs

def convert_for_torch_backward(dYp: Padded) -> ArgsKwargs:
dYtorch = xp2torch(dYp.data, requires_grad=True)
dYtorch = xp2torch(dYp.data, requires_grad=True, device=shim.device)
return ArgsKwargs(args=(Ytorch,), kwargs={"grad_tensors": dYtorch})

Y = cast(Floats3d, torch2xp(Ytorch))
Expand Down
40 changes: 29 additions & 11 deletions thinc/shims/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import srsly

from ..util import torch2xp, xp2torch, convert_recursive, iterate_recursive
from ..util import get_torch_default_device
from ..compat import torch
from ..backends import get_current_ops, context_pools, CupyOps
from ..backends import set_gpu_allocator
Expand All @@ -25,6 +26,10 @@ class PyTorchShim(Shim):
The gradient scaler to use for mixed-precision training. If this
argument is set to "None" and mixed precision is enabled, a gradient
scaler with the default configuration is used.
device:
The PyTorch device to run the model on. When this argument is
set to "None", the default device for the currently active Thinc
ops is used.
"""

def __init__(
Expand All @@ -34,12 +39,20 @@ def __init__(
optimizer: Any = None,
mixed_precision: bool = False,
grad_scaler: Optional[PyTorchGradScaler] = None,
device: Optional["torch.device"] = None,
):
super().__init__(model, config, optimizer)

if device is None:
device = get_torch_default_device()
if model is not None:
model.to(device)

if grad_scaler is None:
grad_scaler = PyTorchGradScaler(mixed_precision)

grad_scaler.to_(device)

self._grad_scaler = grad_scaler

self._mixed_precision = mixed_precision
Expand All @@ -58,6 +71,14 @@ def __call__(self, inputs, is_train):
else:
return self.predict(inputs), lambda a: ...

@property
def device(self):
p = next(self._model.parameters(), None)
if p is None:
return get_torch_default_device()
else:
return p.device

def predict(self, inputs: ArgsKwargs) -> Any:
"""Pass inputs through to the underlying PyTorch model, and return the
output. No conversions are performed. The PyTorch model is set into
Expand Down Expand Up @@ -126,7 +147,9 @@ def finish_update(self, optimizer: Optimizer):
cast(FloatsXd, torch2xp(torch_data.data)),
cast(FloatsXd, torch2xp(torch_data.grad)),
)
torch_data.data = xp2torch(param, requires_grad=True)
torch_data.data = xp2torch(
param, requires_grad=True, device=torch_data.device
)
torch_data.grad.zero_()

self._grad_scaler.update()
Expand All @@ -137,7 +160,7 @@ def use_params(self, params):
state_dict = {}
for k, v in params.items():
if hasattr(k, "startswith") and k.startswith(key_prefix):
state_dict[k.replace(key_prefix, "")] = xp2torch(v)
state_dict[k.replace(key_prefix, "")] = xp2torch(v, device=self.device)
if state_dict:
backup = {k: v.clone() for k, v in self._model.state_dict().items()}
self._model.load_state_dict(state_dict)
Expand All @@ -164,17 +187,12 @@ def to_bytes(self):
return srsly.msgpack_dumps(msg)

def from_bytes(self, bytes_data):
ops = get_current_ops()
device = get_torch_default_device()
msg = srsly.msgpack_loads(bytes_data)
self.cfg = msg["config"]
filelike = BytesIO(msg["state"])
filelike.seek(0)
if ops.device_type == "cpu":
map_location = "cpu"
else: # pragma: no cover
device_id = torch.cuda.current_device()
map_location = "cuda:%d" % device_id
self._model.load_state_dict(torch.load(filelike, map_location=map_location))
self._model.to(map_location)
self._grad_scaler.to_(map_location)
self._model.load_state_dict(torch.load(filelike, map_location=device))
self._model.to(device)
self._grad_scaler.to_(device)
return self
47 changes: 27 additions & 20 deletions thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@
from .api import Ops


def get_torch_default_device() -> "torch.device":
if torch is None:
raise ValueError("Cannot get default Torch device when Torch is not available.")

from .backends import get_current_ops
from .backends.cupy_ops import CupyOps

ops = get_current_ops()
if isinstance(ops, CupyOps):
device_id = torch.cuda.current_device()
return torch.device(f"cuda:{device_id}")

return torch.device("cpu")


def get_array_module(arr): # pragma: no cover
if is_cupy_array(arr):
return cupy
Expand Down Expand Up @@ -133,7 +148,6 @@ def set_active_gpu(gpu_id: int) -> "cupy.cuda.Device": # pragma: no cover

if has_torch_gpu:
torch.cuda.set_device(gpu_id)
torch.set_default_tensor_type("torch.cuda.FloatTensor")

return device

Expand All @@ -144,7 +158,6 @@ def require_cpu() -> bool: # pragma: no cover

ops = get_ops("cpu")
set_current_ops(ops)
set_torch_tensor_type_for_ops(ops)

return True

Expand Down Expand Up @@ -309,17 +322,27 @@ def iterate_recursive(is_match: Callable[[Any], bool], obj: Any) -> Any:


def xp2torch(
xp_tensor: ArrayXd, requires_grad: bool = False
xp_tensor: ArrayXd,
requires_grad: bool = False,
device: Optional["torch.device"] = None,
) -> "torch.Tensor": # pragma: no cover
"""Convert a numpy or cupy tensor to a PyTorch tensor."""
assert_pytorch_installed()

if device is None:
device = get_torch_default_device()

if hasattr(xp_tensor, "toDlpack"):
dlpack_tensor = xp_tensor.toDlpack() # type: ignore
torch_tensor = torch.utils.dlpack.from_dlpack(dlpack_tensor)
else:
torch_tensor = torch.from_numpy(xp_tensor)

torch_tensor = torch_tensor.to(device)

if requires_grad:
torch_tensor.requires_grad_()

return torch_tensor


Expand Down Expand Up @@ -529,22 +552,6 @@ def use_nvtx_range(message: str, id_color: int = -1):
yield


def set_torch_tensor_type_for_ops(ops):
"""Set the PyTorch default tensor type for the given ops. This is a
no-op if PyTorch is not available."""
from .backends.cupy_ops import CupyOps

try:
import torch

if CupyOps.xp is not None and isinstance(ops, CupyOps):
torch.set_default_tensor_type("torch.cuda.FloatTensor")
else:
torch.set_default_tensor_type("torch.FloatTensor")
except ImportError:
pass


@dataclass
class ArrayInfo:
"""Container for info for checking array compatibility."""
Expand All @@ -569,6 +576,7 @@ def check_consistency(self, arr: ArrayXd):

__all__ = [
"get_array_module",
"get_torch_default_device",
"fix_random_seed",
"is_cupy_array",
"is_numpy_array",
Expand All @@ -586,6 +594,5 @@ def check_consistency(self, arr: ArrayXd):
"DataValidationError",
"make_tempfile",
"use_nvtx_range",
"set_torch_tensor_type_for_ops",
"ArrayInfo",
]