Compare commits

...

5 Commits

Author SHA1 Message Date
6a4bb0f11e Update operator benchmarks README 2025-11-19 07:58:11 +00:00
f2e6f94081 deprecate check_is_size and guard_size_oblivious (#167198)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167198
Approved by: https://github.com/bobrenjc93
2025-11-17 05:47:40 +00:00
aa504d4d2a [audio hash update] update the pinned audio hash (#167914)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167914
Approved by: https://github.com/pytorchbot
2025-11-17 05:21:29 +00:00
d8ce6f8df9 Enable PyTorch OSS numerics changes, inductor heuristics (#167799)
Test Plan: CI

Differential Revision: D86211542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167799
Approved by: https://github.com/njriasan, https://github.com/eellison
2025-11-17 04:31:44 +00:00
4322354770 [Inductor] optimize scalar welford_reduce (#162709)
**Summary:**
Optimize scalar welford_reduce implementation, combining Welford algorithm with cascade sum to improve numerical stability. Specifically:

1. Use Welford algorithm to compute mean and variance.
2. Use cascade summation when computing sum over input for both mean and variance.

**Example:**
Take https://github.com/pytorch/pytorch/issues/141541 as an example:
```
import torch
import torch.nn as nn
torch.manual_seed(0)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.gn = nn.GroupNorm(num_groups=32, num_channels=32)

    def forward(self, x):
        return self.gn(x)

model = Model().eval()
x = torch.randn(1, 32, 128, 128, 128)

with torch.no_grad():
    output = model(x)
    with torch._inductor.config.patch({"cpp.simdlen": 0}):
        c_model = torch.compile(model)
        c_output = c_model(x)

print(torch.max(torch.abs(output - c_output)))
print(torch.allclose(output, c_output, 1.3e-6, 1e-5))
```
**logs**

- before
```
tensor(0.0005)
False
```
- After
```
tensor(1.4305e-06)
True
```

**Generated code:**
- before
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['float*', 'float*', 'const float*', 'const float*', 'const float*', 'float*'], '''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(float* in_out_ptr0,
                       float* in_out_ptr1,
                       const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr2)
{
    auto out_ptr1 = in_out_ptr0;
    auto out_ptr0 = in_out_ptr1;
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                Welford<float> tmp_acc0 = Welford<float>();
                Welford<float> tmp_acc0_arr[4];
                for (int i = 0; i < 4; i++)
                {
                    tmp_acc0_arr[i] = Welford<float>();
                }
                #pragma omp parallel num_threads(4)
                {
                    int tid = omp_get_thread_num();
                    Welford<float> tmp_acc0_local = Welford<float>();
                    #pragma omp for
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                    {
                        {
                            {
                                auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                                tmp_acc0_local = welford_combine(tmp_acc0_local, tmp0);
                            }
                        }
                    }
                    tmp_acc0_arr[tid] = tmp_acc0_local;
                }
                for (int tid = 0; tid < 4; tid++)
                {
                    tmp_acc0 = welford_combine(tmp_acc0, tmp_acc0_arr[tid]);
                }
                in_out_ptr1[static_cast<int64_t>(x0)] = tmp_acc0.mean;
                in_out_ptr0[static_cast<int64_t>(x0)] = tmp_acc0.m2;
            }
        }
    }
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                {
                    auto tmp0 = out_ptr1[static_cast<int64_t>(x0)];
                    auto tmp6 = in_ptr1[static_cast<int64_t>(x0)];
                    auto tmp8 = out_ptr0[static_cast<int64_t>(x0)];
                    auto tmp11 = in_ptr2[static_cast<int64_t>(x0)];
                    auto tmp1 = static_cast<float>(2097152.0);
                    auto tmp2 = tmp0 / tmp1;
                    auto tmp3 = static_cast<float>(1e-05);
                    auto tmp4 = float(tmp2 + tmp3);
                    auto tmp5 = 1 / std::sqrt(tmp4);
                    auto tmp7 = float(tmp5 * tmp6);
                    auto tmp9 = decltype(tmp8)(-tmp8);
                    auto tmp10 = float(tmp9 * tmp7);
                    auto tmp12 = float(tmp10 + tmp11);
                    in_out_ptr0[static_cast<int64_t>(x0)] = tmp7;
                    in_out_ptr1[static_cast<int64_t>(x0)] = tmp12;
                }
            }
        }
    }
    #pragma omp parallel num_threads(4)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
            {
                #pragma GCC ivdep
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                {
                    {
                        {
                            auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                            auto tmp1 = in_out_ptr0[static_cast<int64_t>(x0)];
                            auto tmp3 = in_out_ptr1[static_cast<int64_t>(x0)];
                            auto tmp2 = float(tmp0 * tmp1);
                            auto tmp4 = float(tmp2 + tmp3);
                            out_ptr2[static_cast<int64_t>(x1 + 2097152L*x0)] = tmp4;
                        }
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1 = args
        args.clear()
        assert_size_stride(arg0_1, (32, ), (1, ))
        assert_size_stride(arg1_1, (32, ), (1, ))
        assert_size_stride(arg2_1, (1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1))
        buf0 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf1 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf3 = reinterpret_tensor(buf1, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf1  # reuse
        buf4 = reinterpret_tensor(buf0, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf0  # reuse
        buf5 = empty_strided_cpu((1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1), torch.float32)
        # [Provenance debug handles] cpp_fused_native_group_norm_0:1
        cpp_fused_native_group_norm_0(buf3, buf4, arg2_1, arg0_1, arg1_1, buf5)
        del arg0_1
        del arg1_1
        del arg2_1
        return (buf5, )
```

- After
```
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['float*', 'float*', 'const float*', 'const float*', 'const float*', 'float*'], '''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(float* in_out_ptr0,
                       float* in_out_ptr1,
                       const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr2)
{
    auto out_ptr1 = in_out_ptr0;
    auto out_ptr0 = in_out_ptr1;
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                Welford<float> tmp_acc0 = Welford<float>();
                Welford<float> tmp_acc0_arr[4];
                for (int i = 0; i < 4; i++)
                {
                    tmp_acc0_arr[i] = Welford<float>();
                }
                #pragma omp parallel num_threads(4)
                {
                    int tid = omp_get_thread_num();
                    WelfordHelper<float, float, 4096> scalar_welford_helper0(static_cast<int64_t>(524288L));
                    Welford<float> tmp_acc0_local = Welford<float>();
                    #pragma omp for
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                    {
                        {
                            {
                                auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                                tmp_acc0_local = welford_combine(tmp_acc0_local, tmp0, &scalar_welford_helper0);
                            }
                        }
                    }
                    tmp_acc0_local = welford_combine(tmp_acc0_local, &scalar_welford_helper0);
                    tmp_acc0_arr[tid] = tmp_acc0_local;
                }
                for (int tid = 0; tid < 4; tid++)
                {
                    tmp_acc0 = welford_combine(tmp_acc0, tmp_acc0_arr[tid]);
                }
                in_out_ptr1[static_cast<int64_t>(x0)] = tmp_acc0.mean;
                in_out_ptr0[static_cast<int64_t>(x0)] = tmp_acc0.m2;
            }
        }
    }
    {
        #pragma GCC ivdep
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
        {
            {
                {
                    auto tmp0 = out_ptr1[static_cast<int64_t>(x0)];
                    auto tmp6 = in_ptr1[static_cast<int64_t>(x0)];
                    auto tmp8 = out_ptr0[static_cast<int64_t>(x0)];
                    auto tmp11 = in_ptr2[static_cast<int64_t>(x0)];
                    auto tmp1 = static_cast<float>(2097152.0);
                    auto tmp2 = tmp0 / tmp1;
                    auto tmp3 = static_cast<float>(1e-05);
                    auto tmp4 = float(tmp2 + tmp3);
                    auto tmp5 = 1 / std::sqrt(tmp4);
                    auto tmp7 = float(tmp5 * tmp6);
                    auto tmp9 = decltype(tmp8)(-tmp8);
                    auto tmp10 = float(tmp9 * tmp7);
                    auto tmp12 = float(tmp10 + tmp11);
                    in_out_ptr0[static_cast<int64_t>(x0)] = tmp7;
                    in_out_ptr1[static_cast<int64_t>(x0)] = tmp12;
                }
            }
        }
    }
    #pragma omp parallel num_threads(4)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(32L); x0+=static_cast<int64_t>(1L))
            {
                #pragma GCC ivdep
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(2097152L); x1+=static_cast<int64_t>(1L))
                {
                    {
                        {
                            auto tmp0 = in_ptr0[static_cast<int64_t>(x1 + 2097152L*x0)];
                            auto tmp1 = in_out_ptr0[static_cast<int64_t>(x0)];
                            auto tmp3 = in_out_ptr1[static_cast<int64_t>(x0)];
                            auto tmp2 = float(tmp0 * tmp1);
                            auto tmp4 = float(tmp2 + tmp3);
                            out_ptr2[static_cast<int64_t>(x1 + 2097152L*x0)] = tmp4;
                        }
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, arg1_1, arg2_1 = args
        args.clear()
        assert_size_stride(arg0_1, (32, ), (1, ))
        assert_size_stride(arg1_1, (32, ), (1, ))
        assert_size_stride(arg2_1, (1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1))
        buf0 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf1 = empty_strided_cpu((1, 32, 1, 1), (32, 1, 32, 32), torch.float32)
        buf3 = reinterpret_tensor(buf1, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf1  # reuse
        buf4 = reinterpret_tensor(buf0, (1, 32, 1, 1), (32, 1, 1, 1), 0); del buf0  # reuse
        buf5 = empty_strided_cpu((1, 32, 128, 128, 128), (67108864, 2097152, 16384, 128, 1), torch.float32)
        # [Provenance debug handles] cpp_fused_native_group_norm_0:1
        cpp_fused_native_group_norm_0(buf3, buf4, arg2_1, arg0_1, arg1_1, buf5)
        del arg0_1
        del arg1_1
        del arg2_1
        return (buf5, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162709
Approved by: https://github.com/CaoE, https://github.com/jansel
2025-11-17 02:52:33 +00:00
15 changed files with 201 additions and 139 deletions

View File

@ -1 +1 @@
07b6cbde121417a70e4dc871adb6d27030e0ce3f
ee1a1350eb37804b94334768f328144f058f14e9

View File

@ -145,6 +145,64 @@ Run torch.add benchmark with tag 'long':
python -m pt.add_test --tag-filter long
```
## CI Regression Tracking
The operator benchmarks are continuously monitored in CI to track performance regressions across a diverse set of CPU and GPU devices. Two GitHub Actions workflows run these benchmarks on a regular schedule:
### CPU Benchmarks
The [operator_benchmark.yml](../../.github/workflows/operator_benchmark.yml) workflow runs operator benchmarks on CPU devices:
**Devices:**
- x86_64: `linux.12xlarge` (Intel/AMD CPUs)
- aarch64: `linux.arm64.m8g.4xlarge` (ARM64 CPUs)
**Operators Tracked:** All operators in the `pt/` directory with tag : `short`
**Schedule:** Weekly on Sundays at 07:00 UTC
**Test Modes:** `short`, `long`, or `all` (default: `short`)
**Triggers:**
- Scheduled runs (weekly)
- Manual workflow dispatch with configurable test mode
- Push to `ciflow/op-benchmark/*` tags
- Pull requests that modify benchmark files
### GPU Microbenchmarks
The [operator_microbenchmark.yml](../../.github/workflows/operator_microbenchmark.yml) workflow runs operator microbenchmarks on GPU devices:
**CUDA Devices:**
- H100 GPUs (`linux.aws.h100`) - CUDA 12.8, sm_80
- A100 GPUs (`linux.aws.a100`) - CUDA 12.8, sm_80
- B200 GPUs (`linux.dgx.b200`) - CUDA 12.8, sm_100
**ROCm Devices:**
- MI300X GPUs (`linux.rocm.gpu.gfx942.1`) - gfx942
**Operators Tracked in CI:** `matmul`, `mm`, `addmm`, `bmm`, `conv` (with tag `long`)
- Other operators in the `pt/` directory can be run ad-hoc using the workflow dispatch
**Schedule:** Daily at 06:00 UTC
**Performance Dashboard:** [PyTorch Operator Microbenchmark Dashboard](https://hud.pytorch.org/benchmark/v3/dashboard/pytorch_operator_microbenchmark)
**Triggers:**
- Scheduled runs (daily)
- Manual workflow dispatch
- Push to `ciflow/op-benchmark/*` tags
### Running Manual Benchmarks
To trigger a manual run of the benchmarks:
1. Navigate to the [GitHub Actions workflows](https://github.com/pytorch/pytorch/actions)
2. Select either `operator_benchmark` or `operator_microbenchmark`
3. Click "Run workflow" in the top right
4. For CPU benchmarks, optionally select a test mode (`short`, `long`, or `all`)
5. Click "Run workflow" to start the benchmark run
## Adding New Operators to the Benchmark Suite
In the previous sections, we gave several examples to show how to run the already available operators in the benchmark suite. In the following sections, we'll step through the complete flow of adding PyTorch operators to the benchmark suite. Existing benchmarks for operators are in the `pt` directory and we highly recommend putting your new operators in those directories as well.

View File

@ -90,12 +90,12 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", primals_4: "f32[u0, u1, u2]"):
ge_1: "Sym(u0 >= 0)" = primals_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = primals_2 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
ge_5: "Sym(u2 >= 0)" = primals_3 >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
ge: "Sym(u0 >= 0)" = primals_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
ge_1: "Sym(u1 >= 0)" = primals_2 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
ge_2: "Sym(u2 >= 0)" = primals_3 >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None
floordiv: "Sym((u0//2))" = primals_1 // 2

View File

@ -727,7 +727,7 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
arg_count = ifdynstaticdefault(4, 5)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
expected_op_count = ifdynstaticdefault(10, 8)
expected_op_count = ifdynstaticdefault(9, 7)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x,)),
@ -747,7 +747,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -784,7 +783,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_1); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -883,7 +881,7 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3)
arg_count = ifdynstaticdefault(4, 5)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0
expected_op_count = ifdynstaticdefault(10, 8)
expected_op_count = ifdynstaticdefault(9, 7)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x,)),
@ -905,7 +903,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -956,7 +953,7 @@ class GraphModule(torch.nn.Module):
y = torch.randn(3)
arg_count = ifdynstaticdefault(5, 6)
# when compiled with dynamic, we don't have upper bound runtime assertions for u0 and u1
expected_op_count = ifdynstaticdefault(17, 13)
expected_op_count = ifdynstaticdefault(15, 11)
out_graph = self._test_wrap_simple(
f,
default_args_generator((x, y)),
@ -977,7 +974,6 @@ class GraphModule(torch.nn.Module):
c: "i64[u0, 1]" = l_x_.nonzero()
sym_size_int_2: "Sym(u0)" = torch.ops.aten.sym_size.int(c, 0)
_check_is_size = torch._check_is_size(sym_size_int_2); _check_is_size = None
ge: "Sym(u0 >= 0)" = sym_size_int_2 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
@ -987,7 +983,6 @@ class GraphModule(torch.nn.Module):
d: "i64[u1, 1]" = l_y_.nonzero(); l_y_ = None
sym_size_int_3: "Sym(u1)" = torch.ops.aten.sym_size.int(d, 0)
_check_is_size_1 = torch._check_is_size(sym_size_int_3); _check_is_size_1 = None
ge_1: "Sym(u1 >= 0)" = sym_size_int_3 >= 0
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_2 = None

View File

@ -3081,15 +3081,12 @@ def forward(self, x, y):
foo = torch.ops.export.foo.default(x, y); x = None
sym_size_int = torch.ops.aten.sym_size.int(foo, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(foo, 1)
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int); sym_constrain_range_for_size_default = None
ge = sym_size_int >= 0; sym_size_int = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
sym_constrain_range_for_size_default_1 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default_1 = None
ge_1 = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_1 = None
bar = torch.ops.export.bar.default(y); y = None
sym_size_int_2 = torch.ops.aten.sym_size.int(bar, 0)
sym_constrain_range_for_size_default_2 = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_2); sym_constrain_range_for_size_default_2 = None
ge_2 = sym_size_int_2 >= 0; sym_size_int_2 = None
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_default_2 = None
return (foo, bar)""",
@ -17743,7 +17740,6 @@ class TestExportCustomClass(TorchTestCase):
def forward(self, x, mask):
masked_select = torch.ops.aten.masked_select.default(x, mask); x = mask = None
sym_size_int_1 = torch.ops.aten.sym_size.int(masked_select, 0)
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None
ge = sym_size_int_1 >= 0
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None
le = sym_size_int_1 <= 1188864

View File

@ -1492,8 +1492,8 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
clone: "f32[s77][1]cpu" = torch.ops.aten.clone.default(arg1_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None
getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]
@ -1513,8 +1513,8 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1)
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(clone); clone = None
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
_to_copy: "f32[u0, 1][1, u0]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None
@ -1538,8 +1538,8 @@ def forward(self, arg0_1: "f32[2][1]cpu"):
def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg1_1)
sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None
alias_default: "f32[s77][1]cpu" = torch.ops.aten.alias.default(arg1_1)
alias_default_1: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.alias.default(convert_element_type)
@ -1557,8 +1557,8 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"):
def forward(self, arg0_1: "f32[2][1]cpu"):
nonzero: "i64[u0, 1][1, u0]cpu" = torch.ops.aten.nonzero.default(arg0_1)
sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0)
ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u0 >= 0)" = sym_size_int >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None
convert_element_type: "f32[u0, 1][1, u0]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None

View File

@ -4449,16 +4449,17 @@ class CPUReproTests(TestCase):
def forward(self, x):
return self.gn(x)
for dynamic in [True, False]:
torch._dynamo.reset()
metrics.reset()
mod = M().eval()
x = torch.randn(1, 32, 128, 128, 128)
with torch.no_grad():
expected = mod(x)
compiled_m = torch.compile(mod, dynamic=dynamic)
actual = compiled_m(x)
self.assertEqual(expected, actual)
for simdlen, dynamic in itertools.product([None, 0], [True, False]):
with config.patch({"cpp.simdlen": simdlen}):
torch._dynamo.reset()
metrics.reset()
mod = M().eval()
x = torch.randn(1, 32, 128, 128, 128)
with torch.no_grad():
expected = mod(x)
compiled_m = torch.compile(mod, dynamic=dynamic)
actual = compiled_m(x)
self.assertEqual(expected, actual)
@torch._dynamo.config.patch(
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True

View File

@ -3532,11 +3532,11 @@ class TestUbackedOps(TestCase):
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
@ -3573,11 +3573,11 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)",
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
@ -3632,21 +3632,21 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
aot_graphs,
"""\
def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", arg3_1: "f32[u2, u3][1, u2]cpu"):
ge_1: "Sym(u2 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u3 >= 0)" = arg2_1 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
ge: "Sym(u2 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge = _assert_scalar = None
ge_1: "Sym(u3 >= 0)" = arg2_1 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
ge_4: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_4, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_4 = _assert_scalar_2 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_2 = _assert_scalar_2 = None
sym_sum: "Sym(u0 + 1)" = torch.sym_sum((1, _local_scalar_dense))
gt: "Sym(u0 + 1 > 0)" = sym_sum > 0; sym_sum = None
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 0 < u0 + 1 on node 'gt'"); gt = _assert_scalar_3 = None
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
ge_5: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_5 = _assert_scalar_4 = None
ge_3: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_3 = _assert_scalar_4 = None
sym_sum_1: "Sym(u1 + 1)" = torch.sym_sum((1, _local_scalar_dense_1))
gt_1: "Sym(u1 + 1 > 0)" = sym_sum_1 > 0; sym_sum_1 = None
_assert_scalar_5 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 + 1 on node 'gt_1'"); gt_1 = _assert_scalar_5 = None
@ -4068,10 +4068,10 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
ge: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None
add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None
mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None
@ -4097,10 +4097,10 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
ge: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar = None
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None
mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None
return (mul_5,)""", # noqa: B950
@ -4283,11 +4283,11 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
@ -4319,11 +4319,11 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)",
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None
ge_1: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None

View File

@ -121,7 +121,7 @@ class TestOpaqueObject(TestCase):
def size_impl_fake(q: OpaqueQueue) -> int:
ctx = torch._custom_op.impl.get_ctx()
u0 = ctx.new_dynamic_size()
torch._check_is_size(u0)
torch._check(u0 >= 0)
return u0
torch.library.define(

View File

@ -33,7 +33,11 @@ from typing import (
TypeVar as _TypeVar,
Union as _Union,
)
from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs
from typing_extensions import (
deprecated as _deprecated,
ParamSpec as _ParamSpec,
TypeIs as _TypeIs,
)
# As a bunch of torch.packages internally still have this check
@ -1735,7 +1739,10 @@ def _check(cond, message=None): # noqa: F811
_check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
# TODO add deprecation annotation
@_deprecated(
"_check_is_size will be removed in a future PyTorch release along with guard_size_oblivious. \
Use _check(i >= 0) instead."
)
def _check_is_size(i, message=None, *, max=None):
"""Checks that a given integer is a valid size (i.e., is non-negative).
You should use this over ``_check(i >= 0)`` because it can prevent

View File

@ -239,7 +239,10 @@ def reduction_combine(
if reduction_type in ("min", "max"):
return f"{reduction_type}_propagate_nan({var}, {next_value})"
if reduction_type == "welford_reduce":
return f"welford_combine({var}, {next_value})"
if helper_val:
return f"welford_combine({var}, {next_value}, &{helper_val})"
else:
return f"welford_combine({var}, {next_value})"
if reduction_type == "welford_combine":
if isinstance(next_value, tuple):
mean, m2, weight = next_value
@ -2194,10 +2197,8 @@ class CppKernel(Kernel):
# sum and welford
# Note: using helper has non-negligible impact on performance
# keep the original behavior for welford_reduce
# acc helper is not used for scalar welford_reduce
if reduction_type == "welford_reduce":
return not use_scalar
return True
# TODO add supports for more data types when needed
if reduction_type == "sum" and dtype == torch.float:
@ -2323,9 +2324,15 @@ class CppKernel(Kernel):
reduction_size = functools.reduce(
operator.mul, self.ranges[self.reduction_depth :]
)
helper_val = self.cascade_helper_cse.generate(
self.compute, f"reduction {reduction_key}", write=False
)
# use welford_helper/cascade_helper for vec kernel
if reduction_type == "welford_reduce":
helper_val = self.welford_helper_cse.generate(
self.compute, f"reduction {reduction_key}", write=False
)
else:
helper_val = self.cascade_helper_cse.generate(
self.compute, f"reduction {reduction_key}", write=False
)
# rename the helper variable to distinguish it from vectorized version
scalar_helper_val = f"scalar_{helper_val}"
self._use_acc_helper(
@ -3092,19 +3099,16 @@ class CppVecKernel(CppKernel):
if self.ranges[self.tiling_idx] % self.tiling_factor
else sympy.Integer(0)
)
# scalar helper for scalar sum is also needed when vec kernel is included
# Note: is it different from welford reduction as welford reduction of scalar version
# does not need helper, and the helper needs the information of reduction size to initialize
if reduction_type == "sum":
scalar_helper_val = f"scalar_{helper_val}"
self._use_acc_helper(
reduction_type,
acc,
scalar_helper_val,
reduction_size,
dtype,
use_scalar=True,
)
# scalar helper for scalar welford_reduce/sum is also needed when vec kernel is included
scalar_helper_val = f"scalar_{helper_val}"
self._use_acc_helper(
reduction_type,
acc,
scalar_helper_val,
reduction_size,
dtype,
use_scalar=True,
)
self._use_acc_helper(
reduction_type, acc, helper_val, helper_vec_range, dtype
)

View File

@ -22,7 +22,6 @@ from typing import Any, Generic, Literal, TYPE_CHECKING, TypeVar, Union
import torch
from torch._dynamo.utils import counters, set_feature_use
from torch._environment import is_fbcode
from torch._inductor import metrics
from torch._prims_common import compute_required_storage_length
from torch.utils._debug_mode import get_active_debug_mode
@ -2470,9 +2469,8 @@ def triton_config_reduction(
rnumels[prefix] *= 2
if num_warps is None:
if reduction_hint == ReductionHint.INNER and not is_fbcode():
# r is contiguous, so ensure that each thread has 8 elements for
# vectorized loads, assuming bf16/fp16
if reduction_hint == ReductionHint.INNER:
# r is contiguous, ensure at least 8 elements per thread
# xblock is usually 1-2, default to giving each thread more work
num_warps = r // 128
else:
@ -2942,7 +2940,7 @@ def _reduction_configs(
)
contiguous_config = make_config(
2 if rnumel <= 2048 and not is_fbcode() else 1, # 1024 or less is persistent
2 if rnumel <= 2048 else 1, # 1024 or less is persistent
min(rnumel, MAX_R0_BLOCK),
register_intensive=register_intensive,
)
@ -2955,7 +2953,7 @@ def _reduction_configs(
outer_config = make_config(64, 8, register_intensive=register_intensive)
# TODO (paulzhan): Test heuristic on AMD and internal testing
# for correctness
if not torch.version.hip and not is_fbcode():
if not torch.version.hip:
outer_config = outer_config_opt()
configs = []

View File

@ -74,6 +74,22 @@ template <typename T, int N>
struct IsVecMaskType<at::vec::VecMask<T, N>> : std::true_type {};
#endif
template <typename T>
struct GetScalarType {
using type = T;
};
#if INDUCTOR_USE_VECTOR_TYPES()
template <typename T>
struct GetScalarType<at::vec::Vectorized<T>> {
using type = T;
};
template <typename T, int N>
struct GetScalarType<at::vec::VectorizedN<T, N>> {
using type = T;
};
#endif
template <typename T, uint64_t kChunkSize>
struct CascadeSumHelper {
// A data struct to help cascade summation:
@ -139,7 +155,7 @@ struct WelfordHelper {
// 1. Save the reciprocal of weights to avoid redundant divisions.
// 2. Save the welford stack, which is used to combine welford reduction
// with cascade summation to improve numerical stability.
static std::vector<typename T::value_type> weight_recps;
static std::vector<typename GetScalarType<T>::type> weight_recps;
std::vector<Welford<T>> welford_stk{};
uint64_t depth{0}; // depth of welford_stk.
uint64_t num_chunks{0}; // number of chunks stored in welford_stk.
@ -154,9 +170,9 @@ struct WelfordHelper {
};
template <typename T, uint64_t kChunkSize>
std::vector<typename T::value_type> WelfordHelper<T, kChunkSize>::weight_recps =
[]() {
using scalar_t = typename T::value_type;
std::vector<typename GetScalarType<T>::type>
WelfordHelper<T, kChunkSize>::weight_recps = []() {
using scalar_t = typename GetScalarType<T>::type;
std::vector<scalar_t> temp(kChunkSize);
for (const auto i : c10::irange(kChunkSize)) {
temp[i] = scalar_t(static_cast<double>(1) / static_cast<double>(i + 1));
@ -202,21 +218,19 @@ Welford<T> welford_combine(
// stability.
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
// https://en.wikipedia.org/wiki/Pairwise_summation
if constexpr (IsVecType<T>::value) {
if (w != nullptr && w->depth > 0 && acc.index == kChunkSize) {
w->welford_stk[0] = welford_combine(w->welford_stk[0], acc);
w->num_chunks += 1;
acc.mean = T(0);
acc.m2 = T(0);
acc.weight = T(0);
acc.index = 0;
uint64_t mask = w->num_chunks;
for (uint64_t j = 1; j < w->depth && (mask & 1) == 0; ++j) {
w->welford_stk[j] =
welford_combine(w->welford_stk[j], w->welford_stk[j - 1]);
w->welford_stk[j - 1] = Welford<T>();
mask >>= 1;
}
if (w != nullptr && w->depth > 0 && acc.index == kChunkSize) {
w->welford_stk[0] = welford_combine(w->welford_stk[0], acc);
w->num_chunks += 1;
acc.mean = T(0);
acc.m2 = T(0);
acc.weight = T(0);
acc.index = 0;
uint64_t mask = w->num_chunks;
for (uint64_t j = 1; j < w->depth && (mask & 1) == 0; ++j) {
w->welford_stk[j] =
welford_combine(w->welford_stk[j], w->welford_stk[j - 1]);
w->welford_stk[j - 1] = Welford<T>();
mask >>= 1;
}
}
// Add a single data point
@ -224,22 +238,18 @@ Welford<T> welford_combine(
auto new_weight = acc.weight + T(1);
auto delta = data - acc.mean;
T new_mean;
if constexpr (!IsVecType<T>::value) {
new_mean = acc.mean + delta / new_weight;
} else {
// use new_index to fecth 1 / new_weight to avoid divisions
new_mean = acc.mean +
((w == nullptr || acc.index >= w->weight_recps.size())
? delta / new_weight
: delta * T(w->weight_recps[acc.index]));
}
// use new_index to fecth 1 / new_weight to avoid divisions
new_mean = acc.mean +
((w == nullptr || acc.index >= w->weight_recps.size())
? delta / new_weight
: delta * T(w->weight_recps[acc.index]));
auto new_delta = data - new_mean;
auto result =
Welford<T>{new_mean, acc.m2 + delta * new_delta, new_weight, new_index};
return result;
}
template <typename T, uint64_t kChunkSize = 0>
template <typename T, uint64_t kChunkSize>
Welford<T> welford_combine(Welford<T>& acc, WelfordHelper<T, kChunkSize>* w) {
for (const auto i : c10::irange(w->depth)) {
acc = welford_combine(acc, w->welford_stk[i]);
@ -256,7 +266,7 @@ struct IndexValue {
};
#if INDUCTOR_USE_VECTOR_TYPES()
template <typename T, uint64_t kChunkSize>
template <typename T, uint64_t kChunkSize = 0>
Welford<T> welford_combine(
Welford<T>& acc,
T& data,

View File

@ -470,6 +470,10 @@ def has_static_value(a: Union[SymBool, SymFloat, SymInt, bool, float, int]) -> b
return a.node.shape_env.bound_sympy(a.node.expr).is_singleton() # type: ignore[union-attr]
@deprecated(
"guard_size_oblivious will be removed. Consider using explicit unbacked handling \
potentially utilizing guard_or_false, guard_or_true, or statically_known_true"
)
def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool:
"""
Perform a guard on a symbolic boolean expression in a size oblivious way.

View File

@ -576,17 +576,6 @@ def insert_deferred_runtime_asserts(
if i0 in constrained_unbacked_symbols:
continue # constrain symbol just once
if i0 in shape_env.size_like:
if export:
graph.call_function(
torch.ops.aten.sym_constrain_range_for_size.default,
(expr_to_proxy[i0].node,),
)
else:
graph.call_function(
torch._check_is_size, (expr_to_proxy[i0].node,)
)
vr = shape_env.var_to_range[i0]
if vr.is_int and vr.upper == sys.maxsize - 1:
# treat upper bound == sys.maxsize - 1 for int symbols as +oo