Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54034Fixes#53544
I had to touch a bunch of lines but the refactoring was fairly
mechanical. Here's how it works.
The basic concept behind this PR is that tensor_new.cpp was previously
abusing DispatchKey when it actually meant TensorOptions. The provided
DispatchKey argument to most of the constructor functions typically
comes from torch::tensors::get_default_dispatch_key(); it doesn't
really make sense for people to set the default dispatch key, but
this got grandfathered in due to the old API set_default_tensor_type
(where the "Type" concept got refactored into "DispatchKey" concept
over time). See also #53124. But the upshot is that, semantically,
what we refer to as the default dispatch key really is more like
torch.set_default_tensor_type(torch.Tensor) versus
torch.set_default_tensor_type(torch.cuda.Tensor): clearly the user
wants to do something about *construction* of the tensor, and
TensorOptions captures that exactly.
So, how exactly to translate from one to the other?
- Sources (things that used to PRODUCE DispatchKey)
- Most top level functions take a DispatchKey as their argument. I
use the new function dispatchKeyToTensorOptions to convert it into
a TensorOptions
- typeIdWithDefault now produces a TensorOptions (probably could do
with a rename, though I didn't)
- Sinks (things that used to CONSUME DispatchKey)
- Previously, the function options() was typically used to convert the
DispatchKey into a TensorOptions. Now its replacement build_options
just takes a TensorOptions and sets some extra fields on it.
Irritatingly, I can't just replace
`build_options(options, scalar_type, device)` with
`options.dtype(scalar_type).device(device)` because the semantics
are slightly different: if device is nullopt, we should preserve
the usage of the device specified in options (what options.device()
does is overwrite the device unconditionally; e.g., if device is
nullopt, unset device from options)
- The other major sink for DispatchKey was `internal_new_from_data`,
but it turns out it only really extracts the device type from
the dispatch key. Now it just pulls out the device from
TensorOptions.
- To actually do the translation of DispatchKey to TensorOptions, I
introduce new functions dispatchKeyToLayout (replicating
layout_from_backend--there are still a few uses of this function
so I couldn't delete it) and dispatchKeyToDeviceType (replacing
computeDeviceType)
- In all internal functions, whenever DispatchKey is taken as an argument,
I instead take TensorOptions as an argument, and pass it along.
- Anywhere `legacyExtractDispatchKey(other.key_set())` equality was
previously used, I now do `other.options().type_equal()`, which
is the intended BC for doing "backend to backend" comparisons
- There are a few places in the sparse constructors where we allocated
a tensor for values, and then read out the dispatch key from the
result to allocate the keys. As best as I can tell, this is totally
equivalent to just passing in the options to both values and indices
(the only difference is dtype, which is captured via a separate
argument)
This refactor doesn't really go far enough: for example, there are now
functions that take both TensorOptions and ScalarType, when really
the TensorOptions can capture this all. I kept it solely just
s/DispatchKey/TensorOptions/ to reduce the number of possible bugs;
also, a lot of this will be mooted by a proper fix to #53124.
Even with this limited refactor, the payoff is sweet. I can delete:
- backendToCPU
- backendToXPU
- backendToCUDA
- backendToHIP
- backendToBackendOfDeviceType
The reason I can do this is because I can simply overwrite layout in TensorOptions
to do the conversion, rather than having to type out each backend case
explicitly.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: bhosmer
Differential Revision: D27109509
Pulled By: ezyang
fbshipit-source-id: 91d16cfbc390127770362ac04fb43f7e070077e9
Summary:
According to pytorch/rfcs#3
From the goals in the RFC:
1. Support subclassing `torch.Tensor` in Python (done here)
2. Preserve `torch.Tensor` subclasses when calling `torch` functions on them (done here)
3. Use the PyTorch API with `torch.Tensor`-like objects that are _not_ `torch.Tensor`
subclasses (done in https://github.com/pytorch/pytorch/issues/30730)
4. Preserve `torch.Tensor` subclasses when calling `torch.Tensor` methods. (done here)
5. Propagating subclass instances correctly also with operators, using
views/slices/indexing/etc. (done here)
6. Preserve subclass attributes when using methods or views/slices/indexing. (done here)
7. A way to insert code that operates on both functions and methods uniformly
(so we can write a single function that overrides all operators). (done here)
8. The ability to give external libraries a way to also define
functions/methods that follow the `__torch_function__` protocol. (will be addressed in a separate PR)
This PR makes the following changes:
1. Adds the `self` argument to the arg parser.
2. Dispatches on `self` as well if `self` is not `nullptr`.
3. Adds a `torch._C.DisableTorchFunction` context manager to disable `__torch_function__`.
4. Adds a `torch::torch_function_enabled()` and `torch._C._torch_function_enabled()` to check the state of `__torch_function__`.
5. Dispatches all `torch._C.TensorBase` and `torch.Tensor` methods via `__torch_function__`.
TODO:
- [x] Sequence Methods
- [x] Docs
- [x] Tests
Closes https://github.com/pytorch/pytorch/issues/28361
Benchmarks in https://github.com/pytorch/pytorch/pull/37091#issuecomment-633657778
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37091
Reviewed By: ngimel
Differential Revision: D22765678
Pulled By: ezyang
fbshipit-source-id: 53f8aa17ddb8b1108c0997f6a7aa13cb5be73de0
Summary:
The init-list form of `at::indexing::Slice` (i.e. `tensor.index({{1, None, 2}, ...})` instead of `tensor.index({Slice(1, None, 2), ...})`) in C++ API can be easily confused with the list-form indexing in Python API (e.g. `tensor[[1, 3, 2], ...]`), which is not good from readability perspective. This PR removes the init-list form of `at::indexing::Slice` to make the API less confusing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34255
Test Plan: Imported from GitHub, without a `Test Plan:` line.
Differential Revision: D20290166
Pulled By: yf225
fbshipit-source-id: abbcbeca0b179219e5e1f196a33ef8aec87ebb76
Summary:
This PR adds the following items:
- **1st item**: `ArrayRef<TensorIndex>` and `std::initializer_list<TensorIndex>` overloads for `Tensor::index` and `Tensor::index_put_`, to be used specifically for multi-dim indexing purpose.
Design rationale:
* C++ `Tensor::index` and `Tensor::index_put_` are both existing tensor APIs, and they currently (before this PR) only accept a list of tensors (i.e. `ArrayRef<Tensor>`) as indices. If we change their signatures to also accept non-tensors as indices (i.e. `ArrayRef<TensorIndex>`, and `TensorIndex` is convertible from `Tensor` / `Slice` / `None` / `Ellipsis`), it would slow down the original code path (since now it has to go through more steps), which is undesirable.
To get around this problem, the proposed solution is to keep the original `ArrayRef<Tensor>` overload, and add `ArrayRef<TensorIndex>` and `std::initializer_list<TensorIndex>` overloads to `Tensor::index` and `Tensor::index_put_`. This way, the original code path won’t be affected, and the tensor multi-dim indexing API is only used when the user explicitly pass an `ArrayRef<TensorIndex>` or a braced-init-list of `TensorIndex`-convertible types to `Tensor::index` and `Tensor::index_put_` .
Note that the above proposed solution would still affect perf for the user’s original `Tensor::index` or `Tensor::index_put_` call sites that use a braced-init-list of tensors as input, e.g. `tensor.index({...})` or `tensor.index_put_({...}, value)`, since now such function calls would take the multi-dim indexing path instead of the original advanced indexing path. However, there are only two instances of this in our codebase (one in ATen cpp test, one in a C++ API nn init function), and they can be easily changed to explicitly use `ArrayRef<Tensor>` as input (I changed them in this PR). For external user’s code, since this is part of the C++ frontend which is still considered experimental, we will only talk about this change in the release note, and ask users to switch to using `ArrayRef<Tensor>` explicitly if they want to keep using the original advanced indexing code path.
- **2nd item**: Mechanisms for parsing `ArrayRef<TensorIndex>` indices and performing indexing operations (mirroring the functions in `torch/csrc/autograd/python_variable_indexing.cpp`).
- **3rd item**: Simple tests to demonstrate that the `Tensor::index()` and `Tensor::index_put_()` APIs work. I will add more tests after the first few PRs are reviewed.
- **4th item**: Merge Python/C++ indexing code paths, for code simplicity. I tested locally and found that there is no perf regression resulting from the merge. I will get more concrete numbers for common use cases when we settle on the overall design.
This PR supersedes https://github.com/pytorch/pytorch/pull/30425.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32841
Differential Revision: D19919692
Pulled By: yf225
fbshipit-source-id: 7467e64f97fc0e407624809dd183c95ea16b1482
Summary:
Fixes https://github.com/pytorch/pytorch/issues/29161.
I looked a bit at the code changes related to this and think I have all of the use cases of `DeprecatedTypeProperties` covered in the message, but suggestions from someone with more context on this would be very much appreciated :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30281
Differential Revision: D18830818
Pulled By: ezyang
fbshipit-source-id: 1a7fcee15354ae09e6644577e7fa33bd26acfe20
Summary:
Given that pybind11 implements these gil functions, I don't think it makes sense for Pytorch to have its own bespoke versions.
Fixes https://github.com/pytorch/pytorch/issues/29065
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29095
Differential Revision: D18301806
Pulled By: ezyang
fbshipit-source-id: 03da6a26c41ee65aaadf7b67b9f0b14d2def2a5a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29203
There is no more Variable/Tensor distinction, so fix the misleading name.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Differential Revision: D18353505
Pulled By: ezyang
fbshipit-source-id: dadc394d533ab7746f70bc186c6645441a784518
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28620
All Tensors are Variables now, they just happen to have requires_grad=False. Tensors ALWAYS have `VariableTensorId` in their type set.
When constructing this patch, I had to make decisions about what I would fix in this patch, and what I would leave for follow up PRs. Here is the cleanup that happens in this patch:
- The `is_variable` property is removed from TensorOptions. I removed this immediately because unlike Tensor::is_variable, TensorOptions::is_variable doesn't respect our VariableTensorId thread-local state. This means that there were a bunch of places where TensorOptions::is_variable was false, which is obviously bogus in the world when tensor and variable are merged. Instead of keeping the method as a function that always returns true, I just opted to remove it entirely (it's not public API.) All places we set `is_variable` are deleted.
- Knock on effect: there is no longer a separate DeprecatedTypeProperties for the variable and non-variable versions of type.
- Knock on effect: instead of asserting on TensorOptions::is_variable, instead we just test `at::impl::variable_is_excluded()`
- There is now only one copy of the cuDNN RNN dropout cache, not two (I'm not sure why we had two to begin with)
Some cleanup that doesn't happen in this patch:
- Eliminating unnecessary uses of `make_variable`
- Eliminating `Tensor::is_variable`
The most subtle part of this patch is retaining tracing behavior: the fact that everything is a Variable means that more code gets routed to VariableType than before; this can change traces. I identified two places where we didn't appropriately turn off VariableType, mostly factory functions:
- `torch.tensor` must turn off VariableType before invoking `at::empty` to construct the tensor, as it subsequently does direct data access
- `tensor_slow` (invoked when you pass a Python scalar to a tensor argument) must turn off VariableType before calling `scalar_to_tensor` so the scalar gets traced as constant, rather than as a call to `scalar_to_tensor`.
Honestly, these are all giant hacks, and should be replaced with a more specialized guard that just toggles tracing.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: dreiss
Differential Revision: D18171156
Pulled By: ezyang
fbshipit-source-id: 5b6a045beba37492647e350190f495114e86504d
Summary:
Fix Slice/Select trace arguments. This PR stashes arguments to functions in order to avoid tracing them as constants.
This PR depends on a fix for select op in PR:
https://github.com/pytorch/pytorch/pull/25273
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26549
Reviewed By: hl475
Differential Revision: D17623851
Pulled By: houseroad
fbshipit-source-id: ae314004266688d2c25c5bada2dcedbfc4f39c5b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25308
Instead of storing a single TensorTypeId in a Tensor, we store a bitset of tensor type IDs in a Tensor, TensorTypeSet. This class comes with some unit tests. This is in preparation for making Variable a TensorTypeId. In order to help flush out places where this makes a semantic difference, we rename `Tensor::type_id()` to `Tensor::type_set()` and smoke out all of the locations where this was semantically meaningful.
Because the new tensor type set is 64-bits, this increases the size of Tensor by a word.
Listing of semantic changes:
* Many TensorImpl related constructors just propagate TensorTypeId to a parent constructor. These are pretty simple to adjust.
* Backend extensions are now in the business of explicitly constructing a TensorTypeSet and then passing it in. This is probably OK for now but when Variable drops, these dispatch IDs may get immediately overwritten to have Variable set.
* `sparseTensorSetToDeviceType` and similar functions previously did an equality test with TensorTypeId, to determine what an appropriate device type is. This equality is now replaced with a set inclusion test. This is valid, under the assumption that we don't ever have weird sets like "this tensor is simultaneously a sparse CPU tensor and a sparse CUDA tensor", which will be true in the short term plan of adding Variable to the dispatch ID.
* `impl::dispatchTypeId` was generally introduced for cases where we legitimately need to convert from `TensorTypeSet -> TensorTypeId` in a dispatch related manner. At the moment, the implementation is trivial, but they will soon be adjusted to handle TLS. I've tried to make these call sites as forwards compatible as possible:
* `checked_tensor_unwrap` and co now use `dispatchTypeId`. When Variable is added to the type set, these will always be called in a context where the Variable type ID is disabled, so we will get the correct underlying tensor type ID.
* Uses of `Backend` in dispatch are now replaced with `TensorTypeSet`. The general heuristic here for whether or not to accept a `TensorTypeId` or `TensorTypeSet` is that we want to make the generated code as simple as possible. It is easier to retrieve a `TensorTypeSet`, so that's a more appropriate API in these cases.
* In some cases, I could not conveniently switch an implementation to the new semantics, because it was blocked on some other refactor. In this case, I introduced `legacyExtractTypeId`, which gives what would be a BC-compatible `TensorTypeSet` to `TensorTypeId` implementation that will continue to report the same values it would have prior to this change. This is **different** from `dispatchTypeId`, because this function does NOT respect TLS; it always ignores Variable type IDs.
* c10 dispatcher tests, which are oblivious to Variable dispatch, use this BC function (actually, they use `extractTypeId`, an overload for Tensor.
* The implementation of `new_*` methods heavily relies on tensor type ID, I chose not to unwind this. PR to refactor this at https://github.com/pytorch/pytorch/pull/25475
* Slicing also relies on tensor type ID, see `torch/csrc/autograd/python_variable_indexing.cpp` (though in some cases in this file, I was able to replace use of tensor type ID with TensorOptions)
* In some cases, there is an equality test on tensor type ID which would be better done by testing "tensor axes". In those cases, I replaced those equality tests with more equality tests.
* Example: `torch/csrc/nn/type_checks.h`
* There is a total punt in `torch/csrc/tensor/python_tensor.cpp` where "instance of" checking is done via dispatch ids. In general, the Variable-ness of a tensor doesn't participate in instanceof testing. It's not entirely clear what to do here.
* Instead of storing `Backend` in `VariableInfo`, we now just store Layout.
c10 dispatcher test updates were done with:
```
:%s/\([^ ]\+\)\.type_id()/extractTypeId(\1)/g
:%s/\([^( ]\+\)->type_id()/extractTypeId(*\1)/g
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25308
Differential Revision: D17092791
Test Plan: sandcastle and ossci
Reviewed By: bwasti
Pulled By: ezyang
fbshipit-source-id: 22207d14fe62dd31ee19cc5011af22e3d9aabb5b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25252
Our model going forward for extensions will be that you will have to
get an allocation of an ID in our system. This is how things work
in practice today; we're just simplifying our underlying registration
since there is no need to have distributed registration.
There are some codemods in this diff:
```
codemod --extensions cpp,h,cc,cuh,py,in --exclude-paths=c10/core/TensorTypeId.h '([A-Za-z]+?)TensorId\(\)' 'TensorTypeId::\1TensorId'
codemod --extensions cpp,h,cc,cuh,py,in 'TensorTypeIds::undefined\(\)' 'TensorTypeId::UndefinedTensorId'
codemod --extensions cpp 'TensorType1\(\)' 'TensorTypeId::CPUTensorId'
codemod --extensions cpp 'TensorType2\(\)' 'TensorTypeId::CUDATensorId'
codemod --extensions cpp 'TensorType3\(\)' 'TensorTypeId::XLATensorId'
codemod --extensions cpp 'TensorType1' 'CPUTensorId'
codemod --extensions cpp 'TensorType2' 'CUDATensorId'
codemod --extensions cpp 'TensorType3' 'XLATensorId'
```
The main hand-written changes are in c10/core/TensorTypeId.h
Other manual fixes:
- aten/src/ATen/core/op_registration/op_registration.cpp - stop using
std::string operator+
- aten/src/ATen/function_wrapper.py - handle a hardcoded TypeId() that
wasn't caught by codemod
- torch/csrc/tensor/python_tensor.h - fix now incorrect forward declaration
of TensorTypeId
- aten/src/ATen/core/op_registration/ - remove out-of-line registration
Differential Revision: D17072001
Test Plan: ossci and sandcastle
Pulled By: ezyang
fbshipit-source-id: c641515fd0604c045c54fbb1d6b1b950f45e89d1
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19676
Make copy work with QTensor, enable assignment of QTensor in pytorch frontend.
Differential Revision: D15064710
fbshipit-source-id: 04f2dc02a825695d41fa1114bfca49e92108fef3
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19530
Make copy work with QTensor, enable assignment of QTensor in pytorch frontend.
Differential Revision: D15008160
fbshipit-source-id: 5f1166246d768b23f009cde1fa03e8952368a332
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17991
changes:
-Breaks bc: Tensor::type() now returns DeprecatedTypeProperties& rather than Type&.
-Added DeprecatedTypeProperties, it serves as a temporary replacement for Type as the return value of Tensor::type(). This contributes to making Type just for dispatch purposes so that we can make it dtype agnostic.
-Tensor::dispatch_type() now returns Type& like Tensor::type() used to do.
-Changed callsites of Tensor::type() appropriately.
Reviewed By: ezyang
Differential Revision: D14443117
fbshipit-source-id: 239ccb7a09626279a71d1a37f8f82e7f57bf7d9e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16751
This was made more complicated by the fact that ivalue::IntList
is a thing. So I had to fix all of the sites where we referring
to IValue post facto.
The following codemods were run, in this order:
```
codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in IntList IntArrayRef
codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in IntArrayRef::create IntList::create
codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in ivalue::IntArrayRef ivalue::IntList
codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in Tag::IntArrayRef Tag::IntList
codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in isIntArrayRef isIntList
codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in toIntArrayRef toIntList
codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in 'Shared<IntArrayRef>' 'Shared<IntList>'
codemod -m -d . --extensions cc,cpp,cu,cuh,h,hpp,py,cwrap,yaml,in 'intrusive_ptr<IntArrayRef>' 'intrusive_ptr<IntList>'
```
Some manual fixups were done afterwards; they can be reviewed separately
at https://github.com/pytorch/pytorch/pull/16752
Reviewed By: dzhulgakov
Differential Revision: D13954363
fbshipit-source-id: b5c40aacba042402155a2f5a229fa6db7992ac64
Summary:
applySelect does modify the tensor and removes the top most dimension which makes it complicated to track just using dim and need to use another parameter as real_dim to signify original dimension
fixes#16192
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16495
Differential Revision: D13897182
Pulled By: gchanan
fbshipit-source-id: 105581dbbff6b431cc8e2539a07e0058161e53a1
Summary:
…_tensor.
This is part of a long series of paring down the Type interface.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15074
Differential Revision: D13421482
Pulled By: gchanan
fbshipit-source-id: 84010ee71fef2cb74d32d5de7858d8ed9f36b885
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.
I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.
I used the following script to do the canonicalization:
```
import subprocess
import re
import os.path
files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
for fn in files:
if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
continue
if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
continue
with open(fn, 'r') as f:
c = f.read()
def fmt(p):
return "#include <{}>".format(p)
def repl(m):
p = m.group(1)
if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
return fmt(p)
if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
return fmt(p)
for root in ["aten/src", "torch/lib", ""]:
for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
new_p = os.path.relpath(os.path.join(bad_root, p), root)
if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
return fmt(new_p)
print("ERROR: ", fn, p)
return m.group(0)
new_c = re.sub(r'#include "([^"]+)"', repl, c)
if new_c != c:
print(fn)
with open(fn, 'w') as f:
f.write(new_c)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849
Reviewed By: dzhulgakov
Differential Revision: D13363445
Pulled By: ezyang
fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
Summary:
This would solve the tracing problems of #13969.
Fixes: #14732
I would appreciate if this got good scrutiny before applied.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14097
Differential Revision: D13323181
Pulled By: ezyang
fbshipit-source-id: dcd104b497c0bfddb751923c6166a3824b7a3702
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14421
Last time I looked this, I bailed because it seemed like there were
a lot of sites to fix. Well, I need this to work properly for out-of-place
HIPify, so I took another whack at it. Changes should be pretty self-explanatory.
Reviewed By: gchanan
Differential Revision: D13221302
fbshipit-source-id: ed21e2668a1a629898a47358baf368fe680263a0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13342
This PR introduces a few new concepts:
- DeviceGuardImplInterface, and implementations for CPU and CUDA, which
provide a generic interface for interfacing with device and stream state,
without requiring a direct dependency on the code in question.
- InlineDeviceGuard, a general template for generating both specialized
and dynamically dispatched device guard implementations. Dynamic
dispatch is done by specializing it on a VirtualGuardImpl.
- Provide a device-independent DeviceGuard class, which can be used even
from CPU code. It uses the aforementioned dynamic dispatch.
- CUDA-specialized CUDAGuard class, which doesn't have a dynamic dispatch
but can only be used from CUDA.
- StreamGuard, which is the same as above, but for streams rather than
devices.
- Optional variants of all the aforementioned guards, which are a no-op if
no device/stream is specified
- CUDAMultiStreamGuard, specifically for the case when we want to set
a device on every guard.
There are some subtle semantic changes, which have been thoroughly documented
in the class definition.
BC-breaking changes:
- Move constructor/assignment have been removed from all device guard
implementations.
- In some cases where you previously wrote 'set_device' (or 'set_stream'), you now must write
'reset_device', because if you switch devices/device types, the stream/device on the
previous device is unset. This is different from previous behavior.
- CUDAGuard no longer handles streams, or multiple streams. Use CUDAStreamGuard
or CUDAMultiStreamGuard as appropriate for your use case.
Reviewed By: dzhulgakov
Differential Revision: D12849620
fbshipit-source-id: f61956256f0b12be754b3234fcc73c2abc1be04e
Summary:
This is a redo of the previous move which broke OS X and Windows tests -- RTTI seemed to be broken
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13455
Differential Revision: D12883775
Pulled By: bwasti
fbshipit-source-id: 2b6c65e8150e6f89624c6ee99c389335c6fb4bb8