Initial prototype for dynamic int inputs, allows users to run with `torch.compile(f)(DynamicInt(4))`, compiling dynamically and using the underlying hint at runtime.
Current behavior:
- Also works in eager (mostly by subclassing int), as scalar input to torch functions, or numpy/math/etc. For example, `x = DynamicInt(3); torch.randn(x); torch.add(y, z, alpha=x); np.arange(x)` all act as if x = 3.
- Behavior for arithmetic ops is to return new DynamicInts rather than static ints; `DynamicInt(3) * 2 = DynamicInt(6)`. This is via SymNode magic methods, but coverage might not be 100% - for example, I had to explicitly override floordiv to avoid int casting. This is not necessarily the case for non-magic method ops (e.g. `math.cos(x)`). The alternative here is to int cast on all operations, but I opted for this for dynamism propagation in non-compiled regions.
- Doesn't ban fullgraph=False; DynamicInt objects might be leaked back to the user, but I guess this is fine, because they can be casted to ints when needed?
- Dynamo only allocates one symbol per DynamicInt; specifying the same DynamicInt for multiple inputs leads to input deduplication, and a guard installed.
- We don't raise on int specialization (in allowlist/maybe_mark_dynamic style) - but an easy change if needed.
- DynamicInts as nn.Module attributes are handled.
- We don't guard on the DynamicInt id, e.g. users can do the following without recompiling (maybe we should guard?)
```python
x = DynamicInt(4)
f(x)
f(1)
f(DynamicInt(3)) # same as f(3)
```
Follow-up work:
- Specifying shape constraints, either at the int-level, e.g.
```python
DynamicInt(64, name="s0", constraints=["s0 % 32 == 0", "s0 <= 1024"]
```
or at the compilation level, e.g. something like
```python
s0 = DynamicInt(64, name="s0")
s1 = DynamicInt(128, name="s1")
with some_compiler_config.dynamic_int_constraints(["s1 == 2*s0", "s0 % 32 == 0"]):
f(s0, s1)
```
This should subsume the need for specifying derived SymInts?
- SymFloat support - currently it seems backed floats are specialized by the tensorify float pass, and there's no handling in inductor.
- Propagating dynamism in tensor constructors, e.g. `x = DynamicInt(4); torch.randn(x)` could annotate `_dynamo_dynamic_indices`.
Differential Revision: D81698719
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162194
Approved by: https://github.com/bobrenjc93
Fixes#148208. There are solutions for exposing symbols implicitly from inline functions (i.e., inline function A calls non-inline function B in foo.h. Code includes foo.h has to see the symbol B in DLL).
Solution 1: tag the entire struct where the inline functions are defined as member functions with TORCH_PYTHON_API --- this PR does this for python_arg_parser.h. An alternative solution exists but will slow down dispatching a lot --- drop inline keyword and move implementation to .cc file.
Solution 2: tag individual functions with TORCH_PYTHON_API. This PR does this for python_tensor.h.
Related discussion about hiding torch_python symbols: https://github.com/pytorch/pytorch/pull/142214
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148213
Approved by: https://github.com/malfet
These two APIs are being used internally for some projects and need to be exposed as the build for this is done using OSS toolchain.
af8789c056 - this change hid most apis in torch python barring the ones explicitly specified breaking the build.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143380
Approved by: https://github.com/suo
# Motivation
Before this PR, device construction was `cuda` type when only a device index was given. It also returns the `PrivateUser1` type if a `PrivateUser1` type is registered.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
>>> b
tensor([1, 2], device='cuda:0')
```
It works well on CUDA GPU. But it will raise unexpected information and error running on XPU.
```bash
>>> import torch
>>> device = torch.device(0)
>>> device.type
'cuda'
>>> a = torch.tensor([1, 2])
>>> b = a.to(0)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/xxx/pytorch/torch/cuda/__init__.py", line 302, in _lazy_init
raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
```
With this PR, refine the logic to use the currently available device type instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129119
Approved by: https://github.com/albanD, https://github.com/gujinghui, https://github.com/EikanWang
ghstack dependencies: #129463, #129205, #129363
Summary:
`-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it.
This:
```
try {
...
} catch (exception& e) {
// no use of e
}
```
should instead be written as
```
} catch (exception&) {
```
If the code compiles, this is safe to land.
Test Plan: Sandcastle
Reviewed By: palmje
Differential Revision: D55548497
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123056
Approved by: https://github.com/Skylion007
This PR proposes to use std::optional<Generator>& for underlying functions to avoid unnecessary copy and move operations. The torchgen code was changed to generate the new type.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120076
Approved by: https://github.com/malfet
This PR proposes to use std::optional<Generator>& for underlying functions to avoid unnecessary copy and move operations. The torchgen code was changed to generate the new type.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120076
Approved by: https://github.com/malfet
This PR proposes to use std::optional<Generator>& for underlying functions to avoid unnecessary copy and move operations. The torchgen code was changed to generate the new type.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120076
Approved by: https://github.com/malfet
Fixes#115331.
This PR increases the number of valid GPU devices to 512 (from 64) in order to future-proof PyTorch for providers that offer [single nodes with a large device count](https://www.tensorwave.com/). Until now, `DeviceIndex` was an `int8_t`, thus multiple changes were necessary:
- `DeviceIndex` changed to `int16_t`. Updated consumers that assume it to be an `int8_t`.
- Updated bounds checking for `torch.device()` in the Python frontend. Right now, we allow funny things like `torch.device('cpu', 200).index == -56`, which is undefined behavior. I inserted some checks to only allow values between 0 and `c10::Device::MAX_NUM_DEVICES - 1`.
- Updated the `ArgumentInfo` struct as it hardcodes the device index as 8 bit field [^1]. Might be a breaking change, not sure if users rely on this.
- Introduced `c10::Device::MAX_NUM_DEVICES` as a replacement for the old `C10_COMPILE_TIME_MAX_GPUS`
[^1]: This field was unsigned, so I guess this has also been undef behavior the whole time? Our default device index is -1, so this always wrapped around to 255 when written to the `ArgumentInfo` struct. When I switched the `DeviceIndex` to `int16_t`, it actually stayed 255 after unpacking from `ArgumentInfo` again, as the `DeviceIndex` was now wide enough that it didn't wrap back to -1.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119639
Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/huydhn
Fixes#115331.
This PR increases the number of valid GPU devices to 512 (from 64) in order to future-proof PyTorch for providers that offer [single nodes with a large device count](https://www.tensorwave.com/). Until now, `DeviceIndex` was an `int8_t`, thus multiple changes were necessary:
- `DeviceIndex` changed to `int16_t`. Updated consumers that assume it to be an `int8_t`.
- Updated bounds checking for `torch.device()` in the Python frontend. Right now, we allow funny things like `torch.device('cpu', 200).index == -56`, which is undefined behavior. I inserted some checks to only allow values between 0 and `c10::Device::MAX_NUM_DEVICES - 1`.
- Updated the `ArgumentInfo` struct as it hardcodes the device index as 8 bit field [^1]. Might be a breaking change, not sure if users rely on this.
- Introduced `c10::Device::MAX_NUM_DEVICES` as a replacement for the old `C10_COMPILE_TIME_MAX_GPUS`
[^1]: This field was unsigned, so I guess this has also been undef behavior the whole time? Our default device index is -1, so this always wrapped around to 255 when written to the `ArgumentInfo` struct. When I switched the `DeviceIndex` to `int16_t`, it actually stayed 255 after unpacking from `ArgumentInfo` again, as the `DeviceIndex` was now wide enough that it didn't wrap back to -1.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119639
Approved by: https://github.com/cyyever, https://github.com/albanD
We can use `__index__` to do this conversion because that will trigger a
guard on data dependent SymInt if the tensor is a fake tensor, but if
we fetch item directly and put it in the Scalar, we may still be able to
make it work out.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117454
Approved by: https://github.com/yanboliang
ghstack dependencies: #117451, #117452
This fastpath is unnecessary because in the logic below we
do the same thing:
```
auto& var = THPVariable_Unpack(obj);
if (var.numel() != 1 ||
!at::isIntegralType(
var.dtype().toScalarType(), /*include_bool*/ true)) {
throw_intlist_exception(this, i, obj, idx);
}
auto scalar = var.item();
TORCH_CHECK(scalar.isIntegral(/*include bool*/ false));
res.push_back(scalar.toSymInt())
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117452
Approved by: https://github.com/yanboliang
ghstack dependencies: #117451