Summary:
The implementation adds the ability to:
Set custom metadata strings that will be attached to all subsequent allocations
Clear or change the metadata at any point
View the metadata in memory snapshots via _dump_snapshot()
Test Plan: Added test in test_cuda.py and check manually in snapshot to see that metadata was added.
Differential Revision: D84654933
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165490
Approved by: https://github.com/yushangdi
Summary:
The implementation adds the ability to:
Set custom metadata strings that will be attached to all subsequent allocations
Clear or change the metadata at any point
View the metadata in memory snapshots via _dump_snapshot()
Test Plan: Added test in test_cuda.py and check manually in snapshot to see that metadata was added.
Differential Revision: D84654933
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165490
Approved by: https://github.com/yushangdi
These happen when building with CMAKE_BUILD_TYPE=RelWithAssert
This should fix two types of failures that started with https://github.com/pytorch/pytorch/pull/163665
Disclaimer that I used a lot of AI since I don't how pybind works or what refcounts and pointers are, so idk if this is a good solution, or even a solution at all (fwiw the tests pass now)
The first one type is
Truncated:
```
default_pg, _ = _new_process_group_helper(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2096, in _new_process_group_helper
backend_class = creator_fn(dist_backend_opts, backend_options)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/distributed/fake_pg.py", line 25, in _create_fake_pg
return FakeProcessGroup._create_internal(
RuntimeError: new_refcount != 1 INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/c10/util/intrusive_ptr.h":319, please report a bug to PyTorch. intrusive_ptr: Cannot increase refcount after it reached zero.
Exception raised from retain_ at /var/lib/jenkins/workspace/c10/util/intrusive_ptr.h:319 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from ??:0
#7 c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, char const*) from ??:0
#8 void pybind11::class_<c10d::FakeProcessGroup, (anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup> >::init_instance<(anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup>, 0>(pybind11::detail::instance*, void const*) from init.cpp:0
#9 pybind11::detail::type_caster_generic::cast(void const*, pybind11::return_value_policy, pybind11::handle, pybind11::detail::type_info const*, void* (*)(void const*), void* (*)(void const*), void const*) from :0
#10 pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> >, int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v>(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}&&, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> > (*)(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) from init.cpp:0
```
and I fix it here by getting rid of `DontIncreaseRefcount` and using make_intrusive to do the ref count handling instead. However, I also had to move the constructor to be public, which I think is not good, based on the reasoning of the original PR
The other one type is
```
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/test/test_testing.py", line 2415, in test_no_warning_on_import
self.assertEqual(out, "")
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 4233, in assertEqual
raise error_metas.pop()[0].to_error( # type: ignore[index]
AssertionError: String comparison failed: "/opt/conda/envs/py_3.10/lib/python3.10/s[352 chars]):\n" != ''
- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/__init__.py:29: FutureWarning: pybind11-bound class 'torch._C._distributed_c10d.FakeProcessGroup' is using an old-style placement-new '__init__' which has been deprecated. See the upgrade guide in pybind11's docs. This message is only visible when compiled in debug mode.
- if is_available() and not torch._C._c10d_init():
To execute this test, run the following from the base repo dir:
python test/test_testing.py TestImports.test_no_warning_on_import
```
which I fix by getting rid of the `__init__` which I think is ok since it'll just error if you try to make one?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165479
Approved by: https://github.com/ezyang
This is a cleaner implementation of opaque objects (https://github.com/pytorch/pytorch/pull/162660). Instead now we just need to do:
Call `register_opaque_type` to register the type as being "opaque" and allowed by custom ops. You also need to pass a unique name that maps to the type.
```python
class OpaqueQueue:
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
super().__init__()
self.queue = queue
self.init_tensor_ = init_tensor_
def push(self, tensor: torch.Tensor) -> None:
self.queue.append(tensor)
def pop(self) -> torch.Tensor:
if len(self.queue) > 0:
return self.queue.pop(0)
return self.init_tensor_
def size(self) -> int:
return len(self.queue)
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
```
When creating the custom op, the schema will then use the unique name:
```python
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT")
torch.library.define(
"_TestOpaqueObject::queue_push",
"(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=self.lib,
)
@torch.library.impl(
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
)
def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
assert isinstance(queue, OpaqueQueue)
queue.push(b)
```
Using the custom op:
```python
queue = OpaqueQueue([], torch.zeros(3))
torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3))
self.assertTrue(queue.size(), 1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165004
Approved by: https://github.com/albanD
This is a cleaner implementation of opaque objects (https://github.com/pytorch/pytorch/pull/162660). Instead now we just need to do:
Call `register_opaque_type` to register the type as being "opaque" and allowed by custom ops. You also need to pass a unique name that maps to the type.
```python
class OpaqueQueue:
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
super().__init__()
self.queue = queue
self.init_tensor_ = init_tensor_
def push(self, tensor: torch.Tensor) -> None:
self.queue.append(tensor)
def pop(self) -> torch.Tensor:
if len(self.queue) > 0:
return self.queue.pop(0)
return self.init_tensor_
def size(self) -> int:
return len(self.queue)
register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue")
```
When creating the custom op, the schema will then use the unique name:
```python
self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT")
torch.library.define(
"_TestOpaqueObject::queue_push",
"(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=self.lib,
)
@torch.library.impl(
"_TestOpaqueObject::queue_push", "CompositeExplicitAutograd", lib=self.lib
)
def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None:
assert isinstance(queue, OpaqueQueue)
queue.push(b)
```
Using the custom op:
```python
queue = OpaqueQueue([], torch.zeros(3))
torch.ops._TestOpaqueObject.queue_push(queue, torch.ones(3))
self.assertTrue(queue.size(), 1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165004
Approved by: https://github.com/albanD
## Summary
- add a CuBLASReductionOption enum so the CUDA context can track reduced-precision and split-K options
- extend the Python bindings, backend helpers, and docs to accept an optional allow_splitk argument for fp16/bf16 matmul controls
- update cuBLAS/cuBLASLt call sites plus dynamo guards and tests to respect the new combinations
## Testing
- python test/test_cuda.py TestCuda.test_cublas_allow_fp16_reduced_precision_reduction_get_set -v *(fails: ModuleNotFoundError: No module named 'psutil')*
------
https://chatgpt.com/codex/tasks/task_e_68e404623178832f8a3e1d34e1e175da
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164766
Approved by: https://github.com/malfet, https://github.com/albanD
`grad_dtype` is a new attribute on Tensor to control gradient dtype:
- Access/setting is leaf-only.
- grad_dtype is respected when (1) when assigning to .grad, and (2) in the engine after the previous node produces incoming gradients for AccumulateGrad. (See table below for details)
- Not setting grad_dtype preserves the current behavior. Accessing it returns `t.dtype`
- `grad_dtype` cannot be set when there is already a `.grad` present and the dtypes conflict.
| `grad_dtype` setting | Setting `.grad` manually | Incoming gradient from autograd engine |
|-----------------------|--------------------------|-----------------------------------------|
| **Default (tensor’s dtype)** | `.grad` must match tensor’s dtype | Engine casts incoming grad to tensor’s dtype |
| **Set to specific dtype** | `.grad` must match that dtype | Engine casts incoming grad to the specified dtype |
| **Set to `None`** | `.grad` may be any dtype | Engine does not cast; accepts incoming grad dtype as-is |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162815
Approved by: https://github.com/albanD
Fixes#162129. Added validation in _rank_not_in_group() to check if ```FakeProcessGroup``` is properly initialized before use, raising a clear error message if ```torch.distributed.init_process_group(backend='fake')``` hasn't been called first.
This prevents silent failures and ensures proper dispatch system integration for all distributed operations.
Added test case test_fake_process_group_direct_usage_error() that validates the error is raised for ```all_reduce``` and ```all_to_all_single``` operations.
Please let me know if additional distributed operators should be tested or if any other updates are needed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163665
Approved by: https://github.com/ezyang
Fixes#156052 and #156444.
This PR setup the privateuseone key in Python to be used as a python backend for pytorch.
Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor.
Changes done in this PR:
1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine.
2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops.
3. python function that invokes the other incantations to setup the privateusekey backend.
This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859
Approved by: https://github.com/albanD
Fixes#156052 and #156444.
This PR setup the privateuseone key in Python to be used as a python backend for pytorch.
Meaning that, after calling `setup_privateuseone_for_python_backend('npy')`, one can use a subclass to with that device to hold arbitrary python data as "device data" and use `torch.library` to register ops that takes that Tensor.
Changes done in this PR:
1. Register an vanilla Device Guard: I extended NoOpDeviceGuard to have allow device index of 0 and to not raise errors when event related functions are accessed. If I don't do those, when calling backward I would get errors. (CPU backend uses NoOpDeviceGuard just fine, although there seems to be special treatment of CPU in the autograd engine.
2. Tensor subclass allows not having `__torch_dispatch__` if the device is not CUDA or CPU. The comment of the check suggests it was to avoid segfault when calling into ops that expects a storage. Here we have a different device so will not call into those ops.
3. python function that invokes the other incantations to setup the privateusekey backend.
This took inspiration of https://github.com/bdhirsh/pytorch_open_registration_example and https://github.com/tinygrad/tinygrad/blob/master/extra/torch_backend/wrapped_tensor.cpp; great thanks to @bdhirsh and @geohot.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157859
Approved by: https://github.com/albanD
## Overview
This PR allows the profiler users to access `Kineto` and `TorchOp` metadata in JSON string format through a new `metadata_json` attribute in `FunctionEvent` objects, which is triggered through a new `expose_kineto_event_metadata` flag in `ExperimentalConfig`.
## Testing
A unit test was added to validate functionality.
## Documentation
Added/updated function doc strings where appropriate.
## Example output
```python
import torch
from torch.profiler import profile
with profile(experimental_config=torch._C._profiler._ExperimentalConfig(expose_kineto_event_metadata=True)) as prof:
res = torch.mm(torch.rand(1024, 1024), torch.rand(1024, 1024))
for event in prof.events():
print(f'name: {event.key}, metadata: {event.metadata_json}')
```
```
name: aten::rand, metadata: "Ev Idx": 0
name: aten::empty, metadata: "Ev Idx": 1
name: aten::uniform_, metadata: "Ev Idx": 2
name: aten::rand, metadata: "Ev Idx": 3
name: aten::empty, metadata: "Ev Idx": 4
name: aten::uniform_, metadata: "Ev Idx": 5
name: aten::mm, metadata: "Ev Idx": 6
name: aten::resolve_conj, metadata: "Ev Idx": 7
name: aten::resolve_conj, metadata: "Ev Idx": 8
name: aten::resolve_conj, metadata: "Ev Idx": 9
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161624
Approved by: https://github.com/sraikund16
In various benchmarks scattered across the repo, the limits for flops/second and memory bandwidth are usually hardcoded for a single device. This utility could help in providing a more structured way to query the device capabilities. If this is approved, we can use it when reporting flops efficiency and bandwidth relative to peak in the benchmarks and tests. The intent is to add more devices, more parameters (e.g. L2 cache bandwidth, NVLink, etc.) for both CPUs and accelerators.
Testing:
```
import torch
if torch.cuda.is_available():
device = torch.cuda.current_device()
mod = torch.get_device_module('cuda')
hw = mod._device_limits.GPULimits(device)
print(hw.get_tflops_per_second(torch.float16))
print(hw.get_tflops_per_second(torch.float32))
print(hw.get_tflops_per_second(torch.float64))
print(hw.get_tflops_per_second(torch.bfloat16))
print(hw.get_tflops_per_second(torch.int8))
print(hw.get_memory_bandwidth_Bps() / 1e9)
print(hw.get_shared_memory_bandwidth_Bps() / 1e9)
# Output on an H100 GPU
1070.53056
535.26528
66.90816
1070.53056
2141.06112
4893.696
33454.08
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162942
Approved by: https://github.com/ngimel, https://github.com/albanD
A big pain point ppl have with custom ops is that they do not accept arbitrary input/outputs. In this PR we create the concept of an "OpaqueObject" which allows users to pass arbitrary python objects into custom operators.
Some still slightly annoying parts with this implementation:
- The schema of the operator is `__torch__.torch.classes.aten.OpaqueObject` instead of whatever python type
- `@torch.library.custom_op` doesn't work.. yet?
UX:
```python
from torch._library.opaque_object import make_opaque, get_payload
# your custom python class
class OpaqueQueue:
def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> None:
super().__init__()
self.queue = queue
self.init_tensor_ = init_tensor_
def push(self, tensor: torch.Tensor) -> None:
self.queue.append(tensor)
def pop(self) -> torch.Tensor:
if len(self.queue) > 0:
return self.queue.pop(0)
return self.init_tensor_
def size(self) -> int:
return len(self.queue)
queue = OpaqueQueue([], torch.zeros(3))
obj: torch._C.ScriptObject = make_opaque(queue)
# obj.payload stores a direct reference to this python queue object
self.assertEqual(get_payload(obj), queue)
# This is able to be passed through the dispatcher
torch.ops._TestOpaqueObject.queue_push(obj, torch.ones(3))
self.assertTrue(queue.size(), 1)
```
Authoring a custom op:
```python
lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT")
torch.library.define(
f"_TestOpaqueObject::queue_push",
"(__torch__.torch.classes.aten.OpaqueObject a, Tensor b) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
@torch.library.impl(f"{libname}::queue_push", "CompositeExplicitAutograd", lib=lib)
def push_impl(q: torch._C.ScriptObject, b: torch.Tensor) -> None:
# We can get the payload directly by get_payload(q)
queue = get_payload(q)
assert isinstance(queue, OpaqueQueue)
queue.push(b)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162660
Approved by: https://github.com/zou3519
The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy.
Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs.
Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs.
There are two major feature gaps in the implementation:
- DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it. However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls.
- Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something?
Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236
Approved by: https://github.com/zdevito, https://github.com/albanD
In various benchmarks scattered across the repo, the limits for flops/second and memory bandwidth are usually hardcoded for a single device. This utility could help in providing a more structured way to query the device capabilities. If this is approved, we can use it when reporting flops efficiency and bandwidth relative to peak in the benchmarks and tests. The intent is to add more devices, more parameters (e.g. L2 cache bandwidth, NVLink, etc.) for both CPUs and accelerators.
Testing:
```
import torch
if torch.cuda.is_available():
device = torch.cuda.current_device()
mod = torch.get_device_module('cuda')
hw = mod._device_limits.GPULimits(device)
print(hw.get_tflops_per_second(torch.float16))
print(hw.get_tflops_per_second(torch.float32))
print(hw.get_tflops_per_second(torch.float64))
print(hw.get_tflops_per_second(torch.bfloat16))
print(hw.get_tflops_per_second(torch.int8))
print(hw.get_memory_bandwidth_Bps() / 1e9)
print(hw.get_shared_memory_bandwidth_Bps() / 1e9)
# Output on an H100 GPU
1070.53056
535.26528
66.90816
1070.53056
2141.06112
4893.696
33454.08
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162942
Approved by: https://github.com/ngimel
Reland of #160532
Summary:
To support exporting a cuda model on a CPU-only machine under fake tensor mode. User commonly need to move sample inputs to the cuda device with .to("cuda:0") or .to("cuda") call. This diff supports this.
I expect the following pattern to work
```
with FakeTensorMode(allow_non_fake_inputs=True):
cuda_module = module.to("cuda:0")
cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs])
with torch.no_grad():
ep = torch.export.export(cuda_module, cuda_sample_inputs)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163016
Approved by: https://github.com/huydhn
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163187
Approved by: https://github.com/angelayi