This is a cleaner implementation of opaque objects (https://github.com/pytorch/pytorch/pull/162660). Instead now we just need to do:
Call `register_opaque_type` to register the type as being "opaque" and allowed by custom ops. You also need to pass a unique name that maps to the type.
```python
class OpaqueQueue:
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
super().__init__()
self.queue = queue
self.init_tensor_ = init_tensor_
def push(self, tensor: torch.Tensor) -> None:
self.queue.append(tensor)
def pop(self) -> torch.Tensor:
if len(self.queue) > 0:
return self.queue.pop(0)
return self.init_tensor_
def size(self) -> int:
return len(self.queue)
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
```
When creating the custom op, the schema will then use the unique name:
```python
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT")
torch.library.define(
"_TestOpaqueObject::queue_push",
"(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=self.lib,
)
@torch.library.impl(
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
)
def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
assert isinstance(queue, OpaqueQueue)
queue.push(b)
```
Using the custom op:
```python
queue = OpaqueQueue([], torch.zeros(3))
torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3))
self.assertTrue(queue.size(), 1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165004
Approved by: https://github.com/albanD
This is a cleaner implementation of opaque objects (https://github.com/pytorch/pytorch/pull/162660). Instead now we just need to do:
Call `register_opaque_type` to register the type as being "opaque" and allowed by custom ops. You also need to pass a unique name that maps to the type.
```python
class OpaqueQueue:
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
super().__init__()
self.queue = queue
self.init_tensor_ = init_tensor_
def push(self, tensor: torch.Tensor) -> None:
self.queue.append(tensor)
def pop(self) -> torch.Tensor:
if len(self.queue) > 0:
return self.queue.pop(0)
return self.init_tensor_
def size(self) -> int:
return len(self.queue)
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
```
When creating the custom op, the schema will then use the unique name:
```python
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT")
torch.library.define(
"_TestOpaqueObject::queue_push",
"(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=self.lib,
)
@torch.library.impl(
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
)
def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
assert isinstance(queue, OpaqueQueue)
queue.push(b)
```
Using the custom op:
```python
queue = OpaqueQueue([], torch.zeros(3))
torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3))
self.assertTrue(queue.size(), 1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165004
Approved by: https://github.com/albanD
A big pain point ppl have with custom ops is that they do not accept arbitrary input/outputs. In this PR we create the concept of an "OpaqueObject" which allows users to pass arbitrary python objects into custom operators.
Some still slightly annoying parts with this implementation:
- The schema of the operator is `__torch__.torch.classes.aten.OpaqueObject` instead of whatever python type
- `@torch.library.custom_op` doesn't work.. yet?
UX:
```python
from torch._library.opaque_object import make_opaque, get_payload
# your custom python class
class OpaqueQueue:
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
super().__init__()
self.queue = queue
self.init_tensor_ = init_tensor_
def push(self, tensor: torch.Tensor) -> None:
self.queue.append(tensor)
def pop(self) -> torch.Tensor:
if len(self.queue) > 0:
return self.queue.pop(0)
return self.init_tensor_
def size(self) -> int:
return len(self.queue)
queue = OpaqueQueue([], torch.zeros(3))
obj: torch._C.ScriptObject = make_opaque(queue)
# obj.payload stores a direct reference to this python queue object
self.assertEqual(get_payload(obj), queue)
# This is able to be passed through the dispatcher
torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3))
self.assertTrue(queue.size(), 1)
```
Authoring a custom op:
```python
lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT")
torch.library.define(
f"_TestOpaqueObject::queue_push",
"(__torch__.torch.classes.aten.OpaqueObject a, Tensor b) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl(f"{libname}::queue_push", "CompositeExplicitAutograd", lib=lib)
def push_impl(q: torch._C.ScriptObject, b: torch.Tensor) -> None:
# We can get the payload directly by get_payload(q)
queue = get_payload(q)
assert isinstance(queue, OpaqueQueue)
queue.push(b)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162660
Approved by: https://github.com/zou3519
This took me a bit to figure out and I'm pretty sure I've looked at
this code before. Pybind uses
`return_value_policy::reference_internal` for `def_property`, which
[causes the owning object to be kept alive for the lifespan of the
return
value](https://pybind11.readthedocs.io/en/stable/advanced/functions.html),
allowing the getter to safely avoid copying the property
value. However, lambdas act like they return `auto`, not
`decltype(auto)`, so our lambdas themselves were forcing copies!
Testing: observed std::vector<Argument> copying disappear in Linux
perf profile of someOpInfo._schema.arguments/returns (in
_python_dispatch.correct_storage_aliasing).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161301
Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/wconstab
Summary:
**TL;DR**: make DCE faster by replacing a Set<Value*> with a MemoryLocations sparse bitset (representing all the memory locations stored by the collection of all values in the set).
**Details**
The goal of this PR is to optimize this function from AliasDb:
```
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs) const {
const auto writtenTo = getWrites(n);
if (writtenTo.empty()) {
return false;
}
MemoryLocations locs;
for (const auto v : vs) {
auto it = elementMap_.find(v);
if (it != elementMap_.end()) {
const auto& vlocs = memoryDAG_->getMemoryLocations(it->second);
if (writtenTo.intersects(vlocs)) {
return true;
}
}
}
return false;
}
```
In the DCE use case, we have a ValueSet of live values, into which we insert `Value*`s; and sometimes need to check whether a node mutates any of the live values using `writesToAlias`.
Looping through all the values in the ValueSet and indexing into the elementMap_ is slow; so if we can pre-compute the MemoryLocations set, this speeds up the function. In some large model examples, I see ~15-25x speedups from this change.
**Implementation**: To avoid exposing too many details of AliasDb, I introduce a friend class `ValueAndMemoryLocationSet`, which is an insert-only set of Values, which also maintains the corresponding MemoryLocations.
Then in AliasDb, I use `ValueAndMemoryLocationSet` if we're using AliasDb for analysis, and otherwise use a `Set<Value*>` if we don't have AliasDb.
Test Plan: Rely on unit tests.
Differential Revision: D74827086
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153645
Approved by: https://github.com/eellison
This PR creates two utils for generating a schema for hops from example inputs and use base hop as an exmaple.
1. HopArgumentInfoGen creates an argument or an output schema with mutation information.
2. CFuncitonSchemaGen piece together the argument info of inputs and outputs and produces torch._C.FunctionSchema.
is_write attribute of argument info can be computed. Note that the is_write annotation only works when the inputs are flattened (e.g. cannot support mutation inside tuple). We need special handling the case where we have tuple inputs like cond.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149688
Approved by: https://github.com/zou3519
## Background
This PR adds `torch.utils.serialization.config.load.calculate_storage_offsets`. This option relies on the previous PR in this stack, where storage order was changed to non lexicographical. A `.format_version` entry was added to the zipfile and `calculate_storage_offsets` will only work on checkpoints with `.format_version`.
When this is turned on, for `torch.load(mmap=True)`, offsets of each storage record (other than the 0th storage will be calculated instead of relying on `miniz` APIs to determine this).
The existing APIs will issue multiple random reads (reading the end of central directory record, then reading the zipfile header for the record) to determine the storage offset where the record starts. This can greatly degrade `torch.load(mmap=True)` performance for non-filesystem cases.
6aaae9d78f/caffe2/serialize/inline_container.cc (L589-L605)
## How does this work
The format for the checkpoint is as such
```
archive_name/
|_ data.pkl
|_.format_version
|_byteorder
|_data/
|_ 0
|_ 1
|_ 2
|_ ...
|_
```
Each `data/i` record represents a storage, where storages are written in the order that the Pickler encounters them.
For each storage, our `persistent_load` logic saves the following metadata to the pickle file `dtype, numel, key, location` where `numel` is the number of bytes in the storage.
Note that we always use `miniz` writer in the zip64 mode per [here](7796e308d0/caffe2/serialize/inline_container.cc (L701)) A zipfile record written by miniz looks as such
```
---------------- ----------------- ------------------- ---------------- --------- ------------------------------
| 30 byte header | n byte filename | zip64_extra_data | m byte padding | storage | 16 or 24 byte local dir footer |
---------------- ----------------- ------------------- ---------------- --------- ------------------------------
```
- The header size (30) is given by [`MZ_ZIP_LOCAL_DIR_HEADER_SIZE`](https://github.com/pytorch/pytorch/blob/main/third_party/miniz-3.0.2/miniz.c?fbclid=IwZXh0bgNhZW0CMTEAAR2O8Vysd--UoSCxW70gabXIS1dbz733oHwuUQ5_Ff1hY2WU6PL2i6CSH4A_aem_J9oaU2HpDeWtJKOU9EnVqw#L3290)
- filename will be `"{archive_name}/{filepath}"`
- `zip64_extra_data` is determined by [`mz_zip_writer_create_zip64_extra_data`](7796e308d0/third_party/miniz-3.0.2/miniz.c (L6202)). Note that [we only create zip64_extra_data if storage_size >= 0xFFFFFFFF or the offset of the start of the header >= 0xFFFFFFFF](7796e308d0/third_party/miniz-3.0.2/miniz.c (L6519-L6524))
- `m` is determined by [`getPadding`](7796e308d0/caffe2/serialize/inline_container.cc (L254)), which accounts for filename, zip64_extra_data to determine `m` such that the start of `storage` is aligned to 64 bytes. The `m` bytes will always start with `F B padding_size" as the first 4 bytes
- The local dir footer size is determined based on [this snippet ](7796e308d0/third_party/miniz-3.0.2/miniz.c (L6610-L6632)): if the buffer size is 0 it is skipped. If the zip64_extra_data was created, it is 24, otherwise it is 16.
When `torch.utils.serialization.config.load.calculate_storage_offsets` is set we do the following
- We keep track of where the "cursor" is in the file using `current_offset`, after each persistent_load call, it will be at the offset where the header for the next record starts
- for the 0th storage, "data/0", we use the regular get_record_offset to determine the start of the storage
- for any other storage, (where the storages will be in order encountered by the unpickler, 0, 1, 2, 3, ...) we use `get_record_offset_no_read`, which re-uses the `getPadding` logic to determine the offset of the storage
- Note that `load_tensor` will only ever be called again with the same key if the storage's `._data_ptr()` is 0 [[pointer1](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L1917-L1918)][[pointer2](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L1936-L1937)], so we cache the offsets for this edge case
- After each storage, if the storage is non-zero, we account for the local dir footer based on the logic described above
## Testing strategy
The agreed upon testing strategy was as follows:
- Add debug code gated by an environment flag `TORCH_SERIALIZATION_DEBUG` that will run this offset calculation logic and verify it against getRecordOffset for each storage (when mmap=False)
- This flag is set throughout CI, which means that every time `torch.load` is called, the offset calculation logic is implicitly being tested.
Differential Revision: [D67673026](https://our.internmc.facebook.com/intern/diff/D67673026)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143880
Approved by: https://github.com/albanD
ghstack dependencies: #143879
## Background
This PR adds `torch.utils.serialization.config.load.calculate_storage_offsets`. This option relies on the previous PR in this stack, where storage order was changed to non lexicographical. A `.format_version` entry was added to the zipfile and `calculate_storage_offsets` will only work on checkpoints with `.format_version`.
When this is turned on, for `torch.load(mmap=True)`, offsets of each storage record (other than the 0th storage will be calculated instead of relying on `miniz` APIs to determine this).
The existing APIs will issue multiple random reads (reading the end of central directory record, then reading the zipfile header for the record) to determine the storage offset where the record starts. This can greatly degrade `torch.load(mmap=True)` performance for non-filesystem cases.
6aaae9d78f/caffe2/serialize/inline_container.cc (L589-L605)
## Testing strategy
The agreed upon testing strategy was as follows:
- Add debug code gated by an environment flag `TORCH_SERIALIZATION_DEBUG` that will run this offset calculation logic and verify it against getRecordOffset for each storage (when mmap=False)
- This flag is set throughout CI, which means that every time `torch.load` is called, the offset calculation logic is implicitly being tested.
Differential Revision: [D67673026](https://our.internmc.facebook.com/intern/diff/D67673026)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143880
Approved by: https://github.com/albanD
ghstack dependencies: #143879
When we see a custom op:
- check that its mutation annotations are correct
- check that its aliasing constraints matches our constraints for custom
ops.
Otherwise, there may be undefined behavior.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139212
Approved by: https://github.com/angelayi
Fixes#129403
Create a separate printing function to debug SymNode, since we can't easily change `__repr__` that is used by GraphModule.recompile() to create a pythonic version of a graph
This is my first contribution, please let me know if there is anything that I should look into in further details
Thank you for you guidance! 🙏 I hope to contribute more in the future!
@aorenste
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129925
Approved by: https://github.com/aorenste
At a high level, the idea behind this PR is:
* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.
The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:
* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)
In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.
We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:
* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`
These changes have consequences. First, we need to make some administrative changes:
* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
* In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
* TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.
In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:
* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type
The new asserts uncovered necessary bug fixes:
* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1
Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**
**Reland notes.** This requires this internal fbcode diff https://www.internalfb.com/phabricator/paste/view/P1403322587 but I cannot prepare the diff codev due to https://fb.workplace.com/groups/osssupport/posts/26343544518600814/
It also requires this Executorch PR https://github.com/pytorch/executorch/pull/3911 but the ET PR can be landed prior to this landing.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
At a high level, the idea behind this PR is:
* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.
The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:
* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)
In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.
We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:
* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`
These changes have consequences. First, we need to make some administrative changes:
* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
* In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
* TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.
In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:
* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type
The new asserts uncovered necessary bug fixes:
* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1
Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
At a high level, the idea behind this PR is:
* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.
The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:
* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)
In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.
We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:
* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`
These changes have consequences. First, we need to make some administrative changes:
* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
* In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
* TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.
In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:
* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type
The new asserts uncovered necessary bug fixes:
* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1
Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
Summary:
co-dev reland of https://github.com/pytorch/pytorch/pull/124520, which requires
the removal of some executorch tests.
Before this PR, we didn't check that types in a schema were valid. This
is because TorchScript treats unknown types as type variables.
This PR checks types in a schema for the TORCH_LIBRARY APIs. To do this,
we add an `allow_typevars` flag to parseSchema so that TorchScript can
use allow_typevars=True. We also add some error messages for common
mistakes (e.g. using int64_t or double in schema).
Test Plan: Wait for tests
Differential Revision: D57666659
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126861
Approved by: https://github.com/albanD