Compare commits

..

4 Commits

Author SHA1 Message Date
30a97721b0 Apply suggestions from code review 2025-11-10 11:37:44 -08:00
35135d834e check in 2025-11-07 16:33:06 +00:00
192034c41b [easy][dynamo][pytree] simplify pytree polyfill module by move out the guard-if (#167221)
Move the guard-if in `polyfills.pytree` to `polyfills.loader` and dedent the code in the if-branch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167221
Approved by: https://github.com/Lucaskabela
2025-11-07 15:23:03 +00:00
5bfce8f345 Unit test for torch.compile bmm dtype (#167140)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167140
Approved by: https://github.com/atalman, https://github.com/mlazos
2025-11-07 14:59:00 +00:00
10 changed files with 903 additions and 1028 deletions

View File

@ -267,15 +267,15 @@ void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, con
* outer dimensions, which contains several "inner rows").
* Each thread processes a single inner row at a time.
*/
template<typename scalar_t, class BinaryOp>
template<typename scalar_t, typename index_t, class BinaryOp>
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, const scalar_t *src_,
const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
const scalar_t init, BinaryOp binary_op)
{
for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
const scalar_t *src = src_ + orow * row_size * num_irows + irow;
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
const scalar_t *src = src_ + static_cast<index_t>(orow) * row_size * num_irows + irow;
scalar_t *tgt = tgt_ + (index_t) orow * row_size * num_irows + irow;
scalar_t acc = init;
for (uint32_t col = 0; col < row_size; ++col) {
@ -409,10 +409,15 @@ __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
check_fits_in_unsigned(num_irows, "num_irows");
check_fits_in_unsigned(num_orows, "num_orows");
check_fits_in_unsigned(row_size, "row_size");
tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
if (static_cast<size_t>(num_irows) * num_orows * row_size <= UINT_MAX) {
tensor_kernel_scan_outer_dim<scalar_t, uint32_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
num_orows, num_irows, row_size, init, binary_op);
} else {
tensor_kernel_scan_outer_dim<scalar_t, size_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),
num_orows, num_irows, row_size, init, binary_op);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

View File

@ -1913,6 +1913,29 @@ class TestMaxAutotune(TestCase):
# Check that contiguous transform was used
FileCheck().check("contiguous_mm").run(code[0])
@unittest.skipIf(config.cpp_wrapper, "out_dtype override not supported for AOTI")
@unittest.skipIf(TEST_WITH_ROCM, "out_dtype override only available on NVIDIA")
def test_bmm_out_dtype(self):
def f(a, b):
return torch.bmm(a, b, out_dtype=torch.float32)
a = torch.randn(2, 3, 4, device=GPU_TYPE, dtype=torch.float16)
b = torch.randn(2, 4, 5, device=GPU_TYPE, dtype=torch.float16)
with config.patch(
max_autotune=True,
max_autotune_gemm_backends="TRITON",
):
compiled_f = torch.compile(f)
with self.assertRaisesRegex(
torch._inductor.exc.InductorError,
r"LoweringException: NoValidChoicesError: No choices to select",
):
out, code = run_and_get_code(compiled_f, a, b)
compiled_f = torch.compile(f)
out, code = run_and_get_code(compiled_f, a, b)
FileCheck().check("extern_kernels.bmm_dtype").run(code[0])
def test_triton_template_generated_code_cache_key(self):
generate_and_load_args = len(
inspect.signature(

View File

@ -1781,6 +1781,13 @@ class TestTorchDeviceType(TestCase):
self.assertEqual(b[0, :], d[0, :], atol=3e-5, rtol=3e-5)
self.assertEqual(b[-1, :], d[-1, :], atol=3e-5, rtol=3e-5)
@onlyCUDA
def test_cumsum_outer_dim_64bit_indexing(self, device):
x = torch.zeros(309504, 1, 16384, device=device)
torch.exp(x)
cumsum = torch.cumsum(x, dim=1)
self.assertEqual(cumsum.max().item(), 0., atol=0., rtol=0.)
@expectedFailureMeta # expected a non-determinitic error, but it was not raised
@onlyNativeDeviceTypes
def test_nondeterministic_alert_put(self, device):

View File

@ -4,6 +4,8 @@
import importlib
from typing import TYPE_CHECKING
import torch.utils._pytree as python_pytree
from .. import polyfills, trace_rules
@ -19,12 +21,14 @@ POLYFILLED_MODULE_NAMES: tuple[str, ...] = (
"itertools",
"operator",
"os",
"pytree",
"struct",
"sys",
"fx",
"tensor",
)
if python_pytree._cxx_pytree_dynamo_traceable:
POLYFILLED_MODULE_NAMES += ("pytree",)
POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple(
importlib.import_module(f".{submodule}", package=polyfills.__name__)
for submodule in POLYFILLED_MODULE_NAMES

File diff suppressed because it is too large Load Diff

View File

@ -1991,7 +1991,7 @@ class BuiltinVariable(VariableTracker):
# If the object implements a __getitem__ method, iter(...) will call obj.__getitem__()
# with an integer argument starting at 0, until __getitem__ raises IndexError
ret = variables.UserFunctionVariable(
polyfills.builtins.iter_ # type: ignore[arg-type]
polyfills.builtins.iter_
).call_function(tx, [obj, *args], {})
if args:

File diff suppressed because it is too large Load Diff

View File

@ -590,7 +590,7 @@ class FilterVariable(IteratorVariable):
else:
res = self.fn.call_function(tx, [item], {})
pred_res = variables.UserFunctionVariable(
polyfills.predicate # type: ignore[arg-type]
polyfills.predicate
).call_function(tx, [res], {})
if pred_res.as_python_constant():
return item

View File

@ -1498,7 +1498,6 @@ class NamedTupleVariable(TupleVariable):
variables.UserDefinedClassVariable(self.tuple_cls),
)
elif isinstance(method, staticmethod):
# pyrefly: ignore[bad-argument-type]
return UserFunctionVariable(method.__func__)
elif inspect.isfunction(method):
return UserMethodVariable(method, self)

View File

@ -472,12 +472,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
)
elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined]
name_to_arg_map = bind_args_cached(
# pyrefly: ignore[bad-argument-type]
self.value,
tx,
self.source,
args,
kwargs,
self.value, tx, self.source, args, kwargs
)
backends = name_to_arg_map["backends"].as_python_constant()
set_priority = name_to_arg_map["set_priority"].as_python_constant()
@ -1434,7 +1429,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
packed_input_vt = TupleVariable.build(
tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs))
)
out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type]
out_vt = variables.UserFunctionVariable(tree_flatten).call_function(
tx, [packed_input_vt], {}
)
assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2