Compare commits

..

8 Commits

Author SHA1 Message Date
8f920b4799 Update on "bf16 support for per_channel bwd"
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.

For testing, we upcast to fp32 before calling the reference function.

[ghstack-poisoned]
2025-10-14 11:51:45 -07:00
7721d5a806 Update on "bf16 support for per_channel bwd"
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.

For testing, we upcast to fp32 before calling the reference function.

[ghstack-poisoned]
2025-10-14 11:51:28 -07:00
e51552344e Update on "bf16 support for per_channel bwd"
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.

For testing, we upcast to fp32 before calling the reference function.

[ghstack-poisoned]
2025-10-14 11:51:02 -07:00
a12e2d1296 Update on "bf16 support for per_channel bwd"
[ghstack-poisoned]
2025-10-14 11:49:14 -07:00
04b5566455 Update on "bf16 support for per_channel bwd"
[ghstack-poisoned]
2025-10-14 11:49:11 -07:00
b014c1a894 bf16 support for per_channel bwd
[ghstack-poisoned]
2025-10-14 11:48:04 -07:00
3401665110 Patch the flex_attention._get_mod_type to not use inspect.signature when computing num_positional_args (an alternative fix for flex attention graph break on create_block_mask) (#164923)
The initial fix for inspect.signature uses not a right approach (https://github.com/pytorch/pytorch/pull/164349#pullrequestreview-3306614010). As @williamwen42 suggests (https://github.com/pytorch/pytorch/pull/164349#issuecomment-3379222885) we can just for now get rid of `inspect.signature` call in flex_attention to resolve this high priority issue (https://github.com/pytorch/pytorch/issues/164247#issuecomment-3378673179). In this PR I did exactly this - limited the scope of fix to just computing `num_positional_args` in `flex_attention._get_mod_type` based on properties returned by `NestedUserFunctionVariable.const_getattr` (some were missing so I added them)

Fixes #164247

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164923
Approved by: https://github.com/williamwen42
2025-10-14 18:29:15 +00:00
8c60f4ae08 [Distributed] update table in docs (#165009)
Fixes #162248

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165009
Approved by: https://github.com/ezyang
2025-10-14 18:17:22 +00:00
12 changed files with 112 additions and 65 deletions

View File

@ -105,16 +105,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
data_ptr_.clear();
}
void retain_wrapper() const override {
PyObject *obj = pyobj_slot_._unchecked_untagged_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void release_wrapper() const override {
PyObject *obj = pyobj_slot_._unchecked_untagged_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj, false);
}
size_t nbytes() const {
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
TORCH_CHECK(!size_bytes_is_heap_allocated_);

View File

@ -24,6 +24,10 @@ void PyObjectSlot::maybe_destroy_pyobj() {
}
}
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<PyObject*>(

View File

@ -33,9 +33,7 @@ struct C10_API PyObjectSlot {
// Query the PyObject interpreter. This may return null if there is no
// interpreter. This is racy!
PyInterpreter* pyobj_interpreter() const {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyInterpreter* pyobj_interpreter();
PyObject* _unchecked_untagged_pyobj() const;

View File

@ -33,7 +33,6 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
constexpr uint64_t kReferenceCountOne = 1;
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
constexpr uint64_t kHasWrapper = (uint64_t(1) << 63);
template <class TTarget>
struct intrusive_target_default_null_type final {
@ -56,11 +55,7 @@ inline uint32_t refcount(uint64_t combined_refcount) {
}
inline uint32_t weakcount(uint64_t combined_refcount) {
return static_cast<uint32_t>((combined_refcount & ~kHasWrapper) >> 32);
}
inline bool has_wrapper(uint64_t combined_refcount) {
return (combined_refcount & kHasWrapper) != 0;
return static_cast<uint32_t>(combined_refcount >> 32);
}
// The only requirement for refcount increment is that it happens-before
@ -71,6 +66,12 @@ inline uint64_t atomic_combined_refcount_increment(
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
}
inline uint32_t atomic_refcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::refcount(atomic_combined_refcount_increment(
combined_refcount, kReferenceCountOne));
}
inline uint32_t atomic_weakcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::weakcount(atomic_combined_refcount_increment(
@ -254,9 +255,6 @@ class C10_API intrusive_ptr_target {
*/
virtual void release_resources() {}
virtual void retain_wrapper() const {}
virtual void release_wrapper() const {}
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
return detail::refcount(combined_refcount_.load(order));
}
@ -316,15 +314,11 @@ class intrusive_ptr final {
void retain_() {
if (target_ != NullType::singleton()) {
uint64_t combined = detail::atomic_combined_refcount_increment(
target_->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined);
uint32_t new_refcount =
detail::atomic_refcount_increment(target_->combined_refcount_);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
new_refcount != 1,
"intrusive_ptr: Cannot increase refcount after it reached zero.");
if (C10_UNLIKELY(detail::has_wrapper(combined) && new_refcount == 2)) {
target_->retain_wrapper();
}
}
}
@ -343,9 +337,9 @@ class intrusive_ptr final {
auto combined_refcount = detail::atomic_combined_refcount_decrement(
target_->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined_refcount);
if (new_refcount == 0) {
bool should_delete = detail::weakcount(combined_refcount) == 1;
if (detail::refcount(combined_refcount) == 0) {
bool should_delete =
(combined_refcount == detail::kWeakReferenceCountOne);
// See comment above about weakcount. As long as refcount>0,
// weakcount is one larger than the actual number of weak references.
// So we need to decrement it here.
@ -362,8 +356,6 @@ class intrusive_ptr final {
if (should_delete) {
delete target_;
}
} else if (detail::has_wrapper(combined_refcount) && new_refcount == 1) {
target_->release_wrapper();
}
}
}
@ -1068,12 +1060,7 @@ namespace intrusive_ptr {
// NullType::singleton to this function
inline void incref(intrusive_ptr_target* self) {
if (self) {
uint64_t new_refcount = detail::atomic_combined_refcount_increment(
self->combined_refcount_, detail::kReferenceCountOne);
if (C10_UNLIKELY(detail::has_wrapper(new_refcount) &&
detail::refcount(new_refcount) == 2)) {
self->retain_wrapper();
}
detail::atomic_refcount_increment(self->combined_refcount_);
}
}

View File

@ -51,7 +51,7 @@ MPI supports CUDA only if the implementation used to build PyTorch supports it.
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
| reduce_scatter | ✓ | ✓ | ✘ | ✘ | ✘ | ✓ | ✘ | ✓ |
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
| all_to_all | | | ✓ | ? | ✘ | ✓ | ✘ | ✓ |
| all_to_all | | | ✓ | ? | ✘ | ✓ | ✘ | ✓ |
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+
| barrier | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ✘ | ✓ |
+----------------+-----+-----+-----+-----+-----+-----+-----+-----+

View File

@ -46,6 +46,7 @@ from torch._dynamo.backends.debugging import ExplainWithBackend
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import (
CompileCounter,
CompileCounterWithBackend,
EagerAndRecordGraphs,
rand_strided,
same,
@ -54,6 +55,7 @@ from torch._dynamo.testing import (
)
from torch._inductor.utils import fresh_cache
from torch.nn import functional as F
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.profiler import profile, ProfilerActivity
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
@ -7369,6 +7371,67 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
)
self.assertEqual(explain_output.break_reasons[0].reason, expected_msg)
@parametrize("backend", ["eager", "inductor"])
def test_issue164247(self, backend: str):
if backend == "inductor" and torch._dynamo.config.dynamic_shapes:
raise unittest.SkipTest(
"Skip only in dynamic-shapes wrapper (known issue #157612)"
)
class MixedFakeModeModel(nn.Module):
def __init__(self, dim=64):
super().__init__()
self.dim = dim
self.lin = torch.nn.Linear(64, 64)
def forward(self, x):
batch_size, seq_len, _ = x.shape
# Process input first - this creates fake tensors in export's fake mode
processed = self.lin(x)
# Create some computation that depends on processed tensor
intermediate = processed.sum(dim=-1).detach() # Shape: (batch, seq_len)
def dynamic_mask_function(batch_idx, head_idx, q_idx, kv_idx):
threshold = intermediate[
batch_idx, q_idx % seq_len
] # Access the captured tensor
return (kv_idx <= q_idx) & (threshold > 0)
block_mask = create_block_mask(
mask_mod=dynamic_mask_function,
B=batch_size,
H=None,
Q_LEN=seq_len,
KV_LEN=seq_len,
device=x.device,
_compile=False,
)
q = processed.view(batch_size, 1, seq_len, self.dim)
k = processed.view(batch_size, 1, seq_len, self.dim)
v = processed.view(batch_size, 1, seq_len, self.dim)
out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
out = flex_attention(q, k, v, block_mask=block_mask)
return out
backend_counter = CompileCounterWithBackend(backend)
model = MixedFakeModeModel()
compiled = torch.compile(model, backend=backend_counter, fullgraph=True)
if backend == "inductor":
# A known InductorError Issue https://github.com/pytorch/pytorch/issues/157612
with self.assertRaises(RuntimeError):
compiled(torch.randn(2, 128, 64))
else:
compiled(torch.randn(2, 128, 64))
# One graph, so no graph breaks
self.assertEqual(backend_counter.frame_count, 1)
self.assertEqual(len(backend_counter.graphs), 1)
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
def test_sub_alpha_scalar_repro(self, device):

View File

@ -1320,9 +1320,21 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable):
def const_getattr(self, tx, name):
if name == "__name__":
return self.fn_name.as_python_constant()
return self.get_name()
if name == "__code__":
return self.get_code()
if name == "__defaults__":
d = getattr(self, "defaults", None)
return d.as_python_constant() if d else None
return super().const_getattr(tx, name)
def call_obj_hasattr(self, tx: "InstructionTranslator", name):
if name == "__code__":
return variables.ConstantVariable.create(hasattr(self, "code"))
if name == "__defaults__":
return variables.ConstantVariable.create(hasattr(self, "defaults"))
return super().call_obj_hasattr(tx, name)
def has_self(self):
return False

View File

@ -248,7 +248,6 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
if (!Py_IsInitialized())
return;
printf("decref %p (%s)\n", pyobj, Py_TYPE(pyobj)->tp_name);
pybind11::gil_scoped_acquire gil;
// Two possibilities:
// 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
@ -285,7 +284,6 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
if (!Py_IsInitialized())
return;
printf("incref %p (%s)\n", pyobj, Py_TYPE(pyobj)->tp_name);
pybind11::gil_scoped_acquire gil;
Py_INCREF(pyobj);
}

View File

@ -176,13 +176,6 @@ static bool THPStorage_tryPreserve(THPStorage* self) {
return true;
}
static void THPStorage_dealloc(PyObject* self) {
THPStorage* _self = (THPStorage*)self;
printf("dealloc %p\n", _self);
_self->cdata.~MaybeOwned<c10::Storage>();
Py_TYPE(_self)->tp_free(self);
}
static void THPStorage_subclass_dealloc(PyObject* self) {
THPStorage* _self = (THPStorage*)self;
@ -211,7 +204,7 @@ static void THPStorage_subclass_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
}
// base test is unnecessary as THPStorage does not set this
// base test is unnecessary as THPStorae does not set this
if (type->tp_weaklistoffset) {
PyObject_ClearWeakRefs(self);
}
@ -617,7 +610,7 @@ PyTypeObject THPStorageType = {
"torch._C.StorageBase", /* tp_name */
sizeof(THPStorage), /* tp_basicsize */
0, /* tp_itemsize */
THPStorage_dealloc, /* tp_dealloc */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
@ -660,7 +653,7 @@ int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
return -1;
}
// ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPStorage_subclass_dealloc;
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPStorage_subclass_dealloc;
return 0;
}

View File

@ -17,7 +17,3 @@ void clear_slots(PyTypeObject* type, PyObject* self) {
}
}
}
PyObject* wrap(c10::impl::PyObjectSlot* slot) {
return nullptr;
}

View File

@ -1,10 +1,7 @@
#pragma once
#include <torch/csrc/python_headers.h>
#include <c10/core/impl/PyObjectSlot.h>
// This file contains utilities used for handling PyObject preservation
void clear_slots(PyTypeObject* type, PyObject* self);
PyObject* wrap(c10::impl::PyObjectSlot* slot);

View File

@ -266,11 +266,20 @@ def _get_mod_type(fn: Callable) -> _ModificationType:
considered as a score_mod function. If the function has 4 positional arguments, it is
considered as a mask function.
"""
num_positional_args = sum(
1
for param in inspect.signature(fn).parameters.values()
if param.default is inspect.Parameter.empty
)
if hasattr(fn, "__code__"):
code = fn.__code__
num_positional_total = code.co_argcount
defaults = ()
if hasattr(fn, "__defaults__"):
defaults = fn.__defaults__ or ()
num_defaults = len(defaults)
num_positional_args = num_positional_total - num_defaults
else:
num_positional_args = sum(
1
for param in inspect.signature(fn).parameters.values()
if param.default is inspect.Parameter.empty
)
assert num_positional_args == 5 or num_positional_args == 4
if num_positional_args == 5:
return _ModificationType.SCORE_MOD