Compare commits

...

259 Commits

Author SHA1 Message Date
6d21c7a155 Update (base update)
[ghstack-poisoned]
2025-10-20 10:40:33 +00:00
e8cb34dd52 [Inductor] support masked vectorization for the tail_loop for fp8 datatype (#163324)
**Summary:**
Support masked vectorization for the tail_loop for fp8 datatype.

**Example:**
```
import torch

def fn(
    x,
    scale,
    zero_point,
    quant_min,
    quant_max,
    dtype,
):
    x = torch.ops.quantized_decomposed.dequantize_per_tensor(
        x,
        scale,
        zero_point,
        quant_min,
        quant_max,
        dtype,
    )
    x = torch.relu(x)
    x = torch.ops.quantized_decomposed.quantize_per_tensor(
        x, scale, zero_point, quant_min, quant_max, dtype
    )
    return x

quant_min = -128
quant_max = 127
dtype = torch.float8_e4m3fn
x = torch.clamp(torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, quant_min, quant_max).to(dtype)
zero_point = 100
scale = 0.01

with torch.no_grad():
    compiled_fn = torch.compile(fn)
    compiled_fn(x, scale, zero_point, quant_min, quant_max, dtype)
```

**Generated code:**

- Before
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const at::Float8_e4m3fn* in_ptr0,
                       at::Float8_e4m3fn* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
                {
                    for (int64_t x0_tail = static_cast<int64_t>(432L);x0_tail < static_cast<int64_t>(441L); x0_tail++)
                    {
                        auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
                        auto tmp1 = c10::convert<float>(tmp0);
                        auto tmp2 = static_cast<float>(100.0);
                        auto tmp3 = float(tmp1 - tmp2);
                        auto tmp4 = static_cast<float>(0.01);
                        auto tmp5 = float(tmp3 * tmp4);
                        auto tmp6 = c10::convert<float>(tmp5);
                        auto tmp7 = std::max(tmp6, decltype(tmp6)(0));
                        auto tmp8 = float(tmp7 * tmp2);
                        auto tmp9 = std::nearbyint(tmp8);
                        auto tmp10 = float(tmp9 + tmp2);
                        auto tmp11 = static_cast<float>(-128.0);
                        auto tmp12 = max_propagate_nan(tmp10, tmp11);
                        auto tmp13 = static_cast<float>(127.0);
                        auto tmp14 = min_propagate_nan(tmp12, tmp13);
                        auto tmp15 = c10::convert<at::Float8_e4m3fn>(tmp14);
                        out_ptr0[static_cast<int64_t>(x0_tail)] = tmp15;
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
        buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
        # [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
        cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```
- After
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const at::Float8_e4m3fn* in_ptr0,
                       at::Float8_e4m3fn* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
                {
                    auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
                    auto tmp1 = at::vec::convert<float>(tmp0);
                    auto tmp2 = static_cast<float>(100.0);
                    auto tmp3 = at::vec::Vectorized<float>(tmp2);
                    auto tmp4 = tmp1 - tmp3;
                    auto tmp5 = static_cast<float>(0.01);
                    auto tmp6 = at::vec::Vectorized<float>(tmp5);
                    auto tmp7 = tmp4 * tmp6;
                    auto tmp8 = (tmp7);
                    auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
                    auto tmp10 = tmp9 * tmp3;
                    auto tmp11 = tmp10.round();
                    auto tmp12 = tmp11 + tmp3;
                    auto tmp13 = static_cast<float>(-128.0);
                    auto tmp14 = at::vec::Vectorized<float>(tmp13);
                    auto tmp15 = at::vec::maximum(tmp12, tmp14);
                    auto tmp16 = static_cast<float>(127.0);
                    auto tmp17 = at::vec::Vectorized<float>(tmp16);
                    auto tmp18 = at::vec::minimum(tmp15, tmp17);
                    auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
                    tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
        buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
        # [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
        cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163324
Approved by: https://github.com/Xia-Weiwen, https://github.com/mingfeima, https://github.com/jansel
ghstack dependencies: #163316
2025-10-20 01:56:00 +00:00
e9d8973427 [Inductor] support masked vectorization for the tail_loop for float64 datatype (#163316)
**Summary:**
Support masked vectorization for the tail_loop for float64 datatype.

**Example:**
```
import torch

def fn(x):
    return x * x

x = torch.randn((22, 22), dtype=torch.double)
with torch.no_grad():
    compiled_fn = torch.compile(fn)
    compiled_fn(x)
```

**Generated code:**

- Before
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const double* in_ptr0,
                       double* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
                {
                    auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = tmp0 * tmp0;
                    tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
                {
                    for (int64_t x0_tail = static_cast<int64_t>(480L);x0_tail < static_cast<int64_t>(484L); x0_tail++)
                    {
                        auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
                        auto tmp1 = double(tmp0 * tmp0);
                        out_ptr0[static_cast<int64_t>(x0_tail)] = tmp1;
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (22, 22), (22, 1))
        buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
        # [Provenance debug handles] cpp_fused_mul_0:1
        cpp_fused_mul_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```
- After
```
cpp_fused_mul_0 = async_compile.cpp_pybinding(['const double*', 'double*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(const double* in_ptr0,
                       double* out_ptr0)
{
    {
        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(484L); x0+=static_cast<int64_t>(16L))
        {
            {
                if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(480L)))
                {
                    auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                    auto tmp1 = tmp0 * tmp0;
                    tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                }
                if(C10_UNLIKELY(x0 >= static_cast<int64_t>(480L) && x0 < static_cast<int64_t>(484L)))
                {
                    auto tmp0 = at::vec::VectorizedN<double,2>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
                    auto tmp1 = tmp0 * tmp0;
                    tmp1.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(4L));
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (22, 22), (22, 1))
        buf0 = empty_strided_cpu((22, 22), (22, 1), torch.float64)
        # [Provenance debug handles] cpp_fused_mul_0:1
        cpp_fused_mul_0(arg0_1, buf0)
        del arg0_1
        return (buf0, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163316
Approved by: https://github.com/mingfeima, https://github.com/jansel
2025-10-20 01:41:38 +00:00
61d9a5180e [Fix XPU CI] [Inductor UT] Fix test cases broken by community. (#165714)
Fixes #165719, Fixes #165771

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165714
Approved by: https://github.com/jansel
2025-10-19 23:59:04 +00:00
8a8329b51f [ATen] Switch order of blocked reduce when vectorize loads (#165178)
Performance benchmarking, perf neutral:
```
================================================================================================================================================================================================================================================
Tensor Shape         Operation    Full reduce (ms)     Non-Contig dim (ms)    Contig dim (ms)      Full reduce (ms)     Non-Contig dim (ms)    Contig dim (ms)      Full diff %     Non-Contig diff %    Contig diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.015684             0.017056               0.008287             0.016015             0.016929               0.008170                      -2.07%               +0.75%          +1.43%
(256, 256)           sum          0.015774             0.016638               0.007926             0.015811             0.016935               0.008330                      -0.23%               -1.75%          -4.85%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013385             0.025742               0.008629             0.013046             0.026005               0.008924                      +2.60%               -1.01%          -3.31%
(512, 512)           sum          0.013390             0.026059               0.009116             0.013054             0.025696               0.008952                      +2.57%               +1.41%          +1.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014213             0.015467               0.010334             0.013862             0.015082               0.010318                      +2.53%               +2.55%          +0.16%
(1024, 1024)         sum          0.014179             0.015446               0.010774             0.014132             0.015073               0.010350                      +0.33%               +2.47%          +4.10%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018234             0.019487               0.014812             0.018482             0.019397               0.014802                      -1.34%               +0.46%          +0.07%
(2048, 2048)         sum          0.018202             0.019529               0.015195             0.018122             0.019485               0.015129                      +0.44%               +0.23%          +0.44%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.033582             0.039378               0.030751             0.033810             0.039673               0.031019                      -0.67%               -0.74%          -0.86%
(4096, 4096)         sum          0.033604             0.039777               0.030809             0.033530             0.039386               0.031113                      +0.22%               +0.99%          -0.98%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.085824             0.091133               0.084200             0.085431             0.091364               0.084303                      +0.46%               -0.25%          -0.12%
(8192, 8192)         sum          0.085763             0.091442               0.084180             0.085508             0.091419               0.084595                      +0.30%               +0.03%          -0.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.146480             0.147666               0.138807             0.146515             0.147987               0.138930                      -0.02%               -0.22%          -0.09%
(8192, 16384)        sum          0.146446             0.147593               0.138559             0.146151             0.147982               0.139120                      +0.20%               -0.26%          -0.40%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266047             0.265386               0.253837             0.265648             0.265885               0.253652                      +0.15%               -0.19%          +0.07%
(8192, 32768)        sum          0.266093             0.265421               0.253890             0.265458             0.265591               0.253567                      +0.24%               -0.06%          +0.13%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.498632             0.508976               0.481865             0.498237             0.508777               0.481476                      +0.08%               +0.04%          +0.08%
(8192, 65536)        sum          0.498917             0.508202               0.481883             0.498104             0.508016               0.481972                      +0.16%               +0.04%          -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.957633             0.968519               0.938172             0.956766             0.968267               0.938196                      +0.09%               +0.03%          -0.00%
(8192, 131072)       sum          0.956972             0.968140               0.937741             0.957365             0.968404               0.938056                      -0.04%               -0.03%          -0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.906661             1.928377               1.861846             1.907327             1.928811               1.862083                      -0.03%               -0.02%          -0.01%
(8192, 262144)       sum          1.905976             1.928362               1.862399             1.907098             1.928844               1.861782                      -0.06%               -0.02%          +0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.956852             0.970101               0.936524             0.957263             0.969809               0.936965                      -0.04%               +0.03%          -0.05%
(4096, 262144)       sum          0.957117             0.969933               0.936247             0.956675             0.969451               0.936395                      +0.05%               +0.05%          -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.498813             0.511299               0.483415             0.498567             0.511482               0.483376                      +0.05%               -0.04%          +0.01%
(2048, 262144)       sum          0.498813             0.510834               0.483641             0.498875             0.511036               0.483338                      -0.01%               -0.04%          +0.06%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266157             0.276751               0.255192             0.265966             0.276808               0.255544                      +0.07%               -0.02%          -0.14%
(1024, 262144)       sum          0.266133             0.276709               0.255528             0.265658             0.276685               0.255287                      +0.18%               +0.01%          +0.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.085941             0.081184               0.087931             0.085591             0.080832               0.088008                      +0.41%               +0.44%          -0.09%
(512, 131072)        sum          0.085962             0.081107               0.088045             0.085882             0.081160               0.088024                      +0.09%               -0.07%          +0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014203             0.045859               0.010310             0.013885             0.046132               0.010621                      +2.29%               -0.59%          -2.93%
(1000, 1000)         sum          0.014180             0.046165               0.010756             0.013893             0.046109               0.010338                      +2.07%               +0.12%          +4.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.012953             0.016751               0.008536             0.012977             0.016714               0.008916                      -0.18%               +0.22%          -4.26%
(1024, 129)          sum          0.013356             0.016806               0.008722             0.013003             0.017071               0.008611                      +2.71%               -1.55%          +1.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.013075             0.016787               0.009102             0.013116             0.016769               0.008679                      -0.31%               +0.11%          +4.87%
(1024, 257)          sum          0.013092             0.016842               0.008786             0.013126             0.017128               0.008771                      -0.26%               -1.67%          +0.17%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.013662             0.017412               0.010055             0.013659             0.017019               0.010033                      +0.02%               +2.31%          +0.22%
(1024, 587)          sum          0.013636             0.017473               0.010163             0.013642             0.017363               0.010101                      -0.04%               +0.63%          +0.61%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015276             0.027873               0.012531             0.015241             0.027783               0.012467                      +0.23%               +0.32%          +0.51%
(2048, 977)          sum          0.015345             0.027949               0.012192             0.015255             0.027839               0.012485                      +0.59%               +0.40%          -2.35%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.012806             0.014020               0.008291             0.013137             0.014309               0.007908                      -2.52%               -2.02%          +4.84%
(1024, 128)          sum          0.012769             0.014308               0.007924             0.012788             0.014236               0.008038                      -0.15%               +0.51%          -1.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014145             0.023049               0.009143             0.014104             0.023298               0.009501                      +0.29%               -1.07%          -3.77%
(8192, 128)          sum          0.014132             0.023082               0.009638             0.014107             0.023331               0.009244                      +0.18%               -1.07%          +4.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.013420             0.025834               0.008949             0.013368             0.025724               0.008918                      +0.39%               +0.43%          +0.35%
(1024, 130)          sum          0.013300             0.025940               0.009113             0.013266             0.025419               0.008922                      +0.26%               +2.05%          +2.14%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.013993             0.017883               0.009661             0.014275             0.018220               0.009596                      -1.98%               -1.85%          +0.68%
(8192, 130)          sum          0.014026             0.018297               0.010066             0.014326             0.018257               0.009659                      -2.09%               +0.22%          +4.21%
================================================================================================================================================================================================================================================
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165178
Approved by: https://github.com/ngimel
ghstack dependencies: #165494, #164790, #165055
2025-10-19 23:39:05 +00:00
6b80c94901 [FlexAttention] Fix dynamic shaped heads flex_flash check (#165866)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165866
Approved by: https://github.com/BoyuanFeng
ghstack dependencies: #165729
2025-10-19 23:10:16 +00:00
8951df03de test_scaled_matmul_cuda: fix infer_scale_swizzle (#165788)
Extend #165747 fix to other cases.
Add parentheses to clarify operator precedence.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165788
Approved by: https://github.com/jeffdaily, https://github.com/slayton58
2025-10-19 21:42:01 +00:00
8139f33fa5 [dynamo] Add recompile reason for set_stance fail_on_recompile (#165445)
Fixes #163500

### Summary:
For `set_stance("fail_on_recompile")` failures will provide the reason why the recompilation occurred

### Impacts:
module: dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165445
Approved by: https://github.com/williamwen42
2025-10-19 21:12:19 +00:00
a88587348b [dynamo] Clean up assert in dynamo [1/N] (#165430)
Fixes some part of #162852 and #164878. These two issues have some relationship though.

* __->__ #165430

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165430
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42

Co-authored-by: Lucas Kabela <lucasakabela@gmail.com>
2025-10-19 21:00:05 +00:00
633a3b7f67 Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"
This reverts commit fa0db212e717b6cb225159cb32ea3d83baa52381.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3419893217))
2025-10-19 19:20:45 +00:00
fa0db212e7 shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-10-19 18:00:08 +00:00
15ff1cd28b Remove E721 suppression in flake8 (#165855)
Currently all files pass the E721 check.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165855
Approved by: https://github.com/albanD
2025-10-19 17:51:12 +00:00
c73f5080de Migrating some more callsites (#163580)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163580
Approved by: https://github.com/avikchaudhuri
ghstack dependencies: #165582
2025-10-19 15:52:17 +00:00
22ae059d32 AOTI util deprecated flow using the new tracer (#165582)
Reapply of https://github.com/pytorch/pytorch/pull/163260

AOTI utils expect free function sometimes so adjust export API to handle that, haven't seen any methods getting exported. Some AOTI flows also require we populate dynamo_flat_name_to_original_fqn so i just copy how it is done in eval_frame.py. I also cleaned up how we get rid of export_root and fixed some overcomplicated nn_module_stack handling in export code. The logic is simpler now thanks to @anijain2305 .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165582
Approved by: https://github.com/anijain2305
2025-10-19 15:52:16 +00:00
1b121d636e Fix AllocatorConfig parse roundup division bug (#165304)
* #165288
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165304
Approved by: https://github.com/albanD
ghstack dependencies: #165288, #165289, #165291, #165298
2025-10-19 15:34:44 +00:00
1ba808dd97 Refine CUDA BackendStaticInitializer for allocator select (#165298)
* #165288
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165298
Approved by: https://github.com/albanD
ghstack dependencies: #165288, #165289, #165291
2025-10-19 15:34:44 +00:00
b2f5c25b27 Introduce a generic API torch._C._accelerator_setAllocatorSettings (#165291)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165291
Approved by: https://github.com/albanD
ghstack dependencies: #165288, #165289
2025-10-19 15:34:36 +00:00
a1114beed2 Deprecate overlapped functions in CUDAAllocatorConfig (#165289)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165289
Approved by: https://github.com/albanD
ghstack dependencies: #165288
2025-10-19 15:34:26 +00:00
4888ed440e Refine Allocator Config error message friendly (#165288)
* __->__ #165288
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165288
Approved by: https://github.com/albanD
2025-10-19 15:34:17 +00:00
5d62b63a76 [BE] Use Python-3.14 GE build (#165804)
3.14 reached general availability on Oct 7th 2025, so we can remove all pre-release workarounds
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165804
Approved by: https://github.com/yangw-dev, https://github.com/Skylion007, https://github.com/cyyever
2025-10-19 11:45:10 +00:00
57ba575242 [BE][Ez]: Update torch.is_tensor documentation (#165841)
TypeIs propogates the isinstance check with the typing system. They are now equivalent.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165841
Approved by: https://github.com/albanD
2025-10-19 09:24:11 +00:00
ceb11a584d [BE]: Update kleidai submodule to v1.15.0 (#165842)
This mostly just adds a few new kernels and fixes some IMA and performance improvement of prev kernels. Also improves compiler support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165842
Approved by: https://github.com/albanD
2025-10-19 08:25:03 +00:00
33adb276fe [BE][Ez]: Update Eigen to 5.0.0. C++14 support and more! (#165840)
Update Eigen pin to 5.0.0 . Tons of new features and perf improvements. Most importantly updates minimum from C++03 to C++14 giving a ton of performance optimizations like properly implemented move operators, simplified code, etc. Also improved vectorization particularily on ARM. We really only use this library as a fallback for sparse operators, but still useful to update it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165840
Approved by: https://github.com/albanD
2025-10-19 08:00:06 +00:00
e939651972 [audio hash update] update the pinned audio hash (#165807)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165807
Approved by: https://github.com/pytorchbot
2025-10-19 04:45:20 +00:00
3255e7872b Enable all flake8-logging-format rules (#164655)
These rules are enabled by removing existing suppressions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164655
Approved by: https://github.com/janeyx99, https://github.com/mlazos
2025-10-19 00:59:28 +00:00
c4f6619330 Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)
- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA
devices so that it can be reset before execution of the op for each such that ops
with randomness produces the same result for all ranks (note that we are planning a
separate change to add support of per rank rng state). Previously we relied on
op input arguments to deduce which devices to get rng state from. Which doesn't work
for factory functions such torch.randn. Hence this changes switches to uncondionally
collecting rng state from all devices.

- Fixing per rank specific computations in _MaskedPartial and Shard placements discovered
during test enablement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716
Approved by: https://github.com/ezyang
2025-10-18 23:33:24 +00:00
f18041cca8 Fix missing closing quote in __init__.py documentation (#165827)
Title says it all.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165827
Approved by: https://github.com/Skylion007
2025-10-18 22:09:18 +00:00
35e51893bd Remove CUDA 11 workarounds for CUB_SUPPORTS_SCAN_BY_KEY and CUB_SUPPORTS_UNIQUE_BY_KEY (#164637)
`CUB_SUPPORTS_SCAN_BY_KEY` and `CUB_SUPPORTS_UNIQUE_BY_KEY` are true since CUDA 12. This PR removes the old branches and source files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164637
Approved by: https://github.com/ezyang
2025-10-18 20:05:54 +00:00
1f43d17ce6 Fix self assignment (#165816)
This PR removes assignments of the form `var=var`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165816
Approved by: https://github.com/jansel
2025-10-18 18:51:52 +00:00
032bed95cd Various C++ code fixes in LSAN integration (#165818)
This PR extracts the C++ code fixes from #154584, which are fixes in enabling LSAN.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165818
Approved by: https://github.com/ezyang
2025-10-18 17:59:23 +00:00
d14cbb4476 Add NVFP4 two-level scaling to scaled_mm (#165774)
Summary:

* Add second-level scaling dispatch to scaled_mm, tying into optional `alpha` passing
* Add two-level tests

Test Plan:

```
pytest -svv -k "nvfp4_global_scale" test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165774
Approved by: https://github.com/drisspg
2025-10-18 13:06:04 +00:00
f510d0dbc0 Clarrifying input output angle unit in the docs for trigonometric fun… (#161248)
…ctions

Fixes #[160995](https://github.com/pytorch/pytorch/issues/160995)

Modified the docs to clarify that input tensor  values for torch.sin, torch.cos and torch.tan should be in radians and the output tensor  values for torch.acos, torch.asin and torch.atan is in radians.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161248
Approved by: https://github.com/isuruf

Co-authored-by: Isuru Fernando <isuruf@gmail.com>
2025-10-18 11:53:48 +00:00
beb6b62e8c Revert "Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)"
This reverts commit 1b397420f22b22f90a1093233ecd9167656e50cb.

Reverted https://github.com/pytorch/pytorch/pull/165716 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165716#issuecomment-3418083391))
2025-10-18 09:15:49 +00:00
4740ce7787 [CP] Fix load balancer incorrectly assuming batch dimension exists (#165792)
https://github.com/pytorch/pytorch/pull/163617 removes the if/else statement to check if the input buffers have the batch dimension.

This PR fixes the issue and also adds a test.

In the future, we should explicitly ask users to unsqueeze the batch dimension. This is a BC of the existing contract but implicitly infers the batch dimension existence is not safe.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165792
Approved by: https://github.com/XilunWu
2025-10-18 09:11:16 +00:00
ad67170c8b [MPS] sparse matmuls (#165232)
Implements matmuls for sparse tensors. With this commit most of the core sparse operations should be implemented. Fixes:
https://github.com/pytorch/pytorch/issues/156540
https://github.com/pytorch/pytorch/issues/129842

Should be merged after:
https://github.com/pytorch/pytorch/pull/165102

To compare MPS and CPU, you can use this script:
```python
import torch
import time
import matplotlib.pyplot as plt

B, I, J, K = 8, 20000, 20000, 20000
num_iterations = 500

nnz_values = [10, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 100000]
speedups = []

for nnz in nnz_values:
    indices = torch.stack([
        torch.randint(0, B, (nnz,)),
        torch.randint(0, I, (nnz,)),
        torch.randint(0, J, (nnz,)),
    ])
    values = torch.rand(nnz)

    sparse = torch.sparse_coo_tensor(indices, values, size=(B, I, J), device="mps").coalesce()
    dense = torch.randn(B, J, 200, device="mps")

    t1 = time.time()
    for _ in range(num_iterations):
        result = torch.bmm(sparse, dense)
    torch.mps.synchronize()
    t2 = time.time()
    mps_time = (t2 - t1) / num_iterations

    sparse_cpu = sparse.cpu()
    dense_cpu = dense.cpu()
    t1 = time.time()
    for _ in range(num_iterations):
        result_cpu = torch.bmm(sparse_cpu, dense_cpu)
    t2 = time.time()
    cpu_time = (t2 - t1) / num_iterations

    speedup = cpu_time / mps_time
    speedups.append(speedup)
    print(f"nnz={nnz}: MPS={mps_time:.6f}s, CPU={cpu_time:.6f}s, Speedup={speedup:.2f}x")

plt.figure(figsize=(10, 6))
plt.plot(nnz_values, speedups, marker='o', linewidth=2, markersize=8)
plt.xlabel('Number of Non-Zero Elements (nnz)', fontsize=12)
plt.ylabel('Speedup (CPU time / MPS time)', fontsize=12)
plt.title('MPS vs CPU Speedup for Sparse-Dense BMM', fontsize=14)
plt.grid(True, alpha=0.3)
plt.axhline(y=1, color='r', linestyle='--', alpha=0.5)
plt.xscale('log')
plt.tight_layout()
plt.show()
```

## Tested on M1 Pro
<img width="1000" height="600" alt="Figure_1" src="https://github.com/user-attachments/assets/4a2402ec-3dc4-402d-8196-a0426906ca3d" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165232
Approved by: https://github.com/malfet
2025-10-18 09:04:42 +00:00
fdab48a7c1 Enable all PIE rules on ruff (#165814)
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796  Enum contains duplicate value: {value}
PIE808  Unnecessary start argument in range
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
2025-10-18 07:36:18 +00:00
a0948d4d23 [ROCm][inductor] autotune support for persistent reduction kernels (#163908)
After the removal of want_no_x_dim for persistent reduction kernels, we can improve the autotuning setup for persistent reduction kernels.

Currently even with tuning enable, filtering will only try a single config in many cases. Avoid filtering with autotune mode, and override MAX_BLOCK limit. Also we always include tiny_config when autotuning is enabled.

Contributions from several members of the AMD Inductor and Triton teams: @jataylo @iupaikov-amd @AmdSampsa @xiaohuguo2023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163908
Approved by: https://github.com/jansel, https://github.com/PaulZhang12
2025-10-18 07:33:24 +00:00
0bbdd6b8db [ROCm][inductor] heuristic improvements for pointwise kernels (#163197)
Heuristic improvements for pointwise kernels for MI350.

Contributions from several members of the AMD Inductor and Triton teams:
@jataylo @AmdSampsa @iupaikov-amd @@xiaohuguo2023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163197
Approved by: https://github.com/PaulZhang12, https://github.com/eellison, https://github.com/jansel

Co-authored-by: AmdSampsa <sampsa.riikonen@amd.com>
Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
2025-10-18 07:23:41 +00:00
24520b8386 Revert "Enable all PIE rules on ruff (#165814)"
This reverts commit c79dfdc6550e872783aa5cb5fc9e86589bf18872.

Reverted https://github.com/pytorch/pytorch/pull/165814 on behalf of https://github.com/cyyever due to Need to cover more files ([comment](https://github.com/pytorch/pytorch/pull/165814#issuecomment-3417931863))
2025-10-18 07:21:08 +00:00
c79dfdc655 Enable all PIE rules on ruff (#165814)
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796  Enum contains duplicate value: {value}
PIE808  Unnecessary start argument in range
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
2025-10-18 06:40:12 +00:00
e595136187 Enable PLC1802 on ruff (#165813)
This PR enables ruff check `PLC1802`, which detects len calls on sequences in a boolean test context.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165813
Approved by: https://github.com/ezyang
2025-10-18 05:44:14 +00:00
aaac8cb0f5 [1/N] Add strict parameter to Python zip calls (#165531)
Add `strict=True/False` to zip calls in test utils. `strict=True` is passed when possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165531
Approved by: https://github.com/Skylion007
2025-10-18 05:26:33 +00:00
0f0b4bf029 [1/N] Remove unused header inclusion (#165763)
This PR removes unused header inclusion in C++ files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165763
Approved by: https://github.com/Skylion007
2025-10-18 05:23:11 +00:00
b8194268a6 Remove unnecessary noqa suppressions (#164106)
This PR removes unused `noqa` suppressions in Python code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164106
Approved by: https://github.com/albanD
2025-10-18 04:52:41 +00:00
f02e3947f6 Expand type checking to mypy strict files (#165697)
Expands Pyrefly type checking to check the files outlined in the mypy-strict.ini configuration file:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165697
Approved by: https://github.com/ezyang
2025-10-18 04:34:45 +00:00
9095a9dfae [CD] Apply the fix from #162455 to aarch64+cu129 build (#165794)
When trying to bring cu129 back in https://github.com/pytorch/pytorch/pull/163029, I mainly looked at https://github.com/pytorch/pytorch/pull/163029 and missed another tweak coming from https://github.com/pytorch/pytorch/pull/162455

I discover this issue when testing aarch64+cu129 builds in https://github.com/pytorch/test-infra/actions/runs/18603342105/job/53046883322?pr=7373.  Surprisingly, there is no test running for aarch64 CUDA build from what I see in 79a37055e7.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165794
Approved by: https://github.com/malfet
2025-10-18 04:16:24 +00:00
d9f94e0d7d [dynamo] Support fx.traceback.annotate as decorator (#165805)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165805
Approved by: https://github.com/Lucaskabela, https://github.com/SherlockNoMad, https://github.com/yushangdi
2025-10-18 03:58:11 +00:00
23417ae50f [Submodule] Bump FBGEMM to latest (#165544)
Summary:

* FBGEMM submodule updated to main
* CMake updated to reflect necessary changes
* Notably pulls in NVFP4 grouped gemm kernels

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165544
Approved by: https://github.com/cyyever, https://github.com/jeffdaily
2025-10-18 03:58:08 +00:00
e4d6c56ffb Improve dynamo graph capture stack trace for custom ops (#165693)
For a custom op
```
@torch.library.custom_op("my_lib::foo", mutates_args={})
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y
```
ppl could call `torch.ops.my_lib.foo()` or directly call `foo()` in the `forward` of an `nn.Module`

These two calling conventions will lead to the same node in the output graph, but different stack traces.

When directly calling `foo()`, the displayed stack_trace in the graph will be
```
# File: .../pytorch/torch/_library/custom_ops.py:687 in __call__, code: return self._opoverload(*args, **kwargs)
```
This is not useful so we filter it out.

```
python test/functorch/test_aot_joint_with_descriptors.py -k test_custom_op_stack_trace
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165693
Approved by: https://github.com/SherlockNoMad, https://github.com/williamwen42
2025-10-18 03:48:18 +00:00
017d2985f3 set unbacked bindings in reinplace pass for newly created nodes during generalize_scatter decomp (#164948)
Two fixes:
1. in rein_place pass, set unbacked bindings for newly created nodes.
2. In inductor, ComputeBuffer used to miss detecting some used symbols, fixed that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164948
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #164341
2025-10-18 03:20:30 +00:00
c6a8db0b9a Fix issues with generalized_scatter and setitem allocated unbacked symbols. (#164341)
Three fixes:
1. When doing t[u0] +=1  if u0 is unbacked we could allocate a new unbacked symbol during the the indexing of t[u0] (when we fake trace setitem), namely because meta_select does allocate a new unbacked symbol for the storage offset when we do not know if u0>=0 or u0<0.  but the output size/stride of setitem(), does not depend on that new symbol. it's self consumed in setitem so we shall ignore it.

2. Also when we trace through generalized_scatter the applications of the views could allocate unbacked symints
but those do not effect final output, we also shall ignore them.

3.Before accessing strides in lowering we shall materialize.

Address  https://github.com/pytorch/pytorch/issues/114293 and https://github.com/pytorch/pytorch/issues/131911

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164341
Approved by: https://github.com/bobrenjc93
2025-10-18 03:20:30 +00:00
de09bab4b6 [BE]: Update cudnn frontend submodule to 1.15.0 (#165776)
Update cudnn frontend submodule to 1.15.0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165776
Approved by: https://github.com/eqy
2025-10-18 02:23:27 +00:00
c137e222d4 .venv/ in .gitignore (#165418)
`uv venv` creates venv in `.venv/` directory. So, it's useful to have `.venv/` in `.gitignore`, since perhaps more people are using `uv` in their work. As per comment 3592f5f4e5 (diff-bc37d034bad564583790a46f19d807abfe519c5671395fd494d8cce506c42947)

uv docs  that confirms it: https://docs.astral.sh/uv/pip/environments/#using-arbitrary-python-environments
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165418
Approved by: https://github.com/ezyang
2025-10-18 02:00:52 +00:00
cf3a787bbc [annotate] Annotate bw nodes before eliminate dead code (#165782)
Fixes https://github.com/pytorch/torchtitan/pull/1907

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165782
Approved by: https://github.com/SherlockNoMad
2025-10-18 01:54:31 +00:00
de3da77cf7 Thread deterministic config vars to subproc compilation (#165729)
# Summary

TIL (AFTER WAYYYY TOO MUCH INSANITY), that we do not serialize the full set of configs for the subproc compilation.

I found this while working on Flex-attention determinism: https://github.com/meta-pytorch/attention-gym/pull/168

might be good to audit if we need to thread through any more

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165729
Approved by: https://github.com/shunting314, https://github.com/eellison
2025-10-18 01:25:50 +00:00
543ddbf44c [ONNX] Support renaming in dynamic axes to shapes conversion (#165769)
Discovered in ##165748

This PR also deprecates the conversion. ONNX exporter team does not intend to maintain the conversion in long term.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165769
Approved by: https://github.com/justinchuby
2025-10-18 01:11:20 +00:00
e9f4999985 [Code Clean] Replace std::runtime_error with TORCH_CHECK (#165305)
Fixes part of #148114

Including:

- torch/csrc/distributed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165305
Approved by: https://github.com/FFFrog, https://github.com/albanD
2025-10-18 01:08:44 +00:00
29b029648e Fixed issue with GradTrackingTensor not properly propagating sparse layout (#165765)
Fixes #164286

Fixed issue with GradTrackingTensor not properly propagating sparse layout.

@ezyang @jcaip
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165765
Approved by: https://github.com/ezyang
2025-10-18 01:00:53 +00:00
a25a649e70 [Mem Snapshot] Add Metadata Field (#165490)
Summary:
The implementation adds the ability to:

Set custom metadata strings that will be attached to all subsequent allocations
Clear or change the metadata at any point
View the metadata in memory snapshots via _dump_snapshot()

Test Plan: Added test in test_cuda.py and check manually in snapshot to see that metadata was added.

Differential Revision: D84654933

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165490
Approved by: https://github.com/yushangdi
2025-10-17 23:46:02 +00:00
69c33898fa Revert "[Inductor][CuTeDSL] Move load_template up two directories (#165347) (#165576)"
This reverts commit febb60323018948b2b9d2cff35b3cc4e0d0c55c8.

Reverted https://github.com/pytorch/pytorch/pull/165576 on behalf of https://github.com/seemethere due to This was actually reverted internally, current PR is linked to a stale diff so diff train tools think that this is landed via co-dev when it was actually reverted ([comment](https://github.com/pytorch/pytorch/pull/165576#issuecomment-3417510146))
2025-10-17 23:33:17 +00:00
1b397420f2 Enable more DTensor tests in local tensor mode and fix more integration issues (#165716)
- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA
devices so that it can be reset before execution of the op for each such that ops
with randomness produces the same result for all ranks (note that we are planning a
separate change to add support of per rank rng state). Previously we relied on
op input arguments to deduce which devices to get rng state from. Which doesn't work
for factory functions such torch.randn. Hence this changes switches to uncondionally
collecting rng state from all devices.

- Fixing per rank specific computations in _MaskedPartial and Shard placements discovered
during test enablement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716
Approved by: https://github.com/ezyang
2025-10-17 23:28:22 +00:00
fe80f03726 Add B200 files to labeler and update codeowners (#165767)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165767
Approved by: https://github.com/slayton58
2025-10-17 23:24:17 +00:00
e50dc40d28 Revert "Update gm.print_readable to include Annotation (#165397)"
This reverts commit 7a657700131f31577544e93587eb339618677e97.

Reverted https://github.com/pytorch/pytorch/pull/165397 on behalf of https://github.com/malfet due to I don't know how/why, but it breaks windows tests, see 2e22b1a61e/1 ([comment](https://github.com/pytorch/pytorch/pull/165397#issuecomment-3417428128))
2025-10-17 22:35:50 +00:00
2e22b1a61e [pytorch] Composite backend potential fix for is_backend_available (#165061)
Summary: `is_backend_available` takes in a string and expects it to only be backend, if its given a composite (device:backend) string, it fails.

Reviewed By: prashrock

Differential Revision: D81886736

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165061
Approved by: https://github.com/H-Huang
2025-10-17 22:06:36 +00:00
616c6bdf8f [dynamo][ac] Config flag to allow eager and compile AC divergence for side-effects (#165775)
Eager AC/SAC reapplies the mutations (like global dict mutations) in the backward during the recomputation of forward. torch.compile has no easy way to reapply python mutations in the backward. But many users might be ok to skip reapplication of side effects in the backward. They can set this config flag to accept this eager and compile divergence.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165775
Approved by: https://github.com/zou3519
ghstack dependencies: #165734
2025-10-17 22:04:19 +00:00
c18ddfc572 [dynamo][easy] Support torch.accelerator.current_accelerator (#165734)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165734
Approved by: https://github.com/Skylion007
2025-10-17 22:04:19 +00:00
86ebce1766 [precompile] Pass tensor_to_context to backend. (#165702)
Summary:

Fixing a VLLM issue https://github.com/vllm-project/vllm/issues/27040 where
aot precompile fails on some models using symbolic shapes in inductor.

Test Plan:
pp HF_HUB_DISABLE_XET=1 VLLM_ENABLE_V1_MULTIPROCESSING=0 VLLM_USE_AOT_COMPILE=1 vllm bench latency --model microsoft/DialoGPT-small --input-len 128 --output-len 256 --num-iters 50 --dtype float16

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165702
Approved by: https://github.com/tugsbayasgalan
2025-10-17 21:52:04 +00:00
8cb2fb44f2 [Inductor] Support fallback for all gemm like ops (#165755)
Summary: Fill op_override field for bmm aten ops so they can be converted properly in the wrapper_fxir backend

Reviewed By: StellarrZ

Differential Revision: D84840948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165755
Approved by: https://github.com/blaine-rister
2025-10-17 21:08:29 +00:00
ab65498d71 Fix _StridedShard incorrect split (#165533)
https://github.com/pytorch/pytorch/pull/164820 introduced a bug that `_StridedShard` will call parent class `Shard`'s `split_tensor` method, thus results in incorrect data locality. (I think @ezyang spotted this issue, but we have no test to capture this)

Meanwhile, I notice another bug that when we normalize a `_StridedShard`'s placement, it will also trigger parent class `Shard`'s `split_tensor` method because it will create a Shard class [here](0c14f55de6/torch/distributed/tensor/_api.py (L783)). I think we never test `distribute_tensor` for `_StridedShard` before. So I added a test here to compare against ordered shard.

Using classmethod because the _split_tensor logic is different between `Shard` and `_StridedShard`. Basically I want to shard on local tensors without initializing the Shard object:
```
local_tensor = _StridedShard._make_shard_tensor(dim, tensor, mesh, mesh_dim, split_factor=split_factor)
local_tensor = Shard._make_shard_tensor(dim, tensor, mesh, mesh_dim)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165533
Approved by: https://github.com/XilunWu
2025-10-17 20:54:46 +00:00
06d324365c Revert "Escaped html tags name and target to appear as strings (#165543)"
This reverts commit 080365b7d82a3c99c995cab6dc912b7dfe22aa41.

Reverted https://github.com/pytorch/pytorch/pull/165543 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165543#issuecomment-3417102048))
2025-10-17 20:45:48 +00:00
6c9c6e0936 Enable C407 of flake8 (#165046)
This PR enables C407 on flake8. The description is `C407` is `Unnecessary list comprehension - ‘<builtin>’ can take a generator`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165046
Approved by: https://github.com/albanD
2025-10-17 20:15:39 +00:00
2bcd892c86 [distributed] Replace assert statements in distributed checkpoint with explicit checks (#165256)
Fixes partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165256
Approved by: https://github.com/albanD
2025-10-17 20:14:35 +00:00
75e2a9fae3 [annotate] add annotate_fn function decorator (#165703)
Example usage:

```
        @fx_traceback.annotate_fn({"pp_stage": 1})
        def example_function(x):
            return x * x

        class SimpleLinear(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(3, 2)

            def forward(self, x):
                with fx_traceback.annotate({"pp_stage": 0}):
                    y = self.linear(x)
                y = example_function(y)
                return y - 1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165703
Approved by: https://github.com/SherlockNoMad
2025-10-17 20:10:53 +00:00
a16fd6b488 [NVSHMEM][Triton] Fix NVSHMEM triton test for wacky world sizes (#165704)
Currently assumes divisible by 4? world size

Not as slick as the old setup code but more general

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165704
Approved by: https://github.com/Skylion007, https://github.com/kwen2501
2025-10-17 19:33:26 +00:00
382b0150de [docs] Add usage examples to ConvTranspose1d docstring (#165618)
Fixes #165615

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165618
Approved by: https://github.com/mikaylagawarecki
2025-10-17 19:11:57 +00:00
a664b299ac Update docs for torch.mode (#165614)
Currently the docs for `torch.mode` include a note:

`This function is not defined for torch.cuda.Tensor yet.`

However with `torch==2.7.1+cu126` when I try to get the mode of a Tensor that is in cuda memory, I do not face any issues:

```
>>> a = torch.tensor([0, 2, 1, 1, 1, 3, 3])
>>> a.mode()
torch.return_types.mode(
values=tensor(1),
indices=tensor(4))
>>> a.cuda().mode()
torch.return_types.mode(
values=tensor(1, device='cuda:0'),
indices=tensor(4, device='cuda:0'))
```

Am I misunderstanding the note? If not, I suggest removing it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165614
Approved by: https://github.com/mikaylagawarecki
2025-10-17 19:06:33 +00:00
9c12651417 Improve error message for non-positive groups in convolution (#165669)
Prevents from segmentation fault for invalid groups value in convolution.

Fixes #142835

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165669
Approved by: https://github.com/mikaylagawarecki
2025-10-17 19:06:05 +00:00
08c97b4a1f Don't run compile inside kernel invocation (#165687)
When we call torch.compile during fake tensor prop, we shouldn't actually compile because we can't guarantee that the compiled artifact can be fake tensor prop-d. (for example, inductor backend). Instead we should just skip compiling. However, the inner compile will be triggered when being executed in runtime.

Fixes: https://github.com/pytorch/pytorch/issues/151328

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165687
Approved by: https://github.com/zou3519
2025-10-17 19:03:57 +00:00
fae74cd52f Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"
This reverts commit a032510db38e8331afa08f7635d146f9cefdd0ab.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3416718767))
2025-10-17 18:55:53 +00:00
7a65770013 Update gm.print_readable to include Annotation (#165397)
Sample output
```
[rank0]:        # Annotation: {'compile_with_inductor': 'flex_attention'} File: /data/users/bahuang/pytorch/torch/nn/attention/flex_attention.py:1490 in flex_attention, code: out, lse, max_scores = flex_attention_hop(
[rank0]:        score_mod_2 = self.score_mod_2
[rank0]:        mask_fn_2 = self.mask_fn_2
[rank0]:        flex_attention_1 = torch.ops.higher_order.flex_attention(xq_5, xk_5, xv_3, score_mod_2, (2048, 2048, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_indices, 128, 128, mask_fn_2), 0.25, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___mask_mod___closure___0_cell_contents,));  xq_5 = xk_5 = xv_3 = score_mod_2 = mask_fn_2 = None
[rank0]:        out_2: "bf16[8, 4, 2048, 16]" = flex_attention_1[0];  flex_attention_1 = None
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165397
Approved by: https://github.com/yushangdi, https://github.com/anijain2305
2025-10-17 18:35:18 +00:00
e4454947e2 Widen ops support to take in IntHOArrayRef vs only std::vec (#165152)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165152
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991
2025-10-17 18:32:39 +00:00
3806e9767b Refactor out headeronly ArrayRef (#164991)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164991
Approved by: https://github.com/swolchok
2025-10-17 18:32:39 +00:00
b08d8c2e50 Revert "[DebugMode][2/N] add nn.Module tracking (#165498)"
This reverts commit 45afaf08a14ab760d86ea80dea6d50cec8626513.

Reverted https://github.com/pytorch/pytorch/pull/165498 on behalf of https://github.com/seemethere due to First part of the stack was reverted so will need to revert this too ([comment](https://github.com/pytorch/pytorch/pull/165498#issuecomment-3416618198))
2025-10-17 18:22:48 +00:00
ca5b7f8ded torch.compile: populate compiler_config (#165581)
Summary: This starts writing the compiler_config metadata into logger

Test Plan:
Modified existing test case to make sure this is not null.
(Also eyeballed what we're logging tomake sure it's reasonable

Reviewed By: masnesral

Differential Revision: D84014636

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165581
Approved by: https://github.com/masnesral
2025-10-17 18:21:18 +00:00
9a71d96256 Revert "[DebugMode][1/N] refactor logs into _DebugCalls (#165376)"
This reverts commit 556fc09a9f67f24ca5591ec049c5d0c347c5f62a.

Reverted https://github.com/pytorch/pytorch/pull/165376 on behalf of https://github.com/seemethere due to This is failing for internal tests, see D84877379 for more context ([comment](https://github.com/pytorch/pytorch/pull/165376#issuecomment-3416570407))
2025-10-17 18:08:59 +00:00
0d4c2b71e8 [DeviceMesh] Simplify unflatten method (#165556)
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165556
Approved by: https://github.com/fduwjj
ghstack dependencies: #165554, #165555
2025-10-17 17:57:51 +00:00
d659bbde62 [DeviceMesh] Introduce private constructor instead of _create_mesh_from_ranks (#165555)
The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: #165554
2025-10-17 17:57:51 +00:00
58879bfafa [DeviceMesh] Prefer using _layout over _mesh for all sorts of things (#165554)
The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
2025-10-17 17:57:51 +00:00
a032510db3 shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529

To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch.

This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.

For more info:  [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/Skylion007, https://github.com/syed-ahmed, https://github.com/kwen2501
2025-10-17 17:55:03 +00:00
39e0a832c9 Fix B200 test fails in scaled_mm (#165747)
Summary:

PR #165528 changes some scale/swizzle inference behavior in scaled_mm
tests - mxfp8 tests on Blackwell can get incorrectly classified,
resulting in failures.

Fix the scale/swizzle inference code to prevent this.

Fixes https://github.com/pytorch/pytorch/issues/165743

Test Plan:

```
pytest -svv test/test_scaled_matmul_cuda.py
```

Reviewers:

@jagadish-amd @jeffdaily @drisspg

Subscribers:

@Aidyn-A

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165747
Approved by: https://github.com/eqy, https://github.com/drisspg, https://github.com/jeffdaily
2025-10-17 17:52:19 +00:00
dd3b48e85d Fix bug with serialization after AOTAutogradCache hit (#165474)
Fixes #165447

On AOTAutogradCache load, the serialization function we pick is just lambda: self, because the object itself is an AOTAutogradCacheEntry. However, this isn't safe, because `wrap_post_compile` will make `self` unserializable, since it needs to load triton kernels and stuff!

So instead, on AOTAutogradCache load, we preserve the bytes that were used to load the object to begin with, and return that object on a call to serialize(). This effectively makes it so that we save a copy of the pre-hydrated artifact, without needing to do an eager copy until someone actually calls `serialize`.

Test Plan:

Run

```py
import torch

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(2, 4)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(4, 8)
    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

device = "cuda"
m = M().to(device)
sample_inputs = (torch.randn(2, 2, device=device),)
eager_out = m(*sample_inputs)

with torch._dynamo.config.patch("enable_aot_compile", True):
    compiled_fn_path = "./m.pt"
    compiled_fn = torch.compile(
        m,
        fullgraph=True
    ).forward.aot_compile((sample_inputs, {}))

    compiled_fn.save_compiled_function(compiled_fn_path)
    torch._dynamo.reset()
    with torch.compiler.set_stance("fail_on_recompile"):
        with open(compiled_fn_path, "rb") as f:
            loaded_fn = torch.compiler.load_compiled_function(f)

assert loaded_fn is not None

compiled_out = loaded_fn(m, *sample_inputs)

assert torch.allclose(eager_out, compiled_out)
```

twice, see that it succeeds.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165474
Approved by: https://github.com/yiming0416, https://github.com/zhxchen17
2025-10-17 17:47:24 +00:00
cff1b20771 Patch the flex_attention._get_mod_type to not use inspect.signature when computing num_positional_args (an alternative fix for flex attention graph break on create_block_mask) (#164923)
The initial fix for inspect.signature uses not a right approach (https://github.com/pytorch/pytorch/pull/164349#pullrequestreview-3306614010). As @williamwen42 suggests (https://github.com/pytorch/pytorch/pull/164349#issuecomment-3379222885) we can just for now get rid of `inspect.signature` call in flex_attention to resolve this high priority issue (https://github.com/pytorch/pytorch/issues/164247#issuecomment-3378673179). In this PR I did exactly this - limited the scope of fix to just computing `num_positional_args` in `flex_attention._get_mod_type` based on properties returned by `NestedUserFunctionVariable.const_getattr` (some were missing so I added them)

Fixes #164247

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164923
Approved by: https://github.com/williamwen42
2025-10-17 17:44:45 +00:00
da8517fa63 [ROCm][CI] upgrade wheels to 7.0.2 and 6.4.4 patch release (#165756)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165756
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-17 17:41:19 +00:00
45afaf08a1 [DebugMode][2/N] add nn.Module tracking (#165498)
Uses ModTracker to record nn.Module entries, much like CommDebugMode.

Can be switched on with `DebugMode(record_nn_module=True)`:
```
    [nn.Mod] Bar
      [nn.Mod] Bar.abc
        [nn.Mod] Bar.abc.l1
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
        [nn.Mod] Bar.abc.l2
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
      [nn.Mod] Bar.xyz
        aten::t(t: f32[4, 4])
        aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])"""
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165498
Approved by: https://github.com/SherlockNoMad
ghstack dependencies: #165376
2025-10-17 17:39:48 +00:00
080365b7d8 Escaped html tags name and target to appear as strings (#165543)
Fixes small typo in markdown documentation file - Added escape characters to precede tag pattern.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165543
Approved by: https://github.com/mikaylagawarecki
2025-10-17 17:35:18 +00:00
2928c5c572 Revert "Pyrefly suppressions 2 (#165692)"
This reverts commit 43d78423ac224cce432bf34ed9627035169d5433.

Reverted https://github.com/pytorch/pytorch/pull/165692 on behalf of https://github.com/seemethere due to This is causing merge conflicts when attempting to land internally, see D84890919 for more details ([comment](https://github.com/pytorch/pytorch/pull/165692#issuecomment-3416397240))
2025-10-17 17:13:04 +00:00
630520b346 [dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)
Audit: To prevent future issues with functools.partial or callable
objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165707
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #165683, #165706
2025-10-17 17:02:18 +00:00
1dc9a05d03 [dynamo][user_defined] Replace UserFunctionVariable with VariableTracker build (#165706)
Audit: To prevent future issues with functools.partial or callable
objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165706
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #165683
2025-10-17 17:02:18 +00:00
bfcdbd0a97 fix wrong accuracy_status when exception. (#165731)
When I debug `XPU` accruacy issue, I found the script output wrong accuracy_status.
When the `try` block raise an exception, we should process the exception, but not return the `fail_accuracy`.

Before fixing, it returned as `fail_accuracy`:
<img width="1109" height="216" alt="image" src="https://github.com/user-attachments/assets/385c354f-fbf6-48e4-a1be-3e37e987341b" />

After fixing, it returned the exception message:
<img width="1101" height="292" alt="image" src="https://github.com/user-attachments/assets/f18c0e3c-8358-4ec7-a6bb-c2e01b69d27f" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165731
Approved by: https://github.com/Stonepia, https://github.com/chuanqi129, https://github.com/Lucaskabela
2025-10-17 16:37:06 +00:00
faff826a46 Revert "[ROCm] new implementation of upsample_bilinear2d_backward (#164572)"
This reverts commit 53f9ae0e50d4dcc47f2ca4bf854803f9d4f875ae.

Reverted https://github.com/pytorch/pytorch/pull/164572 on behalf of https://github.com/seemethere due to Looks like this is failing in our internal builds, will post a suggestion for a fix but want you to double verify that this behavior is correct ([comment](https://github.com/pytorch/pytorch/pull/164572#issuecomment-3416262676))
2025-10-17 16:27:59 +00:00
85c5433d38 Revert "Fix _StridedShard incorrect split (#165533)"
This reverts commit dfc8a1c5ddc8401197e9ab546e03b0f745edc27b.

Reverted https://github.com/pytorch/pytorch/pull/165533 on behalf of https://github.com/seemethere due to Causing a merge conflict internally, see D84829161 ([comment](https://github.com/pytorch/pytorch/pull/165533#issuecomment-3416143176))
2025-10-17 15:57:01 +00:00
935ccdbe75 [MPS] Fix internal assertion in torch.linalg.solve for singular matrices (#165254)
Fixes #163962 by special casing MPS in the negative status code branch in `_linalg_check_errors`.

Checks if info is [`MPSMatrixDecompositionStatus.singular`](https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixdecompositionstatus/singular) (which has a raw value of -2). I didn't find an official Apple source with this raw value (besides printing the enum value), so I'm not sure if we can (or should) depend on it? Is there a way to directly get the Objective-C enum value in C++?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165254
Approved by: https://github.com/malfet
2025-10-17 15:35:49 +00:00
3af2f0c12a [inductor] require shape in TritonCSEVariable (#162275)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162275
Approved by: https://github.com/mlazos
ghstack dependencies: #164158
2025-10-17 14:47:45 +00:00
6ece527fc5 [CI] Add aarch64 operator benchmark (#165585)
Running on Graviton4
Skip ConvTranspose1d benchmarks if PyTorch is compiled with ACL, due to https://github.com/pytorch/pytorch/issues/165654
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165585
Approved by: https://github.com/huydhn
2025-10-17 14:42:14 +00:00
ce29d0d796 [ATen] Vectorize 8 elements on 16 bit data types for sum/mean (#165055)
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce:
**BF16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022686             0.008263             0.015498             0.008117                          +46.38%               +1.80%
(256, 256)           sum          0.022769             0.008269             0.015628             0.008185                          +45.69%               +1.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014116             0.009545             0.012892             0.008839                           +9.49%               +7.99%
(512, 512)           sum          0.014110             0.009892             0.012891             0.008878                           +9.46%              +11.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014727             0.012642             0.014061             0.010519                           +4.74%              +20.18%
(1024, 1024)         sum          0.014376             0.012636             0.014069             0.010595                           +2.18%              +19.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018663             0.018294             0.018171             0.014678                           +2.71%              +24.64%
(2048, 2048)         sum          0.018638             0.017931             0.018142             0.014713                           +2.73%              +21.87%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034216             0.036953             0.033520             0.030585                           +2.08%              +20.82%
(4096, 4096)         sum          0.034196             0.036942             0.033518             0.030676                           +2.02%              +20.43%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087763             0.095201             0.085439             0.084960                           +2.72%              +12.05%
(8192, 8192)         sum          0.088079             0.095592             0.085353             0.084632                           +3.19%              +12.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148174             0.149705             0.146274             0.138865                           +1.30%               +7.81%
(8192, 16384)        sum          0.147820             0.149371             0.146419             0.138752                           +0.96%               +7.65%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266144             0.260807             0.265953             0.253330                           +0.07%               +2.95%
(8192, 32768)        sum          0.266572             0.261163             0.265729             0.253294                           +0.32%               +3.11%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502034             0.486312             0.498417             0.481246                           +0.73%               +1.05%
(8192, 65536)        sum          0.501597             0.486351             0.497735             0.481579                           +0.78%               +0.99%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971178             0.942988             0.957164             0.938316                           +1.46%               +0.50%
(8192, 131072)       sum          0.971189             0.943232             0.956814             0.937816                           +1.50%               +0.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.953728             1.877648             1.904937             1.861692                           +2.56%               +0.86%
(8192, 262144)       sum          1.953969             1.877538             1.905990             1.862547                           +2.52%               +0.80%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970408             0.940965             0.957871             0.936732                           +1.31%               +0.45%
(4096, 262144)       sum          0.970919             0.941652             0.957765             0.936676                           +1.37%               +0.53%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501477             0.486976             0.497964             0.483570                           +0.71%               +0.70%
(2048, 262144)       sum          0.501955             0.487213             0.498210             0.483218                           +0.75%               +0.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266536             0.257111             0.265642             0.255439                           +0.34%               +0.65%
(1024, 262144)       sum          0.266613             0.257096             0.265427             0.255472                           +0.45%               +0.64%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091200             0.085818             0.087851                           +2.32%               +3.81%
(512, 131072)        sum          0.087788             0.091249             0.085373             0.087944                           +2.83%               +3.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014503             0.012328             0.013663             0.010190                           +6.15%              +20.98%
(1000, 1000)         sum          0.014545             0.012378             0.013662             0.010579                           +6.46%              +17.01%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014163             0.008371             0.012893             0.008828                           +9.85%               -5.18%
(1024, 129)          sum          0.014132             0.008751             0.013234             0.008868                           +6.79%               -1.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014296             0.009101             0.013334             0.008563                           +7.21%               +6.28%
(1024, 257)          sum          0.014302             0.009058             0.013020             0.008672                           +9.85%               +4.45%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014127             0.010997             0.013443             0.009944                           +5.09%              +10.59%
(1024, 587)          sum          0.014471             0.011373             0.013123             0.010354                          +10.27%               +9.84%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015607             0.013566             0.015089             0.012152                           +3.43%              +11.64%
(2048, 977)          sum          0.015953             0.013580             0.015039             0.011861                           +6.08%              +14.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.013982             0.008058             0.012747             0.008139                           +9.69%               -1.00%
(1024, 128)          sum          0.013967             0.008071             0.012726             0.007859                           +9.75%               +2.70%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014378             0.009627             0.013712             0.009395                           +4.86%               +2.47%
(8192, 128)          sum          0.014389             0.009965             0.013718             0.009521                           +4.89%               +4.66%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014156             0.008267             0.012895             0.008833                           +9.78%               -6.41%
(1024, 130)          sum          0.013797             0.008277             0.012903             0.008512                           +6.93%               -2.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.014977             0.010026             0.013911             0.009876                           +7.66%               +1.52%
(8192, 130)          sum          0.014994             0.010043             0.014235             0.009604                           +5.33%               +4.57%
====================================================================================================================================================================================
```

**FP16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022804             0.008298             0.015888             0.007848                          +43.53%               +5.73%
(256, 256)           sum          0.023215             0.008328             0.015677             0.007850                          +48.08%               +6.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013777             0.009988             0.012884             0.008512                           +6.93%              +17.34%
(512, 512)           sum          0.013775             0.009622             0.012870             0.009028                           +7.03%               +6.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014740             0.012322             0.013708             0.010239                           +7.53%              +20.34%
(1024, 1024)         sum          0.014762             0.012756             0.013722             0.010307                           +7.58%              +23.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018700             0.018364             0.018135             0.015078                           +3.12%              +21.79%
(2048, 2048)         sum          0.018276             0.018415             0.018471             0.015127                           -1.06%              +21.74%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034518             0.037000             0.033838             0.030617                           +2.01%              +20.85%
(4096, 4096)         sum          0.034569             0.037448             0.033842             0.031100                           +2.15%              +20.41%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087675             0.095176             0.085328             0.084105                           +2.75%              +13.16%
(8192, 8192)         sum          0.088102             0.095211             0.085707             0.084090                           +2.79%              +13.23%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.147800             0.149263             0.146388             0.138390                           +0.96%               +7.86%
(8192, 16384)        sum          0.148147             0.148957             0.146439             0.138801                           +1.17%               +7.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266316             0.260294             0.265829             0.253411                           +0.18%               +2.72%
(8192, 32768)        sum          0.266562             0.260717             0.265744             0.253308                           +0.31%               +2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502035             0.486077             0.498139             0.481374                           +0.78%               +0.98%
(8192, 65536)        sum          0.501571             0.485733             0.498353             0.481350                           +0.65%               +0.91%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971343             0.943016             0.956600             0.938622                           +1.54%               +0.47%
(8192, 131072)       sum          0.971463             0.942991             0.957352             0.938334                           +1.47%               +0.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.952722             1.877165             1.906406             1.861455                           +2.43%               +0.84%
(8192, 262144)       sum          1.952634             1.876388             1.904677             1.861282                           +2.52%               +0.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970697             0.941298             0.956964             0.936160                           +1.44%               +0.55%
(4096, 262144)       sum          0.969981             0.941078             0.957016             0.936260                           +1.35%               +0.51%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501577             0.487208             0.498422             0.483493                           +0.63%               +0.77%
(2048, 262144)       sum          0.502029             0.487124             0.497854             0.483643                           +0.84%               +0.72%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266416             0.257383             0.265928             0.255140                           +0.18%               +0.88%
(1024, 262144)       sum          0.266434             0.257081             0.265817             0.255143                           +0.23%               +0.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087858             0.091296             0.085816             0.087745                           +2.38%               +4.05%
(512, 131072)        sum          0.088144             0.091314             0.085664             0.087864                           +2.90%               +3.93%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014977             0.012393             0.014141             0.010614                           +5.91%              +16.76%
(1000, 1000)         sum          0.014589             0.012804             0.014118             0.010320                           +3.34%              +24.07%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014208             0.008383             0.013273             0.008440                           +7.04%               -0.68%
(1024, 129)          sum          0.013804             0.008863             0.013265             0.009003                           +4.06%               -1.56%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014378             0.009109             0.013037             0.009038                          +10.29%               +0.79%
(1024, 257)          sum          0.014387             0.009113             0.013396             0.008698                           +7.40%               +4.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014207             0.011037             0.013182             0.010391                           +7.78%               +6.22%
(1024, 587)          sum          0.014588             0.011453             0.013539             0.010049                           +7.75%              +13.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.016024             0.013614             0.015448             0.011845                           +3.73%              +14.93%
(2048, 977)          sum          0.015990             0.014033             0.015406             0.012278                           +3.79%              +14.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.014037             0.007804             0.013143             0.008242                           +6.80%               -5.31%
(1024, 128)          sum          0.014041             0.007847             0.012759             0.007850                          +10.05%               -0.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014361             0.009644             0.014075             0.009061                           +2.03%               +6.43%
(8192, 128)          sum          0.014366             0.010032             0.013702             0.009181                           +4.85%               +9.27%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014226             0.008696             0.012894             0.008835                          +10.33%               -1.57%
(1024, 130)          sum          0.013830             0.008740             0.013288             0.008989                           +4.08%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.015036             0.010019             0.013917             0.009538                           +8.04%               +5.04%
(8192, 130)          sum          0.014652             0.010403             0.013900             0.009565                           +5.41%               +8.76%
====================================================================================================================================================================================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165055
Approved by: https://github.com/ngimel
ghstack dependencies: #165494, #164790
2025-10-17 13:39:36 +00:00
7231118db3 Turn some const variables into constexpr in C++ code (#165401)
This PR checks the C++ code and turns some const variables into constexpr.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165401
Approved by: https://github.com/Skylion007
2025-10-17 13:24:46 +00:00
5d4da26ed0 Revert "[export] preserve_node_meta by default (#165524)"
This reverts commit fdd560afd1d413a9f814cbf7cc2a72e0d39b0117.

Reverted https://github.com/pytorch/pytorch/pull/165524 on behalf of https://github.com/lw due to test/functorch/test_control_flow.py::TestControlFlowTraced::test_cond_symint_closure [GH job link](https://github.com/pytorch/pytorch/actions/runs/18586312291/job/52991654051) [HUD commit link](fdd560afd1) ([comment](https://github.com/pytorch/pytorch/pull/165524#issuecomment-3415352522))
2025-10-17 12:27:17 +00:00
574c9fc950 Revert "Remove torch.serialization entries from the doc ignore list (#160224)"
This reverts commit 9fe3b2afbeff12080b483af1ee23e1c9d9fb0421.

Reverted https://github.com/pytorch/pytorch/pull/160224 on behalf of https://github.com/lw due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/18588004962/job/52997748336) [HUD commit link](9fe3b2afbe) ([comment](https://github.com/pytorch/pytorch/pull/160224#issuecomment-3415345175))
2025-10-17 12:24:08 +00:00
80d2ca7566 Revert "[annotate] add annotate_fn function decorator (#165703)"
This reverts commit f1d882212afc3a73ce1e319d80b6406f9dc4a0c8.

Reverted https://github.com/pytorch/pytorch/pull/165703 on behalf of https://github.com/lw due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/18585518705/job/52989521797) [HUD commit link](f1d882212a) ([comment](https://github.com/pytorch/pytorch/pull/165703#issuecomment-3415073467))
2025-10-17 11:23:13 +00:00
4a22139eea [MPS][BE] Fix unused variable warning (#165726)
Namely this one
```
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:19:18: warning: unused variable 'output_sizes' [-Wunused-variable]
  constant auto& output_sizes = shared_params.output_sizes;
                 ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:85:1: note: in instantiation of function template specialization 'cat<long, float, float>' requested here
REGISTER_CAT_FOR_INDEX_TYPE(int64_t);
^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:69:3: note: expanded from macro 'REGISTER_CAT_FOR_INDEX_TYPE'
  REGISTER_CAT_OP_ALL_INPUT_TYPES(I, float);  \
  ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:55:3: note: expanded from macro 'REGISTER_CAT_OP_ALL_INPUT_TYPES'
  REGISTER_CAT_OP(I, float, T_out);               \
  ^
/Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Shape.metal:47:15: note: expanded from macro 'REGISTER_CAT_OP'
  kernel void cat<I, T_in, T_out>(                               \
```

Repeated about 20-30 times
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165726
Approved by: https://github.com/Skylion007
2025-10-17 11:16:21 +00:00
cb6e4d7d82 User-passed alpha to scaled_gemm (#165563)
Summary:

Add optional user-passed `alpha` argument to
`at::cuda::blas::scaled_gemm`, necessary for two-level-scaled NVFP4 gemm
calls (where the global de-scales are folded into the `alpha` argument.

Global de-scales are naturally device tensors, but using cublas'
device-pointer mode for `alpha`/`beta` has an interesting lifetime
implication - the `alpha` tensor must be valid & correct until the end
of the matmul call, *not* just the launch (as for host values). To
enable this, I added device-constant memory for `one` and `zero`, along
with a statically-held single-fp32-value tensor, which is valid from the
first passed-`alpha` invocation of `scaled_gemm` to the end of the
program. User-passed values are copied into this perpetual buffer to
ensure lifetime requirements are met.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165563
Approved by: https://github.com/drisspg, https://github.com/eqy
2025-10-17 09:42:33 +00:00
202f83dc4e [ROCm][layer_norm] Use __builtin_amdgcn_rcpf(x) instead of 1.f/x (#165589)
Replace (more) exact calculation with hardware approximation.

Benefits:
Reduced code size.
Improved performance for certain scenarios.

Experiments show low reduction in precision.
Experiments show no significant performance regressions. bfloat16 as well as float16 related calculations may benefit largely from this change.

Co-author: @mhalk @amd-hhashemi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165589
Approved by: https://github.com/jeffdaily
2025-10-17 09:12:30 +00:00
9fe3b2afbe Remove torch.serialization entries from the doc ignore list (#160224)
Follows the approach done in #158581
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160224
Approved by: https://github.com/janeyx99
2025-10-17 09:06:09 +00:00
d0c24b392c [APF Logging][Error Trait] To fill the errorTraits for ChildFailedError with signal abort (re-attempt of #165476) (#165688)
**Summary**
Land @guoding83128 's PR https://github.com/pytorch/pytorch/pull/165476 on his behalf due to EasyCLA blocking.
Refer his original PR for detail. But in short, elastic leaves 'errorTraits' as unknown when the error dump file is missing,
this PR adds a "system terminated error" to such case so the internal scuba table can correctly aggregate.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165688
Approved by: https://github.com/fduwjj
2025-10-17 08:23:27 +00:00
b44fb14906 Remove unused parameter when query extension attribute (#165623)
# Motivation
This code is no longer needed since SYCL compiler 2025.0. We are now using compiler 2025.2 (two tool uplifts later), so it can be safely removed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165623
Approved by: https://github.com/EikanWang
ghstack dependencies: #165622
2025-10-17 08:16:13 +00:00
51348c0219 Give a friendly message for older Intel GPU (#165622)
# Motivation
Notify the user if the GPU is older than officially supported. This provides a friendly warning that the GPU may work, but the experience could be unstable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165622
Approved by: https://github.com/EikanWang
2025-10-17 08:16:13 +00:00
fdd560afd1 [export] preserve_node_meta by default (#165524)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165524
Approved by: https://github.com/malaybag
2025-10-17 07:55:28 +00:00
e925dfcc6b Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang, https://github.com/mlazos
2025-10-17 07:27:11 +00:00
f1d882212a [annotate] add annotate_fn function decorator (#165703)
Example usage:

```
        @fx_traceback.annotate_fn({"pp_stage": 1})
        def example_function(x):
            return x * x

        class SimpleLinear(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(3, 2)

            def forward(self, x):
                with fx_traceback.annotate({"pp_stage": 0}):
                    y = self.linear(x)
                y = example_function(y)
                return y - 1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165703
Approved by: https://github.com/SherlockNoMad
2025-10-17 07:18:47 +00:00
24879f0de9 [dynamo] Use Variable Builder to build the property fget object (#165683)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165683
Approved by: https://github.com/ezyang, https://github.com/williamwen42
2025-10-17 06:29:24 +00:00
9e94ec76b8 Revert "Turn some const variables into constexpr in C++ code (#165401)"
This reverts commit 5b2afe4c5dc87786ca65bf22ca9a78f7c21a33a4.

Reverted https://github.com/pytorch/pytorch/pull/165401 on behalf of https://github.com/seemethere due to This is breaking test/distributions/test_distributions.py::TestDistributions::test_binomial_sample on HUD, see 5b2afe4c5d ([comment](https://github.com/pytorch/pytorch/pull/165401#issuecomment-3414023134))
2025-10-17 06:14:09 +00:00
364624e209 [codemod][lowrisk] Remove unused exception parameter from some files (#165700)
Summary:
`-Wunused-exception-parameter` has identified an unused exception parameter. This diff removes it.

This:
```
try {
    ...
} catch (exception& e) {
    // no use of e
}
```
should instead be written as
```
} catch (exception&) {
```

If the code compiles, this is safe to land.

Test Plan: Sandcastle

Differential Revision: D84868162

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165700
Approved by: https://github.com/Skylion007
2025-10-17 05:30:06 +00:00
7e150467f7 allow providing full fr trace path (#165639)
Summary:
- allow users to specify the full path instead of fr suffixing the rank id
- this will be used by torchft to provide the global rank id accross all replicas
- we can't just prefix the replica id because analysis tool expects the file name to provide a unique integer

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/165639).
* #165638
* #165640
* #165677
* #165642
* __->__ #165639

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165639
Approved by: https://github.com/fduwjj
2025-10-17 04:43:44 +00:00
43d78423ac Pyrefly suppressions 2 (#165692)
This is the last directory to opt in for the regular mypy.ini file. Will put up a diff to remove unused ignores before making sure we're also type checking all the files in the mypy strict configurations

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165692
Approved by: https://github.com/oulgen
2025-10-17 04:15:25 +00:00
fcbde24c1c [ONNX] Remove common imports from torchlib (#165156)
The Rank and IsScalar functions are no longer used in the torchlib. Requires onnxscript v0.5.4

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165156
Approved by: https://github.com/Skylion007, https://github.com/cyyever
2025-10-17 03:25:34 +00:00
861cdb887b use statically_known_leq & *=2 instead of bound_sympy in persistent rblock (#165657)
While these should be equivalent, we've found instances where they are not, and an error was caused. update until we figure out underlying issue.

Differential Revision: [D84835898](https://our.internmc.facebook.com/intern/diff/D84835898)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165657
Approved by: https://github.com/bobrenjc93
2025-10-17 02:48:03 +00:00
3154482072 [CUDA][cuBLAS] Only xFail addmm with reduced precision reductions on non-RTX skus (#165379)
RTX Blackwells don't behave quite like their datacenter counterparts here

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165379
Approved by: https://github.com/Skylion007
2025-10-17 02:45:07 +00:00
9fccbdd4f0 Fix incorrect function signature in template (#165567)
Summary:
In https://github.com/pytorch/pytorch/pull/148305 we refactored the grid
argument out, but it's not reflected in our template.

Test Plan:
Included in commit.
python test/inductor/test_aot_inductor.py
AOTInductorTestABICompatibleGpu.test_cond_symint_input_disable_one_pass_cuda

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165567
Approved by: https://github.com/desertfire
2025-10-17 02:40:56 +00:00
7dabfb07cb [torchfuzz] add support for --stop-at-first-failure flag (#165529)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165529
Approved by: https://github.com/pianpwk
ghstack dependencies: #164749
2025-10-17 02:18:07 +00:00
d0add0be43 [torchfuzz] check in some more ignore regexes (#164749)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164749
Approved by: https://github.com/pianpwk
2025-10-17 02:18:07 +00:00
11e2084308 Revert "[Mem Snapshot] Add Metadata Field (#165490)"
This reverts commit 5b3ea758951558e7d9f681ae784acb57eaa07910.

Reverted https://github.com/pytorch/pytorch/pull/165490 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165490#issuecomment-3413491091))
2025-10-17 02:01:53 +00:00
9726553653 [BE][Ez]: Use sys.executable instead of hardcoded Python (#165679)
Handles edgecase to ensure proper interpreter is called. Inspired by #165633
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165679
Approved by: https://github.com/FindHao
2025-10-17 01:07:40 +00:00
d82527b32a [Windows] Add AOTI cross-compilation CI (#165573)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165573
Approved by: https://github.com/malfet
ghstack dependencies: #165560
2025-10-17 01:05:35 +00:00
5d9b024276 Add mingw to docker (#165560)
Add mingw to `pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11` docker image to support AOTI cross-compilation

This PR will make docker container rebuild, and upgrade python version from 3.13.7 to 3.13.8. and it relies on https://github.com/pytorch/pytorch/pull/165667
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165560
Approved by: https://github.com/malfet
2025-10-17 00:47:01 +00:00
5b2afe4c5d Turn some const variables into constexpr in C++ code (#165401)
This PR checks the C++ code and turns some const variables into constexpr.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165401
Approved by: https://github.com/Skylion007
2025-10-17 00:40:11 +00:00
b2953f5643 [9/N] Apply ruff UP035 rule (#165515)
This is follow-up of #165214 to continue applying ruff UP035 rule to the code base.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165515
Approved by: https://github.com/Lucaskabela
2025-10-17 00:09:51 +00:00
470e2f61c3 Revert "[Fix] Use sys.executable instead of hardcoded python (#165633)"
This reverts commit 37f3ba274a8ccebc6b3409f52cf068a8b23617d4.

Reverted https://github.com/pytorch/pytorch/pull/165633 on behalf of https://github.com/malfet due to Looks like it broke test_collect_callgrind in slow workflows, see e0fe37fa68/1 ([comment](https://github.com/pytorch/pytorch/pull/165633#issuecomment-3413290813))
2025-10-17 00:06:40 +00:00
e0fe37fa68 [MPS] Move torch.cat impl to Metal (#165373)
After this change, all of the cases tested in [this performance measurement script](10de64c5ac/cat/perf0.py) take either roughly the same runtime or less.

Before:

```
idx: cpu time, mps time, speedup, op, args, kwargs
-----------------------------------------
0: 0.000857 ms, 0.016098 ms, 0.05, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': -1}
1: 0.000858 ms, 0.014861 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': 1}
2: 0.000806 ms, 0.015145 ms, 0.05, cat, [[tensor(shape[10, 5]), tensor(shape[5, 5])]], {'dim': 0}
3: 0.000829 ms, 0.015355 ms, 0.05, cat, [[tensor(shape[1, 2, 3]), tensor(shape[1, 2, 3])]], {'dim': -2}
4: 0.000591 ms, 0.000582 ms, 1.02, cat, [[tensor(shape[0]), tensor(shape[0])]], {'dim': 0}
5: 0.001076 ms, 0.022387 ms, 0.05, cat, [[tensor(shape[0]), tensor(shape[5, 5])]], {'dim': 1}
6: 0.000708 ms, 0.022300 ms, 0.03, cat, [[tensor(shape[0, 5]), tensor(shape[5, 5])]], {'dim': 0}
7: 0.000640 ms, 0.014367 ms, 0.04, cat, [[tensor(shape[1]), tensor(shape[1])]], {}
8: 0.000777 ms, 0.027506 ms, 0.03, cat, [[tensor(shape[2, 2, 2, 2])], 1], {}
9: 0.003383 ms, 0.269277 ms, 0.01, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
10: 0.526138 ms, 0.650852 ms, 0.81, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
11: 0.444091 ms, 0.628630 ms, 0.71, cat, "[[tensor(shape[1, 3, 2]), tensor(shape[2, 3, 2]), tensor(shape[3, 3, 2]), tensor(shape[1, 3, 2]), te...", {'dim': 0}
12: 2.011870 ms, 0.989525 ms, 2.03, cat, [[tensor(shape[1000000, 3, 2]), tensor(shape[1000000, 3, 2])]], {'dim': 0}
13: 3.100653 ms, 0.948178 ms, 3.27, cat, [[tensor(shape[3, 1000000, 2]), tensor(shape[3, 1000000, 2])]], {'dim': 1}
14: 3.112174 ms, 0.954174 ms, 3.26, cat, [[tensor(shape[3, 2, 1000000]), tensor(shape[3, 2, 1000000])]], {'dim': 2}
```

After:

```
idx: cpu time, mps time, speedup, op, args, kwargs
-----------------------------------------
0: 0.000790 ms, 0.013111 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': -1}
1: 0.000800 ms, 0.014419 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': 1}
2: 0.000748 ms, 0.010019 ms, 0.07, cat, [[tensor(shape[10, 5]), tensor(shape[5, 5])]], {'dim': 0}
3: 0.000767 ms, 0.010063 ms, 0.08, cat, [[tensor(shape[1, 2, 3]), tensor(shape[1, 2, 3])]], {'dim': -2}
4: 0.000591 ms, 0.000591 ms, 1.00, cat, [[tensor(shape[0]), tensor(shape[0])]], {'dim': 0}
5: 0.001220 ms, 0.009763 ms, 0.12, cat, [[tensor(shape[0]), tensor(shape[5, 5])]], {'dim': 1}
6: 0.000739 ms, 0.006203 ms, 0.12, cat, [[tensor(shape[0, 5]), tensor(shape[5, 5])]], {'dim': 0}
7: 0.000647 ms, 0.009905 ms, 0.07, cat, [[tensor(shape[1]), tensor(shape[1])]], {}
8: 0.000753 ms, 0.007818 ms, 0.10, cat, [[tensor(shape[2, 2, 2, 2])], 1], {}
9: 0.003823 ms, 0.192723 ms, 0.02, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
10: 0.576564 ms, 0.733920 ms, 0.79, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
11: 0.462957 ms, 0.692799 ms, 0.67, cat, "[[tensor(shape[1, 3, 2]), tensor(shape[2, 3, 2]), tensor(shape[3, 3, 2]), tensor(shape[1, 3, 2]), te...", {'dim': 0}
12: 2.017181 ms, 0.968345 ms, 2.08, cat, [[tensor(shape[1000000, 3, 2]), tensor(shape[1000000, 3, 2])]], {'dim': 0}
13: 3.203508 ms, 0.986382 ms, 3.25, cat, [[tensor(shape[3, 1000000, 2]), tensor(shape[3, 1000000, 2])]], {'dim': 1}
14: 3.181249 ms, 1.007773 ms, 3.16, cat, [[tensor(shape[3, 2, 1000000]), tensor(shape[3, 2, 1000000])]], {'dim': 2}
```

Fixes #165350
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165373
Approved by: https://github.com/kulinseth, https://github.com/malfet
2025-10-17 00:03:04 +00:00
d2c82bafb7 Revert "158232 Fix autocast cache incorrectly retaining no_grad state (#165068)"
This reverts commit 5daef30b26b794d237fbbc399c1d47ec0380200a.

Reverted https://github.com/pytorch/pytorch/pull/165068 on behalf of https://github.com/jeffdaily due to This broke ROCm CI. test/test_transformers.py::TestTransformersCUDA::test_transformerencoder_fastpath_use_torchscript_False_enable_nested_tensor_True_use_autocast_True_d_model_256_cuda [GH job link](https://github.com/pytorch/pytorch/actions/runs/18572589089/job/52952074008) [HUD commit link](5daef30b26) ([comment](https://github.com/pytorch/pytorch/pull/165068#issuecomment-3413184445))
2025-10-16 23:08:27 +00:00
98a488c9aa Start recording inductor provenance (#162669)
Summary:
This stores information on where fx graphs come from, which makes it
significantly easier to debug.

One outstanding question

1) I only stored the kernel stack traces, do we also want the node mappings?

Test Plan:
I wrote a explicit logging test which makes a module, fx traces it, compiles it, and makes sure the logging infomration shows up.

```
clr@devvm17763 ~/fbsource/fbcode/caffe2/test/dynamo
 % buck2 test @//mode/opt fbcode//caffe2/test/dynamo:test_dynamo -- test_utils

File changed: fbsource//xplat/caffe2/test/dynamo/test_utils.py
File changed: fbcode//caffe2/test/dynamo/test_utils.py
Buck UI: https://www.internalfb.com/buck2/528dea32-2416-4a62-a1ec-39f3c0efdd2e
Test UI: https://www.internalfb.com/intern/testinfra/testrun/13229324015574003
Network: Up: 0B  Down: 0B
Executing actions. Remaining     0/2
Command: test.
Time elapsed: 17.3s
Tests finished: Pass 16. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Rollback Plan:

Differential Revision: D82037582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162669
Approved by: https://github.com/yushangdi
2025-10-16 23:05:31 +00:00
5b3ea75895 [Mem Snapshot] Add Metadata Field (#165490)
Summary:
The implementation adds the ability to:

Set custom metadata strings that will be attached to all subsequent allocations
Clear or change the metadata at any point
View the metadata in memory snapshots via _dump_snapshot()

Test Plan: Added test in test_cuda.py and check manually in snapshot to see that metadata was added.

Differential Revision: D84654933

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165490
Approved by: https://github.com/yushangdi
2025-10-16 22:54:27 +00:00
556fc09a9f [DebugMode][1/N] refactor logs into _DebugCalls (#165376)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165376
Approved by: https://github.com/SherlockNoMad
2025-10-16 22:43:52 +00:00
ce109b3f79 Add torch.backends.mkldnn.is_acl_available() method (#165678)
That tells whether or not PyTorch was compiled with Arm Compute Library
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165678
Approved by: https://github.com/Skylion007, https://github.com/atalman, https://github.com/albanD
ghstack dependencies: #165583, #165584, #165676
2025-10-16 22:34:21 +00:00
4d833f859b [BE] [CI] Fix aarch64 arch checks (#165676)
Instead of relying on `TEST_CONFIG` environment variable  to contain `aarch64`, which is prone to errors,  use output of  `$(uname -m)` that is equal to `aarch64` on Linux ARM systems
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165676
Approved by: https://github.com/huydhn, https://github.com/atalman
ghstack dependencies: #165583, #165584
2025-10-16 22:19:53 +00:00
d7e275d4b4 [CI][CUDA] Add periodic b200 distributed job (#159323)
1. Run distributed job with B200 runner, periodically.
2. discovered generic distributed test issue that certain unit test hard-coded ranks, calling for require_exact_world_size(world_size) API instead of require_world_size(world_size).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159323
Approved by: https://github.com/eqy

Co-authored-by: Aidyn-A <aidyn.b.aitzhan@gmail.com>
2025-10-16 21:54:04 +00:00
d5db3aee0d [CI] Use 1-GPU runners for rocm-mi355.yml (#165658)
Should only need 1-GPU runners for rocm-mi355.yml since it runs `default` test config which only needs 1 GPU

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165658
Approved by: https://github.com/jeffdaily
2025-10-16 21:53:22 +00:00
5641de7b6b Add suppressions for _inductor/codegen (#165659)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165659
Approved by: https://github.com/oulgen
2025-10-16 21:37:37 +00:00
cbc08c8993 Add NEON acceleration for Vectorized<int[8|16|32|64> (#165273)
Summary:
Adding NEON specializations of Vectorized<T> for int8, int16, int32 and int64.

Correcness has been checked using test_ops.py and the comprehensive torch test

operator_benchmark_test.py has been enhanced by adding cases of bitwise operations, boolean ops and integer ops.
The benchmark, which uses the PyTorch API, shows significant enhancements in a wide variety of operations:

Before:

bitwise xor: 779.882us
boolean any: 636.209us
boolean all: 538.621us
integer mul: 304.457us
integer asr: 447.997us

After:

bitwise xor: 680.221us ---> 15% higher throughput
boolean any: 391.468us ---> 63% higher throughput
boolean all: 390.189us ---> 38% higher throughput
integer mul: 193.532us ---> 57% higher throughput
integer asr: 179.929us---> 149% higher throughput

Test Plan:
Correctness:

buck2 test @mode/opt //caffe2/test:test_ops
buck2 test @mode/opt //caffe2/test:torch
buck2 test @mode/opt //caffe2/test/distributed/launcher/fb:fb_run_test

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Differential Revision: D84424638

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165273
Approved by: https://github.com/malfet
2025-10-16 21:35:13 +00:00
1a54d3333d [easy] Fix graph_capture in aot_joint_with_descriptors test (#165660)
when `with_export=True`, `aot_export_joint_with_descriptors` should take the graph produced by `_dynamo_graph_capture_for_export`

```
python test/functorch/test_aot_joint_with_descriptors.py -k test_preserve_annotate_simple
python test/functorch/test_aot_joint_with_descriptors.py -k test_preserve_annotate_flex_attention
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165660
Approved by: https://github.com/yushangdi
2025-10-16 21:10:11 +00:00
4c1c341fa0 FakeTensorMode shouldn't cache syms when tracing (#164718)
Improve FakeTensor cache to handle SymNode and tracing properly.

For now, when we're proxy tracing just don't bother caching operations that contain SymNodes in the output. The problem is that the proxy tracer relies on SymNode identity and our cache doesn't preserve that. It can be fixed (and I left some notes in _validate_symbolic_output_for_caching() how) but it's not worth it for now.

If we aren't proxy tracing then caching is fine.

Thus these changes:

1. Our cache key needs to include whether we were actively tracing or not - this way if we create a cache entry when we weren't tracing and then we try to use it when we ARE tracing it gets rerun.

2. If there's a SymNode in the output then bypass tracing.

3. Some general cleanup of the output validation - we were unnecessarily doing it as a two-step process when it could just be a single step (it's still two parts internally but only a single outer try/except).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164718
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #165266, #164717
2025-10-16 20:57:07 +00:00
5f21cc786a Teach ProxyTorchDispatchMode how to decompose sympy.Expr into known inputs (#164717)
In a training library we hit a weird conflict between dtensor, dynamic shapes, and proxy tensor.

The problem is occuring because in sharding_prop we use FakeTensors to compute an operation size (so we don't have to  use the full "real" data). We turn off proxy tracing while we're doing that because we don't want the FakeTensor ops to end up in the graph.  We then use that size when doing later operations.

Normally this is no problem - but when those sizes are dynamic shapes then we have a problem - the proxy tracer wants to track the provenance of all shape operations (`s1*s2`) but since tracing is disabled it doesn't see the operation and when we then use the result shape later on the proxy tracer gets all confused (because the SymNode appeared out of nowhere).

At first we were thinking to never disable shape tracing - but that caused a slew of other downstream problems (lots of code that actually needs the shape tracing to be disabled) so instead we enable having a "sym tracing override" and surgically when we disable proxy tracing we leave shape tracing enabled.

After this change the dtensor embedding is "fixed" but then runs afoul of a FakeTensor cache bug - which is fixed in the next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164717
Approved by: https://github.com/bobrenjc93, https://github.com/ezyang
ghstack dependencies: #165266
2025-10-16 20:57:06 +00:00
e86942f422 minor proxy_tensor reorg (#165266)
Moving some code around in proxy_tensor in preparation for the next PR. There we
no actual changes (other than simple relabeling such as `self.tracer` ->
`tracer`):

- Move _compute_proxy() out of ProxyTorchDispatchMode.

- Give `sympy_expr_tracker` a structured type instead of `object`.

- Split SymNode registration out of ProxyTorchDispatchMode.__sym_dispatch__() so
  it can be reused.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165266
Approved by: https://github.com/ezyang, https://github.com/mlazos
2025-10-16 20:57:06 +00:00
2cd5fd1588 Enable local tensor mode on DTensor view ops test (#165596)
While enabling this test discovered lack of support for sub meshes. Added limited support
for sub meshes by properly computing rank coordinates for a given sub mesh. The implementation
follows similar approach to collectives. We infer all sub meshes for the given dimensions and
compute each rank's coordinates with respect to is sub mesh.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165596
Approved by: https://github.com/ezyang
2025-10-16 20:52:06 +00:00
7d0f872cb3 Use union syntax in torch/_inductor runtime and fx_passes (#165652)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165652
Approved by: https://github.com/aorenste
2025-10-16 20:51:59 +00:00
fb06e49ce8 Revert "[inductor] print 0.0 as 0 for triton (#164291)"
This reverts commit 99b32a6750bfd0cfe2bc84a47823e1da34802b7b.

Reverted https://github.com/pytorch/pytorch/pull/164291 on behalf of https://github.com/malfet due to Broke slow job, see aba8c43594/1  ([comment](https://github.com/pytorch/pytorch/pull/164291#issuecomment-3412768915))
2025-10-16 20:44:29 +00:00
27a98e6ae9 Revert "[DeviceMesh] Prefer using _layout over _mesh for all sorts of things (#165554)"
This reverts commit d61a9b88cf3be04a29c5a7d6e9622ae5e8d51de3.

Reverted https://github.com/pytorch/pytorch/pull/165554 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see aba8c43594/1 ([comment](https://github.com/pytorch/pytorch/pull/165554#issuecomment-3412765681))
2025-10-16 20:41:37 +00:00
b10f463b1a Revert "[DeviceMesh] Introduce private constructor instead of _create_mesh_from_ranks (#165555)"
This reverts commit 99097b6d89c927c15180ff4683c38be01f9955f6.

Reverted https://github.com/pytorch/pytorch/pull/165555 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see aba8c43594/1 ([comment](https://github.com/pytorch/pytorch/pull/165554#issuecomment-3412765681))
2025-10-16 20:41:37 +00:00
431c13cf61 Revert "[DeviceMesh] Simplify unflatten method (#165556)"
This reverts commit 86fd4fc23e697e275d37c36e3cbe521f156434fd.

Reverted https://github.com/pytorch/pytorch/pull/165556 on behalf of https://github.com/malfet due to Looks like it broke serialization test, see aba8c43594/1 ([comment](https://github.com/pytorch/pytorch/pull/165554#issuecomment-3412765681))
2025-10-16 20:41:37 +00:00
aead9270f5 12/n : Remove fbandroid_compiler_flags (#165558)
Summary:
Currently `get_c2_fbandroid_xplat_compiler_flags()` is reading the `caffe2.strip_glog` buckconfig which we want to get rid of.
This diff removes the `fbandroid_compiler_flags` arg and merges it with compiler_flags with a nested select and the select version of the method

The goal is to get rid of all the usages of `get_c2_fbandroid_xplat_compiler_flags()` so that we can get rid of the `caffe2.strip_glog` buckconfig

Test Plan: CI

bifferential Revision: D84626885

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165558
Approved by: https://github.com/malfet
2025-10-16 20:41:24 +00:00
9bf5b38c14 [Inductor][Triton][FP8] Refactor scaled_mm template to accept scaling mode (#164318)
Summary: Refactor `scaled_mm` Inductor template to support template choice based on scaling mode. This modification sets up the infrastructure for adding new templates based on new scaling modes, such as deepseek-style scaling (a follow-up diff), as new scaling modes (deepseek, block, group) scale before the accumulation (as opposed to per-tensor and per-row scaling, which apply scaling after accumulation). This modification also further enables Inductor to infer a scaling type based on the shape of the scaling tensors, which makes existing infrastructure more extensible to new scaling modes.

Test Plan:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling_rowwise --atol=20 --rtol=2 2>&1 | tee ~/personal/random.log
```

bifferential Revision: D83591083

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164318
Approved by: https://github.com/drisspg, https://github.com/slayton58
2025-10-16 20:40:45 +00:00
aba8c43594 Register var for MTIA (#165382)
Summary: Registers variance kernel

Reviewed By: srsuryadev

Differential Revision: D84546250

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165382
Approved by: https://github.com/malfet
2025-10-16 20:35:15 +00:00
37f3ba274a [Fix] Use sys.executable instead of hardcoded python (#165633)
Replace hardcoded "python" string with sys.executable to ensure correct Python interpreter is used. This fixes failures on systems with multiple Python runtimes or where "python" is not in PATH.

Similar to pytorch/pytorch#155918

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165633
Approved by: https://github.com/Skylion007
2025-10-16 20:26:10 +00:00
585b9dbb5e [async_tp] Support ag+mm with gather_dim lastdim of mat_A (#163068)
Adding ag+mm support for the case, when gather_dim is last dim of matmul (reduction dim).

When we decompose matmul by reduction dimension we result in partials that needs additional reduction,
we allocate memory for accumulator.

Decomposition should not produce small (thin) mms that can not efficiently load the GPU. Limiting for minimal size of the shard 1024 (found empirically by testing in torchtitan).

scaled_mm is not supported yet for this case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163068
Approved by: https://github.com/ngimel
2025-10-16 20:14:39 +00:00
d795fb225a [RFC] Add pyrefly to lintrunner (#165179)
This will add pyrefly to lint runner as a warning only - and allow us to collect feedback about the tool before switching to pyrefly as the main type checker.

References the steps outlined here: : https://github.com/pytorch/pytorch/issues/163283:

test plan:
`lintrunner init`
`lintrunner`
confirm when pyrefly errors are present results look like: https://gist.github.com/maggiemoss/e6cb2d015dd1ded560ae1329098cf33f

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165179
Approved by: https://github.com/ezyang
2025-10-16 20:07:09 +00:00
7df9aca529 [ROCm][Windows] Enable AOTriton runtime compile on Windows (#165538)
AOTriton uses prebuilt runtime binaries if the user's ROCm version matches the ones used to generate the prebuilt runtime. However, since there's no prebuilt runtime available for Windows, this check needs to be bypassed for Windows. This PR enables it by changing condition to always build AOTriton runtime from source on Windows.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165538
Approved by: https://github.com/xinyazhang, https://github.com/jeffdaily
2025-10-16 19:51:43 +00:00
d4a713cd9c Change forkserver test to only run below 3.13.8 (#165667)
A multiprocessing bug is fixed in 3.13.8, see [https://docs.python.org/3.13/whatsnew/changelog.html](https://l.workplace.com/l.php?u=https%3A%2F%2Fdocs.python.org%2F3.13%2Fwhatsnew%2Fchangelog.html&h=AT0qUhHJq5c2UJvQaq9_MrSo0mVhwn1VOfq1nDQl2C1UOhDI80RMbzVayhG7LSAT1uYHKtkftKnBDwiGMhbw0YRvQLe5vwE01qejpPFautHvU3LXeOE1KChPykqz3qnCRzk7czu_iNzQ05shR4F1N_qYOzR5YxejA52ZZQ), [gh-126631](https://github.com/python/cpython/issues/126631)

So this test will fail when we update to python 3.13.8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165667
Approved by: https://github.com/malfet
2025-10-16 19:34:10 +00:00
5daef30b26 158232 Fix autocast cache incorrectly retaining no_grad state (#165068)
Fixes #158232
The autocast caching heuristic in `aten/src/ATen/autocast_mode.cpp:139` did not account for gradient mode state when deciding whether to cache. FSDP2 is not directly related.

~~This PR adds `GradMode::is_enabled()` check to caching condition. Caching is now disabled in `no_grad()` contexts to prevent storing tensors with incorrect gradient state. Ensures correctness at the cost of using cache.~~
This PR proposes separate caches for gradient-enabled and gradient-disabled modes.
Adds tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165068
Approved by: https://github.com/ngimel, https://github.com/janeyx99
2025-10-16 19:32:01 +00:00
6dedd34c31 [CD] Skip 12.9 build on Windows (#165665)
Per title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165665
Approved by: https://github.com/Camyll, https://github.com/malfet
2025-10-16 19:11:27 +00:00
a303d6dda9 [inductor] don't try to reorder loops for template (#165601)
fix https://github.com/pytorch/pytorch/issues/165579

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165601
Approved by: https://github.com/yushangdi
2025-10-16 19:05:21 +00:00
7669ac9402 [ROCm] Add scaled_mm v2 support. (#165528)
Add mx fp4 support in Blas.cpp.
Updated the scale_kernel_dispatch array and ScaledGemmImplementation enum to include MXFP4 support.
Modify the tests under test_scaled_matmul_cuda accordingly.

PYTORCH_TEST_WITH_ROCM=1 python test/test_scaled_matmul_cuda.py -v -k test_blockwise
115 test passed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165528
Approved by: https://github.com/jeffdaily
2025-10-16 18:36:41 +00:00
86fd4fc23e [DeviceMesh] Simplify unflatten method (#165556)
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165556
Approved by: https://github.com/fduwjj
ghstack dependencies: #165554, #165555
2025-10-16 18:36:16 +00:00
99097b6d89 [DeviceMesh] Introduce private constructor instead of _create_mesh_from_ranks (#165555)
The refactoring of DeviceMesh is heavily constrained by the signature of its constructor, which is a public API which contains some "legacy" concepts which we'd love to get rid of, such as an explicit/materialized `mesh` Tensor.

In other languages the solution to this would be to add a private overload of the constructor. Python doesn't natively allow this, but in this PR I managed to build something that approximates it.

This new private constructor basically only takes `_layout`, `_global_rank_permutation`, and `mesh_dim_names`.

With such a constructor we can effectively simplify a lot of callsites and get rid of the `_create_mesh_from_ranks` helper method. That's a good thing because it was instantiating many DeviceMeshes in a for loop, which always felt unnecessary.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165555
Approved by: https://github.com/fduwjj, https://github.com/fegin
ghstack dependencies: #165554
2025-10-16 18:36:16 +00:00
eqy
a214371008 [FP8] Add other Blackwell compute-capabiilities to expected fail test_honor_sm_carveout (#165159)
CUTLASS SM hint also isn't working for other Blackwells, need green context for carveout

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165159
Approved by: https://github.com/Skylion007
2025-10-16 18:35:06 +00:00
7d87d7052e [inductor][bucketing] Fx collectives bucketing of multiple dtypes (#162470)
Bucketing of multiple dtypes to be processed in one bucketed collective.

First target is to bucket bf16 and f32, but already can be used with other dtypes.

For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162470
Approved by: https://github.com/eellison
2025-10-16 18:31:43 +00:00
1a34ff4e04 Fixing get_local_rank() variable missing when compiled (#165432)
Fixes #165215

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165432
Approved by: https://github.com/bdhirsh
2025-10-16 18:20:34 +00:00
fe5ccb1a74 bf16 support for per tensor backward (#165362)
Adding bf16 for the backward pass of `torch._fake_quantize_learnable_per_tensor_affine()`.

Note that for testing, we modified the seed to avoid increasing tolerance due to cases where difference in Python vs CPP downcasting causes tensor mismatches. (e.g. 27.87704 vs  27.8408 before downcasting, 27.7500 vs 27.8750 after downcasting for Python vs CPP op)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165362
Approved by: https://github.com/andrewor14
2025-10-16 17:47:01 +00:00
85586d7efc Make c7i the default for _linux-build.yml (#164747)
Use linux.c7i.2xlarge as the default runner for the _linux-build.yml workflow. In testing we found that switching from c5 - c7i grants a 15-20% faster build times despite c7i costing 5% more. This should reduce costs of jobs using _linux-build.yml.

Relates to pytorch/test-infra#7175.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164747
Approved by: https://github.com/atalman
2025-10-16 17:37:51 +00:00
e1d71a6b35 Revert "12/n : Remove fbandroid_compiler_flags (#165558)"
This reverts commit d7ffa8b8a29ba6071c51499c1df3d702d0a26f72.

Reverted https://github.com/pytorch/pytorch/pull/165558 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/165558#issuecomment-3411879769))
2025-10-16 17:18:56 +00:00
d61a9b88cf [DeviceMesh] Prefer using _layout over _mesh for all sorts of things (#165554)
The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
2025-10-16 17:01:44 +00:00
99b32a6750 [inductor] print 0.0 as 0 for triton (#164291)
Fixes https://github.com/pytorch/pytorch/issues/164157
Fixes https://github.com/pytorch/pytorch/issues/164086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164291
Approved by: https://github.com/bobrenjc93
2025-10-16 16:37:50 +00:00
783da8b8e7 Repro for property related Dynamo graph break (#165609)
Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165609
Approved by: https://github.com/albanD, https://github.com/gchanan, https://github.com/malfet, https://github.com/anijain2305
2025-10-16 16:22:43 +00:00
ed74dc054d add the option to disable functionalization in AOTDispatcher (#164577)
I'm cleaning this PR up as a proper way of disabling functionalization via config in AOTDispatcher. I removed the non-functionalization related changes from the original version:

(1) preventing proxy mode (and functionalization) from incorrectly decomposing CIA ops (Ed has a PR for it here: https://github.com/pytorch/pytorch/pull/164939)

(2) preventing python-dispatcher-based decomps above autograd from running. I'm not doing this for now, will likely do it in a followup

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164577
Approved by: https://github.com/ezyang
ghstack dependencies: #165372
2025-10-16 15:44:11 +00:00
f33c7e1a43 add and fix OpInfo tests for the default partitioner (#165372)
I noticed the default partitioner was breaking in some dynamic shape tests, so prior to turning off functionalization I want to tweak it to pass all of our OpInfo tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165372
Approved by: https://github.com/ezyang
2025-10-16 15:44:11 +00:00
219fb6aafc Refactor CUDAAllocatorConfig using ConfigTokenizer (#165281)
* #165129
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165281
Approved by: https://github.com/albanD
ghstack dependencies: #165129, #165131, #165135, #165136
2025-10-16 15:26:50 +00:00
515b5ff539 Remove unused code in CUDAAllocatorConfig (#165136)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165136
Approved by: https://github.com/Skylion007
ghstack dependencies: #165129, #165131, #165135
2025-10-16 15:26:50 +00:00
608a6d4a26 Reuse AcceleratorAllocatorConfig in CUDAAllocatorConfig (#165135)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165135
Approved by: https://github.com/Skylion007
ghstack dependencies: #165129, #165131
2025-10-16 15:26:40 +00:00
03e5dbb26e Register CUDAAllocatorConfig to AcceleratorAllocatorConfig (#165131)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165131
Approved by: https://github.com/Skylion007
ghstack dependencies: #165129
2025-10-16 15:26:28 +00:00
7ee45f7503 Restore AcceleratorAllocatorConfig to avoid potential regression (#165129)
# Motivation
This PR aims to restore `AcceleratorAllocatorConfig` to avoid the potential regression mentioned in https://github.com/pytorch/pytorch/pull/160666#issue-3323270375
These code change would be reverted in the following PR https://github.com/pytorch/pytorch/pull/165304
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165129
Approved by: https://github.com/albanD
2025-10-16 15:26:17 +00:00
e6d9d68598 [Bugfix][Dynamo] Fix Sparse tensors by graph break in Dynamo (#164873)
Fixes #164823 by making lack of support for sparse tensors very explicit (in fake tensor, inductor, and lowering code)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164873
Approved by: https://github.com/williamwen42, https://github.com/eellison, https://github.com/mlazos
2025-10-16 15:06:20 +00:00
1a5b7eca7b [BE] Fold cond into TORCH_CHECK(false,...) (#165593)
Replace `if (!foo) { TORCH_CHECK(false, "bar");}` with `TORCH_CHECK(foo,"bar");`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165593
Approved by: https://github.com/albanD
ghstack dependencies: #165594
2025-10-16 15:00:30 +00:00
8573574b32 [MPS] sparse mask implementation (#165102)
sparse mask implementation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165102
Approved by: https://github.com/malfet
2025-10-16 14:31:00 +00:00
e6033f6efb [MPS] Improve index_fill_ error handling (#165594)
It shoudl not throw "Cannot convert a float64 Tensor to MPS", but rather a sensible "Converting complex Scalar to non-complex type is not supported".
Add TODO about the complex support, probably good reason to rip out MPSGraph from index_fill as well
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165594
Approved by: https://github.com/dcci, https://github.com/kulinseth
2025-10-16 14:18:39 +00:00
9272437cde Fx collectives bucketing: add bucket all_reduce (#165351)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165351
Approved by: https://github.com/eellison
2025-10-16 13:27:33 +00:00
f06e669f6c refactor: replace runtime_error with TORCH_CHECK for better error handling (#163628)
Fixes some parts of issue #148114

@pytorchbot label "topic: not user facing"

@FFFrog PTAL
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163628
Approved by: https://github.com/albanD
2025-10-16 11:09:48 +00:00
69b05913fb Revert "Add mingw to docker (#165560)"
This reverts commit 5e480b8ecf870e4a466c165701ab0e9d055f2ceb.

Reverted https://github.com/pytorch/pytorch/pull/165560 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165560#issuecomment-3409814274))
2025-10-16 08:42:11 +00:00
d73c283c3a [CUDA] Large tensor maxpool crash fix (#165374)
Fixes #165297

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165374
Approved by: https://github.com/eqy, https://github.com/malfet
2025-10-16 07:59:46 +00:00
eaeaa08e3a [PowerPC] Disable MKLDNN TF32 on PowerPC to fix build failure (#163454)
The commits f4d8bc46c7706f872abcb4ec41f0b32207d5d826 added TF32 support for x86 CPUs,
which causes build failures on PowerPC systems with mkldnn.

This patch disables TF32 paths on PowerPC while keeping x86 TF32 support intact,
allowing PyTorch to build successfully on PowerPC.

I have run the mkldnn test case on PowerPC, and it passed successfully.

`pytest test/test_mkldnn.py
87 passed, 2 skipped in 1709.02s (0:28:29`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163454
Approved by: https://github.com/jgong5, https://github.com/malfet
2025-10-16 06:13:59 +00:00
d0c32971b4 Refine XPU allocator message when OOM (#165509)
# Motivation
Provide more information and align with other backends to enhance the user experience.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165509
Approved by: https://github.com/EikanWang
ghstack dependencies: #165508
2025-10-16 05:47:49 +00:00
d7ffa8b8a2 12/n : Remove fbandroid_compiler_flags (#165558)
Summary:
Currently `get_c2_fbandroid_xplat_compiler_flags()` is reading the `caffe2.strip_glog` buckconfig which we want to get rid of.
This diff removes the `fbandroid_compiler_flags` arg and merges it with compiler_flags with a nested select and the select version of the method

The goal is to get rid of all the usages of `get_c2_fbandroid_xplat_compiler_flags()` so that we can get rid of the `caffe2.strip_glog` buckconfig

Test Plan: CI

Differential Revision: D84626885

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165558
Approved by: https://github.com/malfet
2025-10-16 05:46:02 +00:00
00afa06800 Add cse for make_block_ptr in Triton codegen (#163399)
Summary: per title

Test Plan: added test cases

Differential Revision: D82648215

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163399
Approved by: https://github.com/jansel, https://github.com/njriasan
2025-10-16 05:29:48 +00:00
5d0b22008d Codemod inductor/fx_passes from Optional to union none (#165606)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165606
Approved by: https://github.com/aorenste
ghstack dependencies: #165604, #165605
2025-10-16 04:59:47 +00:00
ab6014a903 Codemod inductor/runtime from Optional to union none (#165605)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165605
Approved by: https://github.com/aorenste
ghstack dependencies: #165604
2025-10-16 04:59:47 +00:00
f6daffc54d Codemod codecache.py from Optional to union none (#165604)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165604
Approved by: https://github.com/aorenste
2025-10-16 04:59:37 +00:00
66b75693ae Reuse kLargeBuffer in XPUCachingAllocator (#165508)
# Motivation
Reuse the shared code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165508
Approved by: https://github.com/EikanWang
2025-10-16 04:12:52 +00:00
21697feff2 [hop] run local_map with interpreter to preserve fx_traceback annotations (#165336)
We have an issue when using fx_traceback.annotate and HOPs that trace joint graphs. HOPs have bodies that have already been traced by Dynamo, and after Animesh's PR, does have the annotations. But when we lower that Dynamo HOP body to aten in either pre-dispatch or post-dispatch, we need to propagate the annotations to the aten nodes.

AOTAutograd does this indirectly by piggybacking off the `PropagateUnbackedSymInts` fx.Interpreter. I'm not sure if all HOPs should be using it to trace their joints or not. This PR adds an interpreter to local_map's implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165336
Approved by: https://github.com/yushangdi
2025-10-16 02:53:17 +00:00
12fa4192c5 [ContextParallel] add process-time based Round-Robin load-balance to CP (#163617)
**Summary**
The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.

- Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for
`flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform
Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order).
- Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance
index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`.

**Test**
`pytest test/distributed/tensor/test_attention.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163617
Approved by: https://github.com/fegin
2025-10-16 02:20:27 +00:00
23fb7e9f4b [CI] Add arch prefix in front of op benchmark results (#165584)
To be able to run x86 and aarch64 benchmarks later on
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165584
Approved by: https://github.com/huydhn
ghstack dependencies: #165583
2025-10-16 01:50:52 +00:00
5e480b8ecf Add mingw to docker (#165560)
Add mingw to `pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11` docker image to support AOTI cross-compilation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165560
Approved by: https://github.com/malfet
ghstack dependencies: #165574
2025-10-16 01:31:50 +00:00
19ba506ca3 Support libtorch and posix mingw flavor (#165574)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165574
Approved by: https://github.com/desertfire
2025-10-16 01:31:50 +00:00
003dd13073 [dynamo, guards] Better error messages when generated guard fails on the same frame (#165242)
Not sure what exactly we want to have in the message, but that's easy to adjust. I tried to find a reliable test to reproduce this message (happens only when a guard fails right after it's created), but I ended up mocking a `guard_manager.check` function to return `False` to trigger this behavior. I think that's fine, because any other case that we pick (like datetime.now()), we want to patch one day anyway, so every time we make the next patch, will need to chase for another repro test

@williamwen42

Fixes #164990

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165242
Approved by: https://github.com/williamwen42
2025-10-16 01:05:31 +00:00
c2bd41ac9f Build vLLM nightly wheels for CUDA 13.0 (#163239)
Now that https://github.com/vllm-project/vllm/pull/24599 has been merged
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163239
Approved by: https://github.com/malfet, https://github.com/atalman
2025-10-16 01:03:26 +00:00
ca8bd5dbed Move toString(ScalarType) and ScalarType ostream operator to headeronly (#164405)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164405
Approved by: https://github.com/Skylion007, https://github.com/janeyx99
ghstack dependencies: #164350, #164354
2025-10-16 00:55:43 +00:00
26f3803433 Remove workaround to old CUDA bug (#164354)
As in the title.

A check for https://github.com/pytorch/pytorch/issues/164348 to see if the workaround can be removed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164354
Approved by: https://github.com/janeyx99, https://github.com/ngimel, https://github.com/malfet, https://github.com/jeffdaily
ghstack dependencies: #164350
2025-10-16 00:55:43 +00:00
48064acf37 Move AT_FORALL_... macros and ScalarTypeToCPPTypeT to headeronly (#164350)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164350
Approved by: https://github.com/janeyx99
2025-10-16 00:55:42 +00:00
e5a9c247bc [Fix XPU CI] [Inductor UT] Fix test cases broken by community. (#165406)
Fixes #163159, Fixes #164098, Fixes #164097, Fixes #164099, Fixes #165025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165406
Approved by: https://github.com/EikanWang, https://github.com/jansel
2025-10-16 00:53:32 +00:00
36371b8ec7 [ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790
Approved by: https://github.com/ngimel, https://github.com/eqy
ghstack dependencies: #165494
2025-10-15 23:54:51 +00:00
7e6721fb0a [BE] Remove confusing opbenchmark-on-demand-build (#165583)
As it doesn't have a test shard, so what's the point or running the build? Was added in https://github.com/pytorch/pytorch/pull/143733 and looks like test shard never existed for it

Moreover, allow one to specify benchmark size as argument, so one
technically can do a workflow dispatch with different opbenchmark sizes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165583
Approved by: https://github.com/huydhn
2025-10-15 23:48:28 +00:00
901bbcba12 Gate division bitwise numerics under a flag (#165566)
https://github.com/pytorch/pytorch/pull/164144 ensures that division for compile is bitwise equivalent with eager. However, in https://github.com/pytorch/pytorch/issues/164301, the kernel performance is regressed.

On B200:
With standard triton `/`:
6511 GB/s

With triton `div_rn`:
4692 GB/s

Further investigation is required for the generated PTX to see why there is such a large slowdown. For now, enable bitwise equivalent results under `TORCHINDUCTOR_EMULATE_DIVISION_ROUNDING` similar to emulate_precision_cast

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165566
Approved by: https://github.com/ngimel, https://github.com/eellison
2025-10-15 23:41:01 +00:00
febb603230 [Inductor][CuTeDSL] Move load_template up two directories (#165347) (#165576)
Summary:

Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.

Test Plan: `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8`

Reviewed By: drisspg

Differential Revision: D84527470

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165576
Approved by: https://github.com/jananisriram
2025-10-15 23:37:55 +00:00
568d2f3ae7 [Dynamo][Logging] Add sources/types to LazyVariableTracker logging (#165402)
Fixes #162860

This task add the variable source attrition to LazyVariableTracker when output trace bytecode

Test plan -- test/dynamo/test_error_messages.py ErrorMessagesTest.test_variable_tracker_source_attribution

The output is as specified in the prior mentioned Github issue.

<img width="961" height="59" alt="Screenshot 2025-10-13 at 10 19 44 PM" src="https://github.com/user-attachments/assets/fb27da3f-d00b-437b-bf2e-52e892572cd7" />

This is specifically for the log setup with ``TORCH_LOGS=trace_bytecode``

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165402
Approved by: https://github.com/Lucaskabela, https://github.com/williamwen42

Co-authored-by: William Wen <williamwen@meta.com>
2025-10-15 23:23:09 +00:00
b54e466fd0 Megacache integration (#163533)
This diff adds megacache integration for DynamoCache.

Because DynamoCache requires lazy serialization, i.e. it can only be serialized once all relevant backends have been compiled and we're ready for a save, we actually do the DynamoCache saving only on a call to `torch.compiler.save_cache_artifacts`.

Differential Revision: [D82735763](https://our.internmc.facebook.com/intern/diff/D82735763/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163533
Approved by: https://github.com/oulgen, https://github.com/zhxchen17
2025-10-15 22:49:15 +00:00
53f9ae0e50 [ROCm] new implementation of upsample_bilinear2d_backward (#164572)
Changed the implementation from an output-based approach to an input-based one to remove `atomicAdd` operations, and it appears to deliver at least a 20× speedup.

The changes are from Yu-Yun <YuYun.Chang@amd.com>.

# Summary: Refactor of the implementation of the `upsample_bilinear2d_backward` opertion on MI300X/MI325X
- The original "scatter-add" approach
  - Each thread, representing an output pixel, scattered gradient contributions to four input pixels, using costly atomic operations on MI300X/MI325X GPUs.
- The new "gather-sum" approach
  - Each thread is responsible for a single input pixel and gathers all relevant gradient contributions from a small, calculated region of the output tensor (done by the `compute_output_range` device function).
# Breakdown of the code changes
- Inversion of the parallelization strategy of the kernel function `upsample_bilinear2d_backward_out_frame`
  - Originally, the main kernel loop was parallelized over the number of elements in the output gradient tensor (`const size_t o_numel = nc * width2 * height2;`).
    - Each thread processed one output pixel.
  - The new loop is parallelized over the number of elements in the input gradient tensor (`const size_t i_numel = nc * height1 * width1;`).
    - Each thread is responsible for calculating the final gradient for a single input pixel.
  - The kernel launch changes accordingly in the function `upsample_bilinear2d_backward_out_cuda_template`.
- Added a device function for calculating the range of output pixels that could have possibly used that the input pixel (`input_pos`) during the forward pass interpolation
  - This is essentially the mathematical inverse of the forward pass.
  - This function tries to prune a thread's search space so that it only needs to inspect a small, local window of the output tensor.
- Gradient calculation approach switching from "scatter-add" to "gather-sum"
  - Scatter-add
    - For each output pixel, the thread calculated 4 gradient contributions and use `fastAtomicAdd` 4 times to add these values to 4 different (and potentially highly contended) memory locations in the input gradient tensor.
  - Gather-sum
    - A thread responsible for one input pixel calls `compute_output_range` to determine the small rectangular region of output pixels that influence the input's final gradient value.
    - The thread iterates through this region, and for each output pixel in the regionre, it re-calculates the interpolation weights to determine the exact contribution to its specific input pixel.
    - All these contributions are accumulated into a private, per-thread register variable (`accscalar_t grad_sum = 0;`).
      - W/o any gloabl memory access, this accumulation is extremely fast.
    - When the loops are done, the thread performs a single, direct write (non-atomic) of the final summed gradient to its designated location in global memory (`idata[index] = static_cast<scalar_t>(grad_sum);`).
# Why performance gets boosted
- Analysis of the root cause of performance drop
  - Ref. (internal only) - https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1140493327/PyTorch__upsample_bilinear2d_backward
- First and foremost, elimination of the contention of atomic operations
  - Many parallel threads called `atomicAdd` frequently attempting to update the exact same memory location in the input gradient tensor at the same time.
    - The GPU's memory controler has to serialize these operations, effectively nullifying the benefit of parallel capability at those contention points.
  - MI300X/MI325X chiplet-based CDNA 3 architeture amplified the issue.
    - When contending threads reside on different XCDs, resolving the atomic operation requires high-latency coherence traffic across the Infinity Fabric interconnect.
  - The implementation change eliminates hardware-level serialization and cross-chiplet coherence traffic caused by many `atomicAdd`.
- Improved memory access pattern and locality
  - Write coalescing
    - The regular sum writes `idata[index] = static_cast<scalar_t>(grad_sum);` can be perfectly coalesced by GPUs.
  - Read locality
    - Even though there are many (potentially repeated) reads from the output tensor (`static_cast<accscalar_t>(odata[output_idx])`), these are highly cache-friendly, meaning the data for one thread is likely to be in the L1 or L2 cache already due to an access from a neighboring thread.
- Trade-off: computation for memory synchronization
  - The recalculation of interpolation weights fits well on high-computational-throughput modern GPUs like MI300X/MI325X.
  - Removal of atomic operations avoids expensive memory synchronization.

---

Optimizations of `grid_sampler_2d_backward` will be addressed in a separate PR.
Doc for reference: (internal only) https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1162750701/PyTorch__grid_sampler_2d_backward

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164572
Approved by: https://github.com/jeffdaily
2025-10-15 22:35:43 +00:00
b42fe389b9 ROCm unit tests enablement (#165366)
Enables:
test_cuda.py::TestCuda::test_streaming_backwards_multiple_streams
test_cuda.py::TestCuda::test_graph_make_graphed_callables_with_amp_cache_disabled_allow_unused_input
test_cuda.py::TestCuda::test_graph_make_graphed_callables_without_amp_allow_unused_input
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_10000_10000_cuda_bfloat16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_10000_10000_cuda_float16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_10000_10000_cuda_float32
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_1000_10000_cuda_bfloat16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_1000_10000_cuda_float16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_1_10000_1000_10000_cuda_float32
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_1000_1000_1000_cuda_bfloat16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_1000_1000_1000_cuda_float16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_1000_1000_1000_cuda_float32
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_100_100_100_cuda_bfloat16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_100_100_100_cuda_float16
test_matmul_cuda.py::TestMatmulCudaCUDA::test_cublas_baddbmm_large_input_2_100_100_100_cuda_float32

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165366
Approved by: https://github.com/jeffdaily
2025-10-15 22:35:03 +00:00
66ea76ec44 [ROCm][tunableop] Improvements to tunableop Numerical Check (#163079)
Modified the flag PYTORCH_TUNABLEOP_NUMERICAL_CHECK, so that it accepts the numerical tolerances in the format atol_rtol as compared to the previous 0 and 1. Retains previous functionality with default values as well.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163079
Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily
2025-10-15 22:26:47 +00:00
e787d532b6 tmp fix for compile internal logger issue (#165568)
Summary: Catch runtime exception when garse and scrub uninteresting configs from inductor config

Test Plan: tested locally

Differential Revision: D84727788

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165568
Approved by: https://github.com/luccafong, https://github.com/oulgen
2025-10-15 22:03:16 +00:00
b3f6d49b69 Overlap scheduler improvements (#165318)
Bucketing a number of smallish improvements:

- Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time
-  Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future.
- When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well.

Better Memory Handling:
- Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing).
- When we are above peak memory, schedule waits.

TODO:
- for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory
- By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples.
- config some hard coded constants, clean up enablement (can do in subsequent pr)

On small llama 2d backward :
578 of 618 potentially hideable collectives hidden
original mem 14.4GB, rescheduled mem, 15.9GB

on forward:
254/256 potentially hideable collectives hidden
original mem 5.8 gb, reshceduled mem 5.8GB

WIP: adding tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945, #165059
2025-10-15 21:58:47 +00:00
bc1f2108d7 [PP] Update backward_counter and fsdp util to schedule class (#165513)
Fixed one issue with FSDP last reshard not being called.

Rest is mostly refactoring, changing some variables to be class variables so they can be used in https://github.com/pytorch/torchtitan/pull/1721

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165513
Approved by: https://github.com/fegin
2025-10-15 21:58:16 +00:00
f071f17911 [Graph Partition] fix partition x memory plan issue (#165514)
For `test_graph_partition_with_memory_plan_reuse`, before this PR, when using graph partition, it would error ([P1992728479](https://www.internalfb.com/phabricator/paste/view/P1992728479)):

```
def partition_0(args):
    ...
    del buf0
    return (buf3, buf4, buf5, buf2, primals_4, )

...

  File "/tmp/torchinductor_boyuan/ww/cwwc7ukfqscg2vy6ankby2fizdb377tvgyx3fwdgddrxe3g47jg6.py", line 132, in partition_0
    return (buf3, buf4, buf5, buf2, primals_4, )
                              ^^^^
NameError: name 'buf2' is not defined. Did you mean: 'buf0'?
```

When not using graph partition, it would work and give the following code ([P1992997521](https://www.internalfb.com/phabricator/paste/view/P1992997521)):

```
def call(self, args):
    ...
    buf2 = buf0; del buf0  # reuse
    ...
```

Note that the issue is buf0 is not reused for buf2 when using graph partition.

Why? Because the codegen runs `run_wrapper_ir_passes` and `memory_plan_reuse`, which pops tailing `MemoryPlanningLine` unless it is in graph output by checking `V.graph.get_output_names()`. However, for graph partition, we should check the output of the current partition instead of the graph before partition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165514
Approved by: https://github.com/ProExpertProg, https://github.com/eellison
2025-10-15 21:52:16 +00:00
fa1539594b consolidate fw and inference compile paths (#165457)
By design, fw compile and inference compile stages should share a bunch of code; just consolidating the duplication here.

Differential Revision: D84628978

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165457
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
2025-10-15 21:33:50 +00:00
dfc8a1c5dd Fix _StridedShard incorrect split (#165533)
https://github.com/pytorch/pytorch/pull/164820 introduced a bug that `_StridedShard` will call parent class `Shard`'s `split_tensor` method, thus results in incorrect data locality. (I think @ezyang spotted this issue, but we have no test to capture this)

Meanwhile, I notice another bug that when we normalize a `_StridedShard`'s placement, it will also trigger parent class `Shard`'s `split_tensor` method because it will create a Shard class [here](0c14f55de6/torch/distributed/tensor/_api.py (L783)). I think we never test `distribute_tensor` for `_StridedShard` before. So I added a test here to compare against ordered shard.

Using classmethod because the _split_tensor logic is different between `Shard` and `_StridedShard`. Basically I want to shard on local tensors without initializing the Shard object:
```
local_tensor = _StridedShard._make_shard_tensor(dim, tensor, mesh, mesh_dim, split_factor=split_factor)
local_tensor = Shard._make_shard_tensor(dim, tensor, mesh, mesh_dim)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165533
Approved by: https://github.com/XilunWu
2025-10-15 20:52:41 +00:00
7f9b745494 [ROCm][tunableop] Modified Online Tuning Mode to add Instant Logging (#163965)
- Added instant logging in online tuning mode, so that each tuned GEMM is instantly written
- Allows us to have saved tuning configs, in cases of crashes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163965
Approved by: https://github.com/naromero77amd, https://github.com/jeffdaily
2025-10-15 20:02:31 +00:00
83f9baf413 [Bugfix][Precompile][vLLM] Support for pickling einops for aot_autograd serialization in vLLM (#165359)
Fixes issue with compiling `Qwen2_5_vl` in https://github.com/vllm-project/vllm/pull/23207 (issue happens with `aot_autograd_cache`)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165359
Approved by: https://github.com/jamesjwu
2025-10-15 20:00:24 +00:00
ffc7552e01 See if we can handle uploading all test data (#165484)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165484
Approved by: https://github.com/izaitsevfb
2025-10-15 19:57:41 +00:00
78f5a1ec60 varlen api (#164502)
**Summary**

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps**

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics.

(This stack builds on top of #162326)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502
Approved by: https://github.com/v0i0, https://github.com/drisspg
2025-10-15 19:45:55 +00:00
2b71b62045 Add Memory Estimation Tracker (#165059)
Add Memory Tracker utility, which will track live memory given alternate ordering of nodes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165059
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945
2025-10-15 19:44:29 +00:00
8c4b528403 Revert "[Inductor][CuTeDSL] Move load_template up two directories (#165347)"
This reverts commit 815d6415996d5b32b569fd2a8206f1e57c75bfe3.

Reverted https://github.com/pytorch/pytorch/pull/165347 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165347#issuecomment-3407958496))
2025-10-15 19:30:46 +00:00
066f818eea Refactor and unify v1/v2 _scaled_mm codes (#165436)
Summary:

* Refactor out some core routines (scaled_gemm, auto-tuned scaled_gemm)
* Unify v1/v2 dispatch calls where possible
* Simplify call pattern w.r.t. CUDA/ROCM for easier readability.

Test Plan:

```
pytest -svv test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165436
Approved by: https://github.com/drisspg
2025-10-15 19:07:05 +00:00
14af1dc3da [DeviceMesh] Fix layout calculation when flattening non-contiguous dims (#165542)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165542
Approved by: https://github.com/ezyang, https://github.com/fduwjj
2025-10-15 18:55:45 +00:00
2395d7d7da Relax equality check (#165460)
When an object is inherited from multiple types, the previous check would fail. So we should relax it to respect eager semantic

Differential Revision: [D84635322](https://our.internmc.facebook.com/intern/diff/D84635322)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165460
Approved by: https://github.com/avikchaudhuri
2025-10-15 18:32:01 +00:00
0aa7ebaf03 Fix periodic debug tests failing due to FakeProcessGroup things (#165479)
These happen when building with CMAKE_BUILD_TYPE=RelWithAssert

This should fix two types of failures that started with https://github.com/pytorch/pytorch/pull/163665

Disclaimer that I used a lot of AI since I don't how pybind works or what refcounts and pointers are, so idk if this is a good solution, or even a solution at all (fwiw the tests pass now)

The first one type is

Truncated:
```
    default_pg, _ = _new_process_group_helper(
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2096, in _new_process_group_helper
    backend_class = creator_fn(dist_backend_opts, backend_options)
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/distributed/fake_pg.py", line 25, in _create_fake_pg
    return FakeProcessGroup._create_internal(
RuntimeError: new_refcount != 1 INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/c10/util/intrusive_ptr.h":319, please report a bug to PyTorch. intrusive_ptr: Cannot increase refcount after it reached zero.
Exception raised from retain_ at /var/lib/jenkins/workspace/c10/util/intrusive_ptr.h:319 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from ??:0
#7 c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, char const*) from ??:0
#8 void pybind11::class_<c10d::FakeProcessGroup, (anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup> >::init_instance<(anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup>, 0>(pybind11::detail::instance*, void const*) from init.cpp:0
#9 pybind11::detail::type_caster_generic::cast(void const*, pybind11::return_value_policy, pybind11::handle, pybind11::detail::type_info const*, void* (*)(void const*), void* (*)(void const*), void const*) from :0
#10 pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> >, int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v>(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}&&, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> > (*)(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) from init.cpp:0
```
and I fix it here by getting rid of `DontIncreaseRefcount` and using make_intrusive to do the ref count handling instead.  However, I also had to move the constructor to be public, which I think is not good, based on the reasoning of the original PR

The other one type is
```
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/test/test_testing.py", line 2415, in test_no_warning_on_import
    self.assertEqual(out, "")
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 4233, in assertEqual
    raise error_metas.pop()[0].to_error(  # type: ignore[index]
AssertionError: String comparison failed: "/opt/conda/envs/py_3.10/lib/python3.10/s[352 chars]):\n" != ''
- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/__init__.py:29: FutureWarning: pybind11-bound class 'torch._C._distributed_c10d.FakeProcessGroup' is using an old-style placement-new '__init__' which has been deprecated. See the upgrade guide in pybind11's docs. This message is only visible when compiled in debug mode.
-   if is_available() and not torch._C._c10d_init():

To execute this test, run the following from the base repo dir:
    python test/test_testing.py TestImports.test_no_warning_on_import
```
which I fix by getting rid of the `__init__` which I think is ok since it'll just error if you try to make one?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165479
Approved by: https://github.com/ezyang
2025-10-15 18:16:08 +00:00
7a97832585 [ROCm] Add more timm models, forward fix #165381 (#165569)
PR #165381 added timm models to cuda and cpu expected accuracy files. ROCm expected accuracy files were not updated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165569
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-10-15 18:11:21 +00:00
84d141e910 Revert "[inductor] Expand use of generic benchmark function (#164938)"
This reverts commit 5c583e2573f29243742e00b9fa36b266c5c78bb3.

Reverted https://github.com/pytorch/pytorch/pull/164938 on behalf of https://github.com/clee2000 due to I think this broke test/inductor/test_cuda_repro.py::CudaReproTests::test_epilogue_fusion_with_view? [GH job link](https://github.com/pytorch/pytorch/actions/runs/18529735968/job/52813191763) [HUD commit link](f58f301313) on both rocm and the slow grad check for linux. It did run successfully on cuda workflow on trunk, I wonder if this a gpu capability thing? no clue though ([comment](https://github.com/pytorch/pytorch/pull/164938#issuecomment-3407600224))
2025-10-15 17:48:38 +00:00
7c6c5d04fe Add scaled_grouped_mm_v2 and python API (#165154)
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165154
Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre
2025-10-15 17:47:23 +00:00
b509fb9b5d Revert "add and fix OpInfo tests for the default partitioner (#165372)"
This reverts commit bcfea48ab7fd489218289693b98c1a6a6582d079.

Reverted https://github.com/pytorch/pytorch/pull/165372 on behalf of https://github.com/malfet due to Looks like it broke slow jobs, see 331b7cc054/1 ([comment](https://github.com/pytorch/pytorch/pull/165372#issuecomment-3407567748))
2025-10-15 17:38:52 +00:00
331b7cc054 Fix double dispatch to Python for detach (#163671)
This fixes #71725.

Differential Revision: [D83857880](https://our.internmc.facebook.com/intern/diff/D83857880)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163671
Approved by: https://github.com/ezyang, https://github.com/albanD
2025-10-15 17:24:50 +00:00
815d641599 [Inductor][CuTeDSL] Move load_template up two directories (#165347)
Summary: Moves the function used to load CuTeDSL Jinja templates up one level out of the flex attention folder. This way it can be used for more generate Inductor templates in the future.

Test Plan: `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:flex_flash -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8`

Differential Revision: D84527470

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165347
Approved by: https://github.com/drisspg
2025-10-15 16:34:58 +00:00
ffe3cb226a In pipeline parallelism: Use same dtype for receive and send tensor when initializing p2p communication. (#165539)
When initializing the p2p communication for pipeline parallelism, currently different default dtypes are used for the send and receive tensor here:
5c583e2573/torch/distributed/pipelining/stage.py (L935-L936)

This caused hard to trace issues when training on multiple nodes. Multiple stages on one node seem to work for some reason which probably caused the unit tests not to catch this.

Fixes #165143

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165539
Approved by: https://github.com/H-Huang
2025-10-15 15:05:55 +00:00
7ae123d72c [DeviceMesh] Make _flatten_mapping an object attribute instead of a class attribute (#165521)
The `_flatten_mapping` field was defined as a class attribute with a mutable default value {}:
```
_flatten_mapping: dict[str, "DeviceMesh"] = {}
```
This caused all DeviceMesh instances to share the same dictionary object. When multiple test instances tried to create flattened meshes with the same name (like "dp"), they would conflict because they were all using the same shared dictionary, resulting in the error: "Flatten mesh with mesh_dim_name dp has been created before, Please specify another valid mesh_dim_name."

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165521
Approved by: https://github.com/fegin, https://github.com/lw
2025-10-15 14:47:09 +00:00
7719cb75bf [ATen][CMake] Fix duplicated CUTLASS path (#165424)
Fixes #165110

The `PUBLIC` scope causes CUTLASS of the FBGEMM being included in for all PyTorch targets, including special matmuls (RowwiseScaledMM, ScaledGroupMM and GroupMM). Due to version mismatch between FBGEMM/CUTLASS and PyTorch/CUTLASS it is unacceptable to use FBGEMM/CUTLASS in PyTorch targets. This PR limits the scope of FBGEMM/CUTLASS to `fbgemm_genai` target only.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165424
Approved by: https://github.com/cthi, https://github.com/eqy, https://github.com/danielvegamyhre
2025-10-15 14:14:17 +00:00
712f54d453 [ATen] Remove explicit casting of complex nansum during accumulation (#165494)
https://github.com/pytorch/pytorch/pull/164790 modifies aten to perform a different reduction order intra warp. However, this change exposed a large difference in a sum for complex32. Namely the case:

```
import torch

a = torch.tensor([[ 4.82031250+7.34765625j,
           -3.37109375-1.9501953125j],

         [ 3.7832031250-2.43359375j,
           -6.07812500+5.32812500j]], dtype=torch.complex32, device='cuda:0')

sum_out = torch.sum(a)
nansum_out = torch.nansum(a)
torch.testing.assert_close(
    sum_out,
    nansum_out,
    rtol=0,
    atol=0,
)
```

Here, the result of `sum` and `nansum` differed significantly by 1e-2. Further investigation showed that the explicit casting of b back to `arg_t` from `scalar_t` was the root cause. `arg_t` is the dtype of the accumulator, ComplexFloat, and `scalar_t` of the input dtype, ComplexHalf. When we cast in the reduction to the accumulator order, that means the input is still of ComplexHalf, which loses precision as it can store intermediate values.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165494
Approved by: https://github.com/ngimel
2025-10-15 13:49:25 +00:00
f58f301313 Fixes bug with tolist calls to GradTrackingTensors (#165184)
Fixes #161943

## The Fix
I implemented a recursive unwrapping helper function in the `tensor_to_list.cpp` file that looks for wrapped tensors and unwraps them. The recursive implementation was needed for multi-level gradTrackingTensors.

Let me know if there is any more suggestions on fixing this issue!

@guilhermeleobas @KimbingNg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165184
Approved by: https://github.com/zou3519
2025-10-15 12:54:28 +00:00
5c583e2573 [inductor] Expand use of generic benchmark function (#164938)
Use the more generic `Benchmarker.benchmark` function to allow benchmarking other devices that support the required functionality, for example prologue and epilogue fusion can be benchmarked for triton CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164938
Approved by: https://github.com/nmacchioni, https://github.com/eellison
2025-10-15 09:18:24 +00:00
0c14f55de6 [ez] fix typo (#165282)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165282
Approved by: https://github.com/ezyang, https://github.com/mlazos
2025-10-15 06:19:24 +00:00
8e510e1095 [MPS] fix empty dot op crash (#165237)
reproducer
```
import torch

# does not crash
a = torch.rand((0), device="cpu")
b = torch.rand((0), device="cpu")
a.dot(b)

# crashes due to internal assert
a = torch.rand((0), device="mps")
b = torch.rand((0), device="mps")
a.dot(b)

```

Discovered when implementing an op for SparseMPS backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165237
Approved by: https://github.com/malfet
2025-10-15 04:49:29 +00:00
59d30d1b75 [vision hash update] update the pinned vision hash (#165496)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned vision hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165496
Approved by: https://github.com/pytorchbot
2025-10-15 04:35:50 +00:00
3915898c22 [audio hash update] update the pinned audio hash (#165495)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165495
Approved by: https://github.com/pytorchbot
2025-10-15 04:32:49 +00:00
3044e1a460 Revert "varlen api (#164502)"
This reverts commit 3681312ce03e425e280a110df2153db107616a15.

Reverted https://github.com/pytorch/pytorch/pull/164502 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but the doctests failure is legit ([comment](https://github.com/pytorch/pytorch/pull/164502#issuecomment-3404419420))
2025-10-15 03:56:42 +00:00
b11593c31b [8/N] Apply ruff UP035 rule (#165214)
This is follow-up of #164653 to continue applying `UP035` fixes. The purpose is to finally enable this rule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165214
Approved by: https://github.com/ezyang
2025-10-15 03:18:57 +00:00
36871622f1 [2/N] Mark unused parameters in C++ code (#165121)
This is follow-up of #164912 to mark unused C++ parameters to improve code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165121
Approved by: https://github.com/Skylion007
2025-10-15 03:04:39 +00:00
809 changed files with 14768 additions and 7728 deletions

View File

@ -113,6 +113,7 @@ case "$tag" in
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
INSTALL_MINGW=yes
;;
pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11)
CUDA_VERSION=13.0.0
@ -361,6 +362,7 @@ docker build \
--build-arg "OPENBLAS=${OPENBLAS:-}" \
--build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \
--build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \
--build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \
-f $(dirname ${DOCKERFILE})/Dockerfile \
-t "$tmp_tag" \
"$@" \

View File

@ -83,10 +83,6 @@ function build_cpython {
py_suffix=${py_ver::-1}
py_folder=$py_suffix
fi
# Update to rc2 due to https://github.com/python/cpython/commit/c72699086fe4
if [ "$py_suffix" == "3.14.0" ]; then
py_suffix="3.14.0rc2"
fi
wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz
do_cpython_build $py_ver Python-$py_suffix

View File

@ -0,0 +1,10 @@
#!/bin/bash
set -ex
# Install MinGW-w64 for Windows cross-compilation
apt-get update
apt-get install -y g++-mingw-w64-x86-64-posix
echo "MinGW-w64 installed successfully"
x86_64-w64-mingw32-g++ --version

View File

@ -20,7 +20,7 @@ pip_install \
pip_install coloredlogs packaging
pip_install onnxruntime==1.23.0
pip_install onnxscript==0.5.3
pip_install onnxscript==0.5.4
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/

View File

@ -39,9 +39,13 @@ case ${DOCKER_TAG_PREFIX} in
DOCKER_GPU_BUILD_ARG=""
;;
rocm*)
# we want the patch version of 7.0 instead
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
fi
# we want the patch version of 6.4 instead
if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
fi
BASE_TARGET=rocm
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete

View File

@ -75,9 +75,13 @@ case ${image} in
DOCKERFILE_SUFFIX="_cuda_aarch64"
;;
manylinux2_28-builder:rocm*)
# we want the patch version of 7.0 instead
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
fi
# we want the patch version of 6.4 instead
if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2"
GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4"
fi
TARGET=rocm_final
MANY_LINUX_VERSION="2_28"

View File

@ -103,6 +103,11 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
ARG INSTALL_MINGW
COPY ./common/install_mingw.sh install_mingw.sh
RUN if [ -n "${INSTALL_MINGW}" ]; then bash ./install_mingw.sh; fi
RUN rm install_mingw.sh
ARG TRITON
ARG TRITON_CPU

View File

@ -57,8 +57,8 @@ def clone_external_repo(target: str, repo: str, dst: str = "", update_submodules
logger.info("Successfully cloned %s", target)
return r, commit
except GitCommandError as e:
logger.error("Git operation failed: %s", e)
except GitCommandError:
logger.exception("Git operation failed")
raise

View File

@ -485,6 +485,22 @@ test_inductor_aoti() {
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
}
test_inductor_aoti_cross_compile_for_windows() {
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"
# Set WINDOWS_CUDA_HOME environment variable
WINDOWS_CUDA_HOME="$(pwd)/win-torch-wheel-extracted"
export WINDOWS_CUDA_HOME
echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME"
echo "Contents:"
ls -lah "$(pwd)/win-torch-wheel-extracted/lib/x64/" || true
python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib"
}
test_inductor_cpp_wrapper_shard() {
if [[ -z "$NUM_TEST_SHARDS" ]]; then
echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
@ -900,7 +916,7 @@ test_inductor_set_cpu_affinity(){
export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD"
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
if [[ "${TEST_CONFIG}" != *aarch64* ]]; then
if [[ "$(uname -m)" != "aarch64" ]]; then
# Use Intel OpenMP for x86
IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so"
export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD"
@ -914,7 +930,7 @@ test_inductor_set_cpu_affinity(){
cores=$((cpus / thread_per_core))
# Set number of cores to 16 on aarch64 for performance runs
if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then
if [[ "$(uname -m)" == "aarch64" && $cores -gt 16 ]]; then
cores=16
fi
export OMP_NUM_THREADS=$cores
@ -1615,6 +1631,7 @@ test_operator_benchmark() {
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"
TEST_DIR=$(pwd)
ARCH=$(uname -m)
test_inductor_set_cpu_affinity
@ -1629,7 +1646,7 @@ test_operator_benchmark() {
pip_install pandas
python check_perf_csv.py \
--actual "${TEST_REPORTS_DIR}/operator_benchmark_eager_float32_cpu.csv" \
--expected "expected_ci_operator_benchmark_eager_float32_cpu.csv"
--expected "${ARCH}_expected_ci_operator_benchmark_eager_float32_cpu.csv"
}
test_operator_microbenchmark() {
@ -1666,7 +1683,7 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then
python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0
fi
python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; then
test_linux_aarch64
elif [[ "${TEST_CONFIG}" == *backward* ]]; then
test_forward_backward_compatibility
@ -1717,6 +1734,8 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
test_inductor_triton_cpu
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
test_inductor_micro_benchmark
elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then
test_inductor_aoti_cross_compile_for_windows
elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then
install_torchvision
id=$((SHARD_NUMBER-1))

View File

@ -7,16 +7,12 @@ max-line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
ignore =
E203,E305,E402,E501,E704,E721,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
E203,E305,E402,E501,E704,E741,F405,F841,F999,W503,W504,C408,E302,W291,E303,F824,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910
# these ignores are from flake8-comprehensions; please fix!
C407,
# these ignores are from flake8-logging-format; please fix!
G100,G101,G200
# these ignores are from flake8-simplify. please fix or ignore with commented reason
SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12,
# SIM104 is already covered by pyupgrade ruff

View File

@ -65,7 +65,7 @@ runs:
cd .ci/lumen_cli
python3 -m pip install -e .
)
MAX_JOBS="$(nproc --ignore=6)"
MAX_JOBS="$(nproc --ignore=10)"
export MAX_JOBS
# Split the comma-separated list and build each target

View File

@ -1 +1 @@
8ad2aa5d354d1bf432339113860185d5a5d1abbd
69bbe7363897764f9e758d851cd0340147d27f94

View File

@ -1 +1 @@
f5c6c2ec6490455e86f67b2a25c10390d60a27f7
faffd5cf673615583da6517275e361cb3dbc77e6

29
.github/labeler.yml vendored
View File

@ -133,3 +133,32 @@
"ciflow/vllm":
- .github/ci_commit_pins/vllm.txt
"ciflow/b200":
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm
"ciflow/h100":
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm
"ciflow/rocm":
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/Blas.cpp
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm

View File

@ -3,6 +3,7 @@ ciflow_tracking_issue: 64124
ciflow_push_tags:
- ciflow/b200
- ciflow/b200-symm-mem
- ciflow/b200-distributed
- ciflow/binaries
- ciflow/binaries_libtorch
- ciflow/binaries_wheel

View File

@ -79,21 +79,21 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
"nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'"
),
"12.9": (
"nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'"
"nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | "
"nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | "
"nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | "
"nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | "
"nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | "
"nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | "
"nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | "
"nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | "
"nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | "
"nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | "
"nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | "
"nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | "
"nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | "
"nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | "
"nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'"
),
"13.0": (
"nvidia-cuda-nvrtc==13.0.48; platform_system == 'Linux' | "
@ -241,7 +241,11 @@ def generate_libtorch_matrix(
arches += CUDA_ARCHES
arches += ROCM_ARCHES
elif os == "windows":
arches += CUDA_ARCHES
# TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
# in 2.10
windows_cuda_arches = CUDA_ARCHES.copy()
windows_cuda_arches.remove("12.9")
arches += windows_cuda_arches
if libtorch_variants is None:
libtorch_variants = [
"shared-with-deps",
@ -305,7 +309,11 @@ def generate_wheels_matrix(
if os == "linux":
arches += CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES
elif os == "windows":
arches += CUDA_ARCHES + XPU_ARCHES
# TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up
# in 2.10
windows_cuda_arches = CUDA_ARCHES.copy()
windows_cuda_arches.remove("12.9")
arches += windows_cuda_arches + XPU_ARCHES
elif os == "linux-aarch64":
# Separate new if as the CPU type is different and
# uses different build/test scripts

View File

@ -1092,7 +1092,7 @@ class GitHubPR:
editor = node["editor"]
return GitHubComment(
body_text=node["bodyText"],
created_at=node["createdAt"] if "createdAt" in node else "",
created_at=node.get("createdAt", ""),
author_login=node["author"]["login"],
author_url=node["author"].get("url", None),
author_association=node["authorAssociation"],

View File

@ -26,9 +26,8 @@ name: !{{ build_environment }}
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "!{{ (py_ver.strip('t') + '.4') if '3.14' not in py_ver else '3.14.0-rc.2' }}"
python-version: "!{{ py_ver.strip('t') + ('.4' if '3.14' not in py_ver else '.0') }}"
freethreaded: !{{ "true" if py_ver.endswith('t') else "false" }}
{%- endmacro %}

View File

@ -37,7 +37,7 @@ on:
runner:
required: false
type: string
default: "linux.2xlarge"
default: "linux.c7i.2xlarge"
description: |
Label of the runner this job should run on.
test-matrix:

View File

@ -224,6 +224,46 @@ jobs:
continue-on-error: true
uses: ./.github/actions/download-td-artifacts
- name: Download Windows torch wheel for cross-compilation
if: matrix.win_torch_wheel_artifact != ''
uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0
with:
name: ${{ matrix.win_torch_wheel_artifact }}
path: win-torch-wheel
- name: Extract Windows wheel and setup CUDA libraries
if: matrix.win_torch_wheel_artifact != ''
shell: bash
run: |
set -x
# Find the wheel file
WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1)
if [ -z "$WHEEL_FILE" ]; then
echo "Error: No wheel file found in win-torch-wheel directory"
exit 1
fi
echo "Found wheel file: $WHEEL_FILE"
# Unzip the wheel file
unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted
echo "Extracted wheel contents"
# Setup CUDA libraries (cuda.lib and cudart.lib) directory
mkdir -p win-torch-wheel-extracted/lib/x64
if [ -f "win-torch-wheel/cuda.lib" ]; then
mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/
echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/"
fi
if [ -f "win-torch-wheel/cudart.lib" ]; then
mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/
echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/"
fi
# Verify CUDA libraries are present
echo "CUDA libraries:"
ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found"
- name: Parse ref
id: parse-ref
run: .github/scripts/parse_ref.py

View File

@ -168,6 +168,31 @@ jobs:
run: |
.ci/pytorch/win-build.sh
# Collect Windows torch libs and CUDA libs for cross-compilation
- name: Collect Windows CUDA libs for cross-compilation
if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu'
shell: bash
run: |
set -ex
# Create directory structure if does not exist
mkdir -p /c/${{ github.run_id }}/build-results
# Copy CUDA libs
CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}"
if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then
cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/build-results/
fi
if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then
cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/
fi
# List collected files
echo "Collected CUDA libs:"
ls -lah /c/${{ github.run_id }}/build-results/*.lib
# Upload to github so that people can click and download artifacts
- name: Upload artifacts to s3
if: steps.build.outcome != 'skipped'

62
.github/workflows/b200-distributed.yml vendored Normal file
View File

@ -0,0 +1,62 @@
name: CI for distributed tests on B200
on:
pull_request:
paths:
- .github/workflows/b200-distributed.yml
workflow_dispatch:
push:
tags:
- ciflow/b200-distributed/*
schedule:
- cron: 46 8 * * * # about 1:46am PDT
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
get-label-type:
if: github.repository_owner == 'pytorch'
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200:
name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'
test-matrix: |
{ include: [
{ config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" },
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200:
name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200
with:
timeout-minutes: 1200
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
secrets: inherit

View File

@ -27,9 +27,8 @@ jobs:
fail-fast: false
matrix:
python-version: [ '3.12' ]
# TODO (huydhn): Add cu130 after https://github.com/vllm-project/vllm/issues/24464 is resolved
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
device: [ 'cu128', 'cu129' ]
device: [ 'cu128', 'cu129', 'cu130' ]
include:
- platform: manylinux_2_28_x86_64
device: cu128
@ -39,6 +38,10 @@ jobs:
device: cu129
manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.9'
runner: linux.12xlarge.memory
- platform: manylinux_2_28_x86_64
device: cu130
manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.0'
runner: linux.12xlarge.memory
- platform: manylinux_2_28_aarch64
device: cu128
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.8'
@ -47,6 +50,11 @@ jobs:
device: cu129
manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.9'
runner: linux.arm64.r7g.12xlarge.memory
exclude:
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
# xformers is update to support 13.0
- platform: manylinux_2_28_aarch64
device: cu130
name: "Build ${{ matrix.device }} vLLM wheel on ${{ matrix.platform }}"
runs-on: ${{ matrix.runner }}
timeout-minutes: 480
@ -169,7 +177,12 @@ jobs:
fail-fast: false
matrix:
platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ]
device: [ 'cu128', 'cu129' ]
device: [ 'cu128', 'cu129', 'cu130' ]
exclude:
# TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and
# xformers is update to support 13.0
- platform: manylinux_2_28_aarch64
device: cu130
env:
PLATFORM: ${{ matrix.platform }}
BUILD_DEVICE: ${{ matrix.device }}

View File

@ -224,7 +224,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_10-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -473,7 +473,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_11-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -722,7 +722,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_12-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -971,7 +971,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_13-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -1220,7 +1220,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_13t-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -1469,7 +1469,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_14-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
@ -1718,7 +1718,7 @@ jobs:
ALPINE_IMAGE: "arm64v8/alpine"
build_name: manywheel-py3_14t-cuda-aarch64-12_9
build_environment: linux-aarch64-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
timeout-minutes: 420
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -259,7 +259,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_10-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_10-cuda12_9-test: # Testing
@ -925,7 +925,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_11-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_11-cuda12_9-test: # Testing
@ -1591,7 +1591,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_12-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_12-cuda12_9-test: # Testing
@ -2257,7 +2257,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_13-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_13-cuda12_9-test: # Testing
@ -2923,7 +2923,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_13t-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_13t-cuda12_9-test: # Testing
@ -3589,7 +3589,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_14-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_14-cuda12_9-test: # Testing
@ -4255,7 +4255,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_14t-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.3.20; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_14t-cuda12_9-test: # Testing

View File

@ -63,7 +63,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.10.4"
freethreaded: false

View File

@ -59,7 +59,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.10.4"
freethreaded: false
@ -169,7 +168,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.11.4"
freethreaded: false
@ -279,7 +277,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.12.4"
freethreaded: false
@ -389,7 +386,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.13.4"
freethreaded: false
@ -499,7 +495,6 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.13.4"
freethreaded: true
@ -609,9 +604,8 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.14.0-rc.2"
python-version: "3.14.0"
freethreaded: false
- name: Checkout PyTorch
uses: actions/checkout@v4
@ -719,9 +713,8 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v6
with:
# TODO: Removeme once 3.14 is out
# .4 version is min minor for 3.10, and also no-gil version of 3.13 needs at least 3.13.3
python-version: "3.14.0-rc.2"
python-version: "3.14.0"
freethreaded: true
- name: Checkout PyTorch
uses: actions/checkout@v4

View File

@ -788,256 +788,6 @@ jobs:
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda12_9-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Build PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
- uses: actions/upload-artifact@v4.4.0
if: always()
with:
name: libtorch-cuda12_9-shared-with-deps-debug
retention-days: 14
if-no-files-found: error
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_9-shared-with-deps-debug-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-cuda12_9-shared-with-deps-debug-build
- get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-cuda12_9-shared-with-deps-debug
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Test PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_9-shared-with-deps-debug-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-cuda12_9-shared-with-deps-debug-test
with:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
build_name: libtorch-cuda12_9-shared-with-deps-debug
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda13_0-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type

View File

@ -788,256 +788,6 @@ jobs:
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda12_9-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Build PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
- uses: actions/upload-artifact@v4.4.0
if: always()
with:
name: libtorch-cuda12_9-shared-with-deps-release
retention-days: 14
if-no-files-found: error
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_9-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-cuda12_9-shared-with-deps-release-build
- get-label-type
runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge"
timeout-minutes: 360
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
- name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon
shell: bash
run: |
git config --global core.longpaths true
git config --global core.symlinks true
# https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock
# the directory on Windows and prevent GHA from checking out as reported
# in https://github.com/actions/checkout/issues/1018
git config --global core.fsmonitor false
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-cuda12_9-shared-with-deps-release
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Test PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_9-shared-with-deps-release-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-cuda12_9-shared-with-deps-release-test
with:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: "12.9"
GPU_ARCH_TYPE: cuda
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.10"
build_name: libtorch-cuda12_9-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda13_0-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
needs: get-label-type

File diff suppressed because it is too large Load Diff

View File

@ -88,27 +88,27 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
]}
secrets: inherit

View File

@ -118,9 +118,9 @@ jobs:
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
echo "Running all other linters"
if [ "$CHANGED_FILES" = '*' ]; then
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
else
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT ${CHANGED_FILES}" .github/scripts/lintrunner.sh
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
fi
quick-checks:

View File

@ -7,9 +7,11 @@ on:
workflow_dispatch:
inputs:
test_mode:
required: false
type: string
default: 'short'
type: choice
options:
- 'short'
- 'long'
- 'all'
description: tag filter for operator benchmarks, options from long, short, all
schedule:
# Run at 07:00 UTC every Sunday
@ -28,38 +30,49 @@ permissions:
contents: read
jobs:
opbenchmark-build:
x86-opbenchmark-build:
if: github.repository_owner == 'pytorch'
name: opbenchmark-build
name: x86-opbenchmark-build
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-py3.10-gcc11-build
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
{ config: "cpu_operator_benchmark_${{ inputs.test_mode || 'short' }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
]}
secrets: inherit
opbenchmark-on-demand-build:
if: ${{ github.event_name == 'workflow_dispatch' && github.repository_owner == 'pytorch' }}
name: opbenchmark-on-demand-build
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-py3.10-gcc11-build
docker-image-name: ci-image:pytorch-linux-jammy-py3-gcc11-inductor-benchmarks
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_${{ inputs.test_mode }}", shard: 1, num_shards: 1, runner: "linux.12xlarge" },
]}
secrets: inherit
opbenchmark-test:
name: opbenchmark-test
x86-opbenchmark-test:
name: x86-opbenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: opbenchmark-build
needs: x86-opbenchmark-build
with:
build-environment: linux-jammy-py3.10-gcc11-build
docker-image: ${{ needs.opbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.opbenchmark-build.outputs.test-matrix }}
docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
secrets: inherit
aarch64-opbenchmark-build:
if: github.repository_owner == 'pytorch'
name: aarch64-opbenchmark-build
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-aarch64-py3.10
runner: linux.arm64.m7g.4xlarge
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
]}
secrets: inherit
aarch64-opbenchmark-test:
name: aarch64-opbenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: aarch64-opbenchmark-build
with:
build-environment: linux-jammy-aarch64-py3.10
docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }}
secrets: inherit

View File

@ -45,12 +45,12 @@ jobs:
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
]}
secrets: inherit

View File

@ -200,6 +200,23 @@ jobs:
cuda-arch-list: '8.0'
secrets: inherit
# Test cross-compiled models with Windows libs extracted from wheel
cross-compile-linux-test:
name: cross-compile-linux-test
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-build
- get-label-type
- win-vs2022-cuda12_8-py3-build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
test-matrix: |
{ include: [
{ config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" },
]}
secrets: inherit
verify-cachebench-cpu-build:
name: verify-cachebench-cpu-build
uses: ./.github/workflows/_linux-build.yml

1
.gitignore vendored
View File

@ -374,6 +374,7 @@ third_party/ruy/
third_party/glog/
# Virtualenv
.venv/
venv/
# Log files

View File

@ -209,6 +209,46 @@ command = [
'@{{PATHSFILE}}'
]
[[linter]]
code = 'PYREFLY'
include_patterns = [
'torch/**/*.py',
'torch/**/*.pyi',
'torchgen/**/*.py',
'torchgen/**/*.pyi',
'functorch/**/*.py',
'functorch/**/*.pyi',
]
exclude_patterns = []
command = [
'python3',
'tools/linter/adapters/pyrefly_linter.py',
'--config=pyrefly.toml',
]
init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.3.0',
'pyrefly==0.36.2',
'sympy==1.13.3',
'types-requests==2.27.25',
'types-pyyaml==6.0.2',
'types-tabulate==0.8.8',
'types-protobuf==5.29.1.20250403',
'types-setuptools==79.0.0.20250422',
'types-jinja2==2.11.9',
'types-colorama==0.4.6',
'filelock==3.18.0',
'junitparser==2.1.1',
'rich==14.1.0',
'optree==0.17.0',
'types-openpyxl==3.1.5.20250919',
'types-python-dateutil==2.9.0.20251008'
]
[[linter]]
code = 'CLANGTIDY'
include_patterns = [

View File

@ -201,3 +201,17 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
/torch/headeronly/ @janeyx99
/torch/header_only_apis.txt @janeyx99
# FlexAttention
/torch/nn/attention/flex_attention.py @drisspg
/torch/_higher_order_ops/flex_attention.py @drisspg
/torch/_inductor/kernel/flex/ @drisspg
/torch/_inductor/codegen/cpp_flex_attention_template.py @drisspg
/test/inductor/test_flex_attention.py @drisspg
/test/inductor/test_flex_decoding.py @drisspg
# Low Precision GEMMs
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
/test/test_scaled_matmul_cuda.py @drisspg @slayton58

View File

@ -256,6 +256,7 @@ endif()
IF(USE_FBGEMM_GENAI)
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
set(FBGEMM_GENAI_SRCS ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
@ -288,62 +289,69 @@ IF(USE_FBGEMM_GENAI)
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
set(fbgemm_genai_mx8mx8bf16_grouped
set(fbgemm_genai_cuh
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/mx8mx8bf16_grouped/"
"${FBGEMM_GENAI_SRCS}/"
)
target_include_directories(fbgemm_genai PUBLIC
target_include_directories(fbgemm_genai PRIVATE
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${fbgemm_genai_mx8mx8bf16_grouped}
${fbgemm_genai_cuh}
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
else()
if(USE_ROCM)
# Only include the kernels we want to build to avoid increasing binary size.
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
# Add FBGEMM_GENAI include directories for torch_ops.h
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
list(APPEND ATen_CUDA_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
elseif(USE_ROCM)
# Only include the kernels we want to build to avoid increasing binary size.
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
"${FBGEMM_GENAI_SRCS}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
# Only compile for gfx942 for now.
# This is rather hacky, I could not figure out a clean solution :(
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
endif()
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
-mllvm
-amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
hip_add_library(
fbgemm_genai STATIC
${fbgemm_genai_native_rocm_hip}
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
target_include_directories(fbgemm_genai PUBLIC
# FBGEMM version of Composable Kernel is used due to some customizations
${FBGEMM_THIRD_PARTY}/composable_kernel/include
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
# Only compile for gfx942 for now.
# This is rather hacky, I could not figure out a clean solution :(
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
if("gfx942" IN_LIST PYTORCH_ROCM_ARCH)
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=gfx942;)
endif()
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
hip_add_library(
fbgemm_genai STATIC
${fbgemm_genai_native_rocm_hip}
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
target_include_directories(fbgemm_genai PRIVATE
# FBGEMM version of Composable Kernel is used due to some customizations
${FBGEMM_THIRD_PARTY}/composable_kernel/include
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
${FBGEMM_THIRD_PARTY}/cutlass/include
${FBGEMM_THIRD_PARTY}/cutlass/tools/util/include
${FBGEMM_GENAI_SRCS}/common/include/ # includes fbgemm_gpu/quantize/utils.h, fbgemm_gpu/quantize/tuning_cache.hpp
${FBGEMM_GENAI_SRCS}/include/ # includes fbgemm_gpu/torch_ops.h
)
# Add FBGEMM_GENAI include directories for torch_ops.h
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
list(APPEND ATen_HIP_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
endif()
endif()
@ -692,12 +700,6 @@ if(USE_CUDA AND NOT USE_ROCM)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
# Add FBGEMM_GENAI include directories for torch_ops.h
if(USE_FBGEMM_GENAI)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/include)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/common/include)
endif()
if($ENV{ATEN_STATIC_CUDA})
if(CUDA_VERSION VERSION_LESS_EQUAL 12.9)
list(APPEND ATen_CUDA_DEPENDENCY_LIBS

View File

@ -229,10 +229,10 @@ private:
}
static const uint32_t kPhilox10A = 0x9E3779B9;
static const uint32_t kPhilox10B = 0xBB67AE85;
static const uint32_t kPhiloxSA = 0xD2511F53;
static const uint32_t kPhiloxSB = 0xCD9E8D57;
static constexpr uint32_t kPhilox10A = 0x9E3779B9;
static constexpr uint32_t kPhilox10B = 0xBB67AE85;
static constexpr uint32_t kPhiloxSA = 0xD2511F53;
static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
};
typedef philox_engine Philox4_32;

View File

@ -8,6 +8,7 @@
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
#endif
#include <ATen/cpu/vec/vec128/vec128_convert.h>

View File

@ -0,0 +1,794 @@
#pragma once
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
namespace at::vec {
// Note [CPU_CAPABILITY namespace]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This header, and all of its subheaders, will be compiled with
// different architecture flags for each supported set of vector
// intrinsics. So we need to make sure they aren't inadvertently
// linked together. We do this by declaring objects in an `inline
// namespace` which changes the name mangling, but can still be
// accessed as `at::vec`.
inline namespace CPU_CAPABILITY {
#define VEC_INT_NEON_TEMPLATE(vl, bit) \
template <> \
struct is_vec_specialized_for<int##bit##_t> : std::bool_constant<true> {}; \
\
template <> \
class Vectorized<int##bit##_t> { \
using neon_type = int##bit##x##vl##_t; \
\
private: \
neon_type values; \
\
public: \
using value_type = int##bit##_t; \
using size_type = int; \
static constexpr size_type size() { \
return vl; \
} \
Vectorized() { \
values = vdupq_n_s##bit(0); \
} \
Vectorized(neon_type v) : values(v) {} \
Vectorized(int##bit##_t val); \
template < \
typename... Args, \
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
Vectorized(Args... vals) { \
__at_align__ int##bit##_t buffer[size()] = {vals...}; \
values = vld1q_s##bit(buffer); \
} \
operator neon_type() const { \
return values; \
} \
static Vectorized<int##bit##_t> loadu( \
const void* ptr, \
int64_t count = size()); \
void store(void* ptr, int64_t count = size()) const; \
template <int64_t mask> \
static Vectorized<int##bit##_t> blend( \
const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b); \
static Vectorized<int##bit##_t> blendv( \
const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b, \
const Vectorized<int##bit##_t>& mask_) { \
return vbslq_s##bit(vreinterpretq_u##bit##_s##bit(mask_.values), b, a); \
} \
template <typename step_t> \
static Vectorized<int##bit##_t> arange( \
value_type base = 0, \
step_t step = static_cast<step_t>(1)); \
static Vectorized<int##bit##_t> set( \
const Vectorized<int##bit##_t>& a, \
const Vectorized<int##bit##_t>& b, \
int64_t count = size()); \
const int##bit##_t& operator[](int idx) const = delete; \
int##bit##_t& operator[](int idx) = delete; \
Vectorized<int##bit##_t> abs() const { \
return vabsq_s##bit(values); \
} \
Vectorized<int##bit##_t> real() const { \
return values; \
} \
Vectorized<int##bit##_t> imag() const { \
return vdupq_n_s##bit(0); \
} \
Vectorized<int##bit##_t> conj() const { \
return values; \
} \
Vectorized<int##bit##_t> neg() const { \
return vnegq_s##bit(values); \
} \
int##bit##_t reduce_add() const { \
return vaddvq_s##bit(values); \
} \
int##bit##_t reduce_max() const; \
Vectorized<int##bit##_t> operator==( \
const Vectorized<int##bit##_t>& other) const { \
return Vectorized<value_type>( \
vreinterpretq_s##bit##_u##bit(vceqq_s##bit(values, other.values))); \
} \
Vectorized<int##bit##_t> operator!=( \
const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> operator<( \
const Vectorized<int##bit##_t>& other) const { \
return Vectorized<value_type>( \
vreinterpretq_s##bit##_u##bit(vcltq_s##bit(values, other.values))); \
} \
Vectorized<int##bit##_t> operator<=( \
const Vectorized<int##bit##_t>& other) const { \
return Vectorized<value_type>( \
vreinterpretq_s##bit##_u##bit(vcleq_s##bit(values, other.values))); \
} \
Vectorized<int##bit##_t> operator>( \
const Vectorized<int##bit##_t>& other) const { \
return Vectorized<value_type>( \
vreinterpretq_s##bit##_u##bit(vcgtq_s##bit(values, other.values))); \
} \
Vectorized<int##bit##_t> operator>=( \
const Vectorized<int##bit##_t>& other) const { \
return Vectorized<value_type>( \
vreinterpretq_s##bit##_u##bit(vcgeq_s##bit(values, other.values))); \
} \
Vectorized<int##bit##_t> eq(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> ne(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> gt(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> ge(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> lt(const Vectorized<int##bit##_t>& other) const; \
Vectorized<int##bit##_t> le(const Vectorized<int##bit##_t>& other) const; \
}; \
template <> \
Vectorized<int##bit##_t> inline operator+( \
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
return vaddq_s##bit(a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator-( \
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
return vsubq_s##bit(a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator&( \
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
return vandq_s##bit(a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator|( \
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
return vorrq_s##bit(a, b); \
} \
template <> \
Vectorized<int##bit##_t> inline operator^( \
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
return veorq_s##bit(a, b); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::eq( \
const Vectorized<int##bit##_t>& other) const { \
return (*this == other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ne( \
const Vectorized<int##bit##_t>& other) const { \
return (*this != other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::gt( \
const Vectorized<int##bit##_t>& other) const { \
return (*this > other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ge( \
const Vectorized<int##bit##_t>& other) const { \
return (*this >= other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::lt( \
const Vectorized<int##bit##_t>& other) const { \
return (*this < other) & Vectorized<int##bit##_t>(1); \
} \
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::le( \
const Vectorized<int##bit##_t>& other) const { \
return (*this <= other) & Vectorized<int##bit##_t>(1); \
}
VEC_INT_NEON_TEMPLATE(2, 64)
VEC_INT_NEON_TEMPLATE(4, 32)
VEC_INT_NEON_TEMPLATE(8, 16)
VEC_INT_NEON_TEMPLATE(16, 8)
inline int32_t Vectorized<int32_t>::reduce_max() const {
return vmaxvq_s32(values);
}
inline int16_t Vectorized<int16_t>::reduce_max() const {
return vmaxvq_s16(values);
}
inline int8_t Vectorized<int8_t>::reduce_max() const {
return vmaxvq_s8(values);
}
template <>
Vectorized<int32_t> inline operator*(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
return vmulq_s32(a, b);
}
template <>
Vectorized<int16_t> inline operator*(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
return vmulq_s16(a, b);
}
template <>
Vectorized<int8_t> inline operator*(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
return vmulq_s8(a, b);
}
template <>
inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) {
int64x2_t val = a;
return ~val;
}
template <>
inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) {
return vmvnq_s32(a);
}
template <>
inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) {
return vmvnq_s16(a);
}
template <>
inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) {
return vmvnq_s8(a);
}
inline Vectorized<int64_t> Vectorized<int64_t>::operator!=(
const Vectorized<int64_t>& other) const {
return ~(*this == other);
}
inline Vectorized<int32_t> Vectorized<int32_t>::operator!=(
const Vectorized<int32_t>& other) const {
return ~(*this == other);
}
inline Vectorized<int16_t> Vectorized<int16_t>::operator!=(
const Vectorized<int16_t>& other) const {
return ~(*this == other);
}
inline Vectorized<int8_t> Vectorized<int8_t>::operator!=(
const Vectorized<int8_t>& other) const {
return ~(*this == other);
}
template <>
Vectorized<int32_t> inline minimum(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
return vminq_s32(a, b);
}
template <>
Vectorized<int16_t> inline minimum(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
return vminq_s16(a, b);
}
template <>
Vectorized<int8_t> inline minimum(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
return vminq_s8(a, b);
}
template <>
Vectorized<int32_t> inline maximum(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
return vmaxq_s32(a, b);
}
template <>
Vectorized<int16_t> inline maximum(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
return vmaxq_s16(a, b);
}
template <>
Vectorized<int8_t> inline maximum(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
return vmaxq_s8(a, b);
}
template <int64_t mask>
Vectorized<int64_t> Vectorized<int64_t>::blend(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint64x2_t maskArray = {
(mask & 1LL) ? 0xFFFFFFFFFFFFFFFF : 0,
(mask & 2LL) ? 0xFFFFFFFFFFFFFFFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s64(maskArray, b.values, a.values);
}
template <int64_t mask>
Vectorized<int32_t> Vectorized<int32_t>::blend(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint32x4_t maskArray = {
(mask & 1LL) ? 0xFFFFFFFF : 0,
(mask & 2LL) ? 0xFFFFFFFF : 0,
(mask & 4LL) ? 0xFFFFFFFF : 0,
(mask & 8LL) ? 0xFFFFFFFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s32(maskArray, b.values, a.values);
}
template <int64_t mask>
Vectorized<int16_t> Vectorized<int16_t>::blend(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint16x8_t maskArray = {
(mask & 1LL) ? 0xFFFF : 0,
(mask & 2LL) ? 0xFFFF : 0,
(mask & 4LL) ? 0xFFFF : 0,
(mask & 8LL) ? 0xFFFF : 0,
(mask & 16LL) ? 0xFFFF : 0,
(mask & 32LL) ? 0xFFFF : 0,
(mask & 64LL) ? 0xFFFF : 0,
(mask & 128LL) ? 0xFFFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s16(maskArray, b.values, a.values);
}
template <int64_t mask>
Vectorized<int8_t> Vectorized<int8_t>::blend(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
// Build an array of flags: each bit of element is 1 if the corresponding bit
// in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
(mask & 1LL) ? 0xFF : 0,
(mask & 2LL) ? 0xFF : 0,
(mask & 4LL) ? 0xFF : 0,
(mask & 8LL) ? 0xFF : 0,
(mask & 16LL) ? 0xFF : 0,
(mask & 32LL) ? 0xFF : 0,
(mask & 64LL) ? 0xFF : 0,
(mask & 128LL) ? 0xFF : 0,
(mask & 256LL) ? 0xFF : 0,
(mask & 512LL) ? 0xFF : 0,
(mask & 1024LL) ? 0xFF : 0,
(mask & 2048LL) ? 0xFF : 0,
(mask & 4096LL) ? 0xFF : 0,
(mask & 8192LL) ? 0xFF : 0,
(mask & 16384LL) ? 0xFF : 0,
(mask & 32768LL) ? 0xFF : 0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s8(maskArray, b.values, a.values);
}
#define VEC_INT_NEON_OPS(vl, bit) \
inline Vectorized<int##bit##_t>::Vectorized(int##bit##_t val) { \
values = vdupq_n_s##bit(val); \
} \
inline Vectorized<int##bit##_t> Vectorized<int##bit##_t>::loadu( \
const void* ptr, int64_t count) { \
if (count == size()) { \
return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(ptr)); \
} else { \
__at_align__ int##bit##_t tmp_values[size()]; \
for (const auto i : c10::irange(size())) { \
tmp_values[i] = 0; \
} \
std::memcpy( \
tmp_values, \
reinterpret_cast<const int##bit##_t*>(ptr), \
count * sizeof(int##bit##_t)); \
return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(tmp_values)); \
} \
} \
inline void Vectorized<int##bit##_t>::store(void* ptr, int64_t count) \
const { \
if (count == size()) { \
vst1q_s##bit(reinterpret_cast<int##bit##_t*>(ptr), values); \
} else { \
int##bit##_t tmp_values[size()]; \
vst1q_s##bit(reinterpret_cast<int##bit##_t*>(tmp_values), values); \
std::memcpy(ptr, tmp_values, count * sizeof(int##bit##_t)); \
} \
}
VEC_INT_NEON_OPS(2, 64)
VEC_INT_NEON_OPS(4, 32)
VEC_INT_NEON_OPS(8, 16)
VEC_INT_NEON_OPS(16, 8)
template <>
Vectorized<int64_t> inline operator*(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t x = a;
int64x2_t y = b;
return x * y;
}
template <>
Vectorized<int64_t> inline operator/(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t x = a;
int64x2_t y = b;
return x / y;
}
template <>
Vectorized<int32_t> inline operator/(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
int32x4_t x = a;
int32x4_t y = b;
return x / y;
}
inline int64_t Vectorized<int64_t>::reduce_max() const {
return std::max(values[0], values[1]);
}
template <>
Vectorized<int64_t> inline minimum(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t x = a;
int64x2_t y = b;
return {std::min(x[0], y[0]), std::min(x[1], y[1])};
}
template <>
Vectorized<int64_t> inline maximum(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t x = a;
int64x2_t y = b;
return {std::max(x[0], y[0]), std::max(x[1], y[1])};
}
template <typename step_t>
inline Vectorized<int64_t> Vectorized<int64_t>::arange(
int64_t base,
step_t step) {
const Vectorized<int64_t> base_vec(base);
const Vectorized<int64_t> step_vec(step);
const int64x2_t step_sizes = {0, 1};
return base_vec.values + step_sizes * step_vec.values;
}
template <typename step_t>
inline Vectorized<int32_t> Vectorized<int32_t>::arange(
int32_t base,
step_t step) {
const Vectorized<int32_t> base_vec(base);
const Vectorized<int32_t> step_vec(step);
const int32x4_t step_sizes = {0, 1, 2, 3};
return vmlaq_s32(base_vec, step_sizes, step_vec);
}
template <typename step_t>
inline Vectorized<int16_t> Vectorized<int16_t>::arange(
int16_t base,
step_t step) {
const Vectorized<int16_t> base_vec(base);
const Vectorized<int16_t> step_vec(step);
const int16x8_t step_sizes = {0, 1, 2, 3, 4, 5, 6, 7};
return vmlaq_s16(base_vec, step_sizes, step_vec);
}
template <typename step_t>
inline Vectorized<int8_t> Vectorized<int8_t>::arange(int8_t base, step_t step) {
const Vectorized<int8_t> base_vec(base);
const Vectorized<int8_t> step_vec(step);
const int8x16_t step_sizes = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
return vmlaq_s8(base_vec, step_sizes, step_vec);
}
template <>
Vectorized<int64_t> inline operator>>(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t x = a;
int64x2_t y = b;
uint64x2_t u = vreinterpretq_u64_s64(y);
uint64x2_t z = {std::min(u[0], (uint64_t)63), std::min(u[1], (uint64_t)63)};
return x >> vreinterpretq_s64_u64(z);
}
template <>
Vectorized<int32_t> inline operator>>(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
int32x4_t x = a;
int32x4_t y = b;
uint32x4_t bound = vdupq_n_u32(31);
uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
return x >> vreinterpretq_s32_u32(z);
}
template <>
Vectorized<int16_t> inline operator>>(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
int16x8_t x = a;
int16x8_t y = b;
uint16x8_t bound = vdupq_n_u16(15);
uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
return x >> vreinterpretq_s16_u16(z);
}
template <>
Vectorized<int8_t> inline operator>>(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
int8x16_t x = a;
int8x16_t y = b;
uint8x16_t bound = vdupq_n_u8(7);
int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
return x >> z;
}
template <>
Vectorized<int64_t> inline operator<<(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b) {
int64x2_t y = b;
uint64x2_t u = vreinterpretq_u64_s64(y);
uint64x2_t z = {std::min(u[0], (uint64_t)64), std::min(u[1], (uint64_t)64)};
return vshlq_s64(a, vreinterpretq_s64_u64(z));
}
template <>
Vectorized<int32_t> inline operator<<(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b) {
int32x4_t y = b;
uint32x4_t bound = vdupq_n_u32(32);
uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
return vshlq_s32(a, vreinterpretq_s32_u32(z));
}
template <>
Vectorized<int16_t> inline operator<<(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
int16x8_t y = b;
uint16x8_t bound = vdupq_n_u16(16);
uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
return vshlq_s16(a, vreinterpretq_s16_u16(z));
}
template <>
Vectorized<int8_t> inline operator<<(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
int8x16_t y = b;
uint8x16_t bound = vdupq_n_u8(8);
int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
return vshlq_s8(a, z);
}
inline Vectorized<int64_t> Vectorized<int64_t>::set(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& b,
int64_t count) {
if (count == 0) {
return a;
} else if (count >= 2) {
return b;
} else {
int64x2_t c = {b.values[0], a.values[1]};
return c;
}
}
inline Vectorized<int32_t> Vectorized<int32_t>::set(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& b,
int64_t count) {
if (count == 0) {
return a;
} else if (count >= 4) {
return b;
} else {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint32x4_t maskArray = {
(count >= 1LL) ? 0xFFFFFFFF : 0,
(count >= 2LL) ? 0xFFFFFFFF : 0,
(count >= 3LL) ? 0xFFFFFFFF : 0,
0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s32(maskArray, b.values, a.values);
}
}
inline Vectorized<int16_t> Vectorized<int16_t>::set(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b,
int64_t count) {
if (count == 0) {
return a;
} else if (count >= 8) {
return b;
} else {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint16x8_t maskArray = {
static_cast<uint16_t>((count >= 1LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 2LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 3LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 4LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 5LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 6LL) ? 0xFFFF : 0),
static_cast<uint16_t>((count >= 7LL) ? 0xFFFF : 0),
0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s16(maskArray, b.values, a.values);
}
}
inline Vectorized<int8_t> Vectorized<int8_t>::set(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b,
int64_t count) {
if (count == 0) {
return a;
} else if (count >= 16) {
return b;
} else {
// Build an array of flags: each bit of element is 1 if the corresponding
// bit in 'mask' is set, 0 otherwise.
uint8x16_t maskArray = {
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
0};
// Use BSL to select elements from b where the mask is 1, else from a
return vbslq_s8(maskArray, b.values, a.values);
}
}
template <>
Vectorized<int16_t> inline operator/(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& b) {
Vectorized<int32_t> highBitsA = vmovl_high_s16(a);
Vectorized<int32_t> highBitsB = vmovl_high_s16(b);
Vectorized<int32_t> lowBitsA = vmovl_s16(vget_low_s16(a));
Vectorized<int32_t> lowBitsB = vmovl_s16(vget_low_s16(b));
int32x4_t highBitsResult = highBitsA / highBitsB;
int32x4_t lowBitsResult = lowBitsA / lowBitsB;
return vuzp1q_s16(
vreinterpretq_s16_s32(lowBitsResult),
vreinterpretq_s16_s32(highBitsResult));
}
template <>
Vectorized<int8_t> inline operator/(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& b) {
Vectorized<int16_t> highBitsA = vmovl_high_s8(a);
Vectorized<int16_t> highBitsB = vmovl_high_s8(b);
Vectorized<int16_t> lowBitsA = vmovl_s8(vget_low_s8(a));
Vectorized<int16_t> lowBitsB = vmovl_s8(vget_low_s8(b));
int16x8_t highBitsResult = highBitsA / highBitsB;
int16x8_t lowBitsResult = lowBitsA / lowBitsB;
return vuzp1q_s8(
vreinterpretq_s8_s16(lowBitsResult),
vreinterpretq_s8_s16(highBitsResult));
}
template <>
Vectorized<int64_t> inline clamp(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& min,
const Vectorized<int64_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<int32_t> inline clamp(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& min,
const Vectorized<int32_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<int16_t> inline clamp(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& min,
const Vectorized<int16_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<int8_t> inline clamp(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& min,
const Vectorized<int8_t>& max) {
return minimum(max, maximum(min, a));
}
template <>
Vectorized<int64_t> inline clamp_max(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<int32_t> inline clamp_max(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<int16_t> inline clamp_max(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<int8_t> inline clamp_max(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& max) {
return minimum(max, a);
}
template <>
Vectorized<int64_t> inline clamp_min(
const Vectorized<int64_t>& a,
const Vectorized<int64_t>& min) {
return maximum(min, a);
}
template <>
Vectorized<int32_t> inline clamp_min(
const Vectorized<int32_t>& a,
const Vectorized<int32_t>& min) {
return maximum(min, a);
}
template <>
Vectorized<int16_t> inline clamp_min(
const Vectorized<int16_t>& a,
const Vectorized<int16_t>& min) {
return maximum(min, a);
}
template <>
Vectorized<int8_t> inline clamp_min(
const Vectorized<int8_t>& a,
const Vectorized<int8_t>& min) {
return maximum(min, a);
}
} // namespace CPU_CAPABILITY
} // namespace at::vec

View File

@ -1377,7 +1377,7 @@ Vectorized<c10::quint8> inline maximum(
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
at::vec::Vectorized<int8_t> src) {
auto s8x8 = vld1_s8(src.operator const int8_t*());
auto s8x8 = vget_low_s8(src);
auto s16x8 = vmovl_s8(s8x8);
auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
@ -1402,7 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
Vectorized<float> inline convert_int8_half_register_to_float(
at::vec::Vectorized<int8_t> src) {
auto s8x8 = vld1_s8(src.operator const int8_t*());
auto s8x8 = vget_low_s8(src);
auto s16x8 = vmovl_s8(s8x8);
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));

View File

@ -16,6 +16,8 @@
#include <c10/util/irange.h>
#include <c10/core/ScalarType.h>
#include <ATen/cuda/detail/BLASConstants.h>
#ifdef USE_ROCM
#include <c10/cuda/CUDAStream.h>
#include <hipblaslt/hipblaslt-ext.hpp>
@ -1954,13 +1956,15 @@ void scaled_gemm(
const void *result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
bool use_fast_accum) {
bool use_fast_accum,
const std::optional<Tensor>& alpha) {
// Note: see `cublasCommonArgs` for various non-intuitive manupulations
// of input arguments to this function.
const auto computeType = CUBLAS_COMPUTE_32F;
const auto scaleType = CUDA_R_32F;
const float alpha_val = 1.0;
const float beta_val = 0.0;
// Note: alpha_val may change later depending on user-passed argument
float alpha_val = 1.0;
float beta_val = 0.0;
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
@ -2031,6 +2035,33 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}
// Handle user-passed alpha
float *alpha_ptr = &alpha_val;
float *beta_ptr = &beta_val;
if (alpha.has_value()) {
auto& a = alpha.value();
// if device-tensor
if (a.is_cuda()) {
// NOTE: there are lifetime requirements on device-side pointers for alpha/beta -- the value must be
// valid & correct until the cublas call finishes (not is scheduled like host-side values). Thus
// we need to use allocations for alpha/beta that have some guarantees on lifetime - a statically
// managed 4B buffer for alpha that we'll copy the passed alpha value into, and constant memory
// for beta respectively.
float *user_alpha_ptr = at::cuda::detail::get_user_alpha_ptr();
at::Tensor user_alpha = at::from_blob(user_alpha_ptr, {1}, TensorOptions().device(kCUDA).dtype(kFloat));
user_alpha.copy_(a);
// Tell cublasLt we're using device-side pointers for alpha/beta
auto pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_POINTER_MODE, pointer_mode);
alpha_ptr = user_alpha.data_ptr<float>();
beta_ptr = at::cuda::detail::get_cublas_device_zero();
} else {
alpha_val = a.item<float>();
}
}
// For other data types, use the get_scale_mode function based on scaling type
// The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt,
// but we must invoke get_scale_mode anyways to trigger the version checks.
@ -2048,6 +2079,7 @@ void scaled_gemm(
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
@ -2088,10 +2120,10 @@ void scaled_gemm(
auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported(
ltHandle,
computeDesc.descriptor(),
&alpha_val,
alpha_ptr,
Adesc.descriptor(),
Bdesc.descriptor(),
&beta_val,
beta_ptr,
Cdesc.descriptor(),
Ddesc.descriptor(),
all_algos[i].algo,
@ -2110,17 +2142,14 @@ void scaled_gemm(
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
&alpha_val,
alpha_ptr,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
&beta_val,
#ifdef USE_ROCM
beta_ptr,
// NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either
result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr
#else
nullptr,
#endif // ifdef USE_ROCM
Cdesc.descriptor(),
result_ptr,
Ddesc.descriptor(),

View File

@ -161,7 +161,8 @@ void scaled_gemm(
const void* result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
bool use_fast_accum);
bool use_fast_accum,
const std::optional<Tensor>& alpha);
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)

View File

@ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() {
*/
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
// The RNG state comprises the seed, and an offset used for Philox.
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(int64_t);
constexpr size_t total_size = seed_size + offset_size;
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>();
@ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
* and size of the internal state.
*/
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(int64_t);
constexpr size_t total_size = seed_size + offset_size;
detail::check_rng_state(new_state);

View File

@ -183,11 +183,6 @@ struct CUDACachingHostAllocatorImpl
return true;
}
bool pinned_use_background_threads() override {
return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
pinned_use_background_threads();
}
EventPool::Event create_event_internal(DeviceIndex idx) {
// Leak the event pool to avoid shutdown issue.
static auto* event_pool = new EventPool();

View File

@ -177,7 +177,6 @@ inline void segmented_sort_pairs(
}
}
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
inline void unique_by_key(
KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
@ -193,7 +192,6 @@ inline void unique_by_key(
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
}
#endif
namespace impl {
@ -579,7 +577,6 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
#endif
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
@ -607,7 +604,6 @@ inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT
#endif
}
#endif
template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
void unique(InputIteratorT input, OutputIteratorT output,

View File

@ -28,22 +28,6 @@
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
#endif
// cub support for UniqueByKey is added to cub 1.16 in:
// https://github.com/NVIDIA/cub/pull/405
#if CUB_VERSION >= 101600
#define CUB_SUPPORTS_UNIQUE_BY_KEY() true
#else
#define CUB_SUPPORTS_UNIQUE_BY_KEY() false
#endif
// cub support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
// cub support for cub::FutureValue is added to cub 1.15 in:
// https://github.com/NVIDIA/cub/pull/305
#if CUB_VERSION >= 101500

View File

@ -0,0 +1,54 @@
#include <ATen/Functions.h>
#include <ATen/Tensor.h>
#include <ATen/cuda/Exceptions.h>
#include <mutex>
namespace at {
namespace cuda {
namespace detail {
__device__ __constant__ float cublas_one_device;
__device__ __constant__ float cublas_zero_device;
float *get_cublas_device_one() {
static c10::once_flag init_flag;
c10::call_once(init_flag, []() {
const float one = 1.f;
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
});
float *ptr;
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
return ptr;
}
float *get_cublas_device_zero() {
static c10::once_flag init_flag;
c10::call_once(init_flag, []() {
const float zero = 0.f;
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
});
float *ptr;
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
return ptr;
}
float *get_user_alpha_ptr() {
static float *alpha_ptr;
static c10::once_flag init_flag;
c10::call_once(init_flag, []() {
AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
});
return alpha_ptr;
}
} // namespace detail
} // namespace cuda
} // namespace at

View File

@ -0,0 +1,11 @@
#pragma once
#include <ATen/core/TensorBase.h>
namespace at::cuda::detail {
float *get_cublas_device_one();
float *get_cublas_device_zero();
float *get_user_alpha_ptr();
} // namespace at::cuda::detail

View File

@ -13,6 +13,7 @@
#include <c10/core/ScalarType.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/util/StringUtil.h>
@ -150,6 +151,7 @@ inline std::string ScalarTypeToBLASType(c10::ScalarType scalar_type) {
BLASType = "unknown";
}
return BLASType;
}
// Similar to Compute Type in GemmRocblas.h
@ -244,33 +246,25 @@ inline std::string to_string_epilogue(const at::cuda::blas::GEMMAndBiasActivatio
namespace detail {
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size, const NumericalCheckConfig& config) {
if (!config.enabled) {
return true; // skip when disabled
}
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
// comparison done as 1D tensor
at::Tensor ref = at::from_blob(c, {size}, options);
at::Tensor oth = at::from_blob(other_c, {size}, options);
at::Tensor ref_float = ref.to(at::kFloat);
at::Tensor oth_float = oth.to(at::kFloat);
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
double last_succeed_atol = 1;
double last_succeed_rtol = 1;
for (auto& atol : atols) {
for (auto& rtol : rtols) {
if (at::allclose(ref_float, oth_float, rtol, atol)) {
last_succeed_atol = atol;
last_succeed_rtol = rtol;
}
}
}
if (last_succeed_atol == 1) {
return false;
}
else {
TUNABLE_LOG3("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
}
return true;
const bool ok = at::allclose(ref_float, oth_float, config.rtol, config.atol);
if (ok) {
TUNABLE_LOG3("├──verify numerics: PASSED with atol=", config.atol, ", rtol=", config.rtol);
} else {
TUNABLE_LOG3("├──verify numerics: FAILED with atol=", config.atol, ", rtol=", config.rtol);
}
return ok;
}
}
@ -355,8 +349,10 @@ struct GemmParams : OpParams {
}
TuningStatus NumericalCheck(GemmParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};
@ -449,8 +445,10 @@ struct GemmAndBiasParams : OpParams {
}
TuningStatus NumericalCheck(GemmAndBiasParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};
@ -546,8 +544,10 @@ struct GemmStridedBatchedParams : OpParams {
}
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};
@ -663,7 +663,9 @@ struct ScaledGemmParams : OpParams {
}
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};

View File

@ -145,7 +145,7 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins
| PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. |
| PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. |
| PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. |
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is 0. Set to 1 to enable. |
| PYTORCH_TUNABLEOP_NUMERICAL_CHECK | Default is off. Set 'atol_rtol' to enable, for example "1e-5_1e-5". |
| PYTORCH_TUNABLEOP_ROCBLAS_ENABLED | Default is 1. Set to 0 to disable rocblas being considered during tuning. |
| PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED | Default is 1. Set to 0 to disable hipblaslt being considered during tuning. |
| PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS | Default is 30. Unit is milliseconds. |
@ -173,10 +173,9 @@ All python APIs exist in the `torch.cuda.tunable` module.
| get_max_tuning_iterations() -> int | |
| set_filename(filename: str, insert_device_ordinal: bool = False) -> None | |
| get_filename() -> str | |
| set_numerical_check_tolerances(enable: bool, atol: float, rtol: float) -> None | Enable or disable numerical checking; atol and rtol default to 1e-5.
| get_results() -> Tuple[str, str, str, float] | |
| get_validators() -> Tuple[str, str] | |
| write_file_on_exit(val: bool) -> None | Default is True. |
| write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
| read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). |
| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. |
| mgpu_tune_gemm_in_file(filename_pattern: str, num_gpus: int) -> None: -> None | read one or more untuned files and tune all unique GEMMs on one or more GPUs. |

View File

@ -107,14 +107,30 @@ void TuningResultsManager::AddImpl(const std::string& op_signature,
}
void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
std::scoped_lock l{lock_};
bool is_new = false;
ResultEntry inserted = ResultEntry::Null();
auto it = results_.find(op_signature);
if (it == results_.end()) {
it = results_.insert({op_signature, {}}).first;
// ---- mutate maps under results lock ----
{
std::scoped_lock l{lock_};
auto& km = results_[op_signature]; // creates if missing
is_new = (km.find(params_signature) == km.end());
AddImpl(op_signature, params_signature, std::move(best), km);
if (is_new) {
inserted = km.at(params_signature); // snapshot for I/O after unlocking
}
}
if (!is_new) return; // only write once per unique (op, params)
TuningContext* ctx = getTuningContext();
if (ctx->IsTuningEnabled() && !ctx->IsRecordUntunedEnabled()) {
InitRealtimeAppend(ctx->GetFilename(), ctx->GetTuningResultsValidator().GetAllValidators());
if (is_new && realtime_out_ && realtime_out_->good()) {
AppendResultLine(op_signature, params_signature, inserted);
}
}
AddImpl(op_signature, params_signature, std::move(best), it->second);
}
void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
@ -150,6 +166,77 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
}
}
void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const std::unordered_map<std::string, std::string>& validators) {
std::scoped_lock fl{realtime_file_mutex_};
if (realtime_out_ && realtime_out_->good() && realtime_filename_ == filename) {
return;
}
if (realtime_out_ && realtime_filename_ != filename) {
realtime_out_->flush();
realtime_out_->close();
realtime_out_.reset();
validators_written_ = false;
}
bool file_exists = false;
bool file_empty = true;
{
std::ifstream check_file(filename);
if (check_file.good()) {
file_exists = true;
file_empty = (check_file.peek() == std::ifstream::traits_type::eof());
}
}
realtime_out_ = std::make_unique<std::ofstream>(filename, std::ios::out | std::ios::app);
if (!realtime_out_->good()) {
TORCH_WARN("TunableOp realtime append: failed to open '", filename,"'");
realtime_out_.reset();
return;
}
if(!file_exists || file_empty) {
for(const auto& [key, val] : validators) {
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
realtime_out_->flush();
}
validators_written_ = true;
TUNABLE_LOG2("Wrote validators to realtime output file");
}
realtime_filename_ = filename;
}
void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std::string& param_sig, const ResultEntry& result) {
std::scoped_lock fl{realtime_file_mutex_};
if(!realtime_out_ || !realtime_out_->good()) {
return;
}
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
realtime_out_->flush(); //ensure immediate write to disk
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
}
void TuningResultsManager::CloseRealtimeAppend() {
std::scoped_lock fl{realtime_file_mutex_};
if(realtime_out_) {
realtime_out_->flush();
realtime_out_->close();
realtime_out_.reset();
TUNABLE_LOG2("Closed realtime output file");
}
}
void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
std::scoped_lock l{lock_};
@ -396,7 +483,6 @@ TuningContext::TuningContext() :
tuning_enable_{true},
record_untuned_enable_{false},
manager_initialized_{false},
write_file_on_exit_{true},
numerics_check_enable_{false},
max_tuning_duration_ms_{30},
max_tuning_iterations_{100},
@ -417,20 +503,8 @@ TuningContext::~TuningContext() {
// but doesn't do any computation itself.
return;
}
auto filename = GetFilename();
if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) {
if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
if (results_count_from_input_file_ > 0) {
TUNABLE_LOG1("additional tuning results available, rewriting file ", filename);
}
else {
TUNABLE_LOG1("writing file ", filename);
}
if (!WriteFile(filename)) {
TUNABLE_LOG1("failed to write file ", filename);
}
}
}
TUNABLE_LOG1("Closing File");
GetTuningResultsManager().CloseRealtimeAppend(); // Since, we do instant logging by default now.
if (untuned_file_.good()) {
untuned_file_.close();
@ -511,20 +585,54 @@ std::ofstream& TuningContext::GetUntunedFile(){
return untuned_file_;
}
void TuningContext::WriteFileOnExit(bool value) {
write_file_on_exit_ = value;
}
void TuningContext::EnableNumericsCheck(bool value) {
numerics_check_enable_ = value;
}
bool TuningContext::IsNumericsCheckEnabled() const {
const auto env = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
if (env == "1") {
return true;
NumericalCheckConfig TuningContext::GetNumericalCheckConfig() const {
const auto env_opt = c10::utils::get_env("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
if (!env_opt.has_value()) {
return numerics_cfg_;
}
return numerics_check_enable_;
const std::string& env = env_opt.value();
if (env == "0") {
return NumericalCheckConfig(false, 1e-5, 1e-5);
}
const size_t underscore = env.find('_');
TORCH_CHECK(
underscore != std::string::npos,
"Invalid PYTORCH_TUNABLEOP_NUMERICAL_CHECK format. "
"Expected 'atol_rtol', got: ",
env);
double atol = 0.0;
double rtol = 0.0;
try {
atol = std::stod(env.substr(0, underscore));
rtol = std::stod(env.substr(underscore + 1));
} catch (const std::exception& e) {
TORCH_CHECK(false, "Failed to parse PYTORCH_TUNABLEOP_NUMERICAL_CHECK: ", e.what());
}
TORCH_CHECK( atol > 0.0 && rtol > 0.0, "Tolerance values must be positive. atol=", atol, ", rtol=", rtol);
return NumericalCheckConfig(true, atol, rtol);
}
void TuningContext::SetNumericalCheckConfig(bool enabled, double atol, double rtol) {
TORCH_CHECK(atol > 0.0 && rtol > 0.0, "Numerical check tolerances must be positive");
numerics_cfg_ = {enabled, atol, rtol};
}
bool TuningContext::IsNumericsCheckEnabled() const {
const auto cfg = GetNumericalCheckConfig();
return cfg.enabled || numerics_check_enable_;
}
void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
@ -634,11 +742,6 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() {
auto filename = GetFilename();
if (!filename.empty() && !IsRecordUntunedEnabled()) {
ReadFile(filename);
// attempt immediately to open file for writing to catch errors early
std::ofstream file(filename, std::ios::out | std::ios::app);
if (!file.good()) {
TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
}
}
});
return manager_;
@ -744,27 +847,6 @@ bool TuningContext::ReadFile(const std::string& filename_) {
return true;
}
bool TuningContext::WriteFile(const std::string& filename_) {
std::string filename = filename_.empty() ? GetFilename() : filename_;
std::ofstream file(filename, std::ios::out | std::ios::trunc);
if (!file.good()) {
TUNABLE_LOG1("error opening tuning results file for writing ", filename);
return false;
}
auto validators = GetTuningResultsValidator().GetAllValidators();
for (const auto& [key, val] : validators) {
file << "Validator," << key << "," << val << std::endl;
}
auto results = GetTuningResultsManager().Dump();
for (const auto& [op_sig, kernelmap] : results) {
for (const auto& [param_sig, result] : kernelmap) {
file << op_sig << "," << param_sig << "," << result << std::endl;
}
}
file.close();
return true;
}
namespace {
struct MaybeDelete {

View File

@ -103,10 +103,24 @@ class TORCH_CUDA_CPP_API TuningResultsManager {
void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature,
const std::string& params_signature, const std::string& blas_signature);
void InitRealtimeAppend(
const std::string& filename,
const std::unordered_map<std::string, std::string>& validators);
void AppendResultLine(const std::string& op_sig,
const std::string& param_sig,
const ResultEntry& result);
void CloseRealtimeAppend(); // For clean shutdown
private:
std::mutex lock_;
std::mutex realtime_file_mutex_;
std::unique_ptr<std::ofstream> realtime_out_;
std::string realtime_filename_;
ResultsMap results_;
UntunedMap untuned_results_;
bool validators_written_ = false;
};
@ -134,6 +148,16 @@ class TORCH_CUDA_CPP_API TuningResultsValidator {
GetValidateFuncs validators_;
};
struct NumericalCheckConfig {
bool enabled{false};
double atol{1e-5};
double rtol{1e-5};
NumericalCheckConfig() = default;
NumericalCheckConfig(bool e, double a, double r) : enabled(e), atol(a), rtol(r) {}
};
class TORCH_CUDA_CPP_API TuningContext {
public:
TuningContext();
@ -155,6 +179,8 @@ class TORCH_CUDA_CPP_API TuningContext {
void EnableNumericsCheck(bool value);
bool IsNumericsCheckEnabled() const;
void SetNumericalCheckConfig(bool enabled, double atol, double rtol);
NumericalCheckConfig GetNumericalCheckConfig() const;
void SetMaxTuningDurationMs(int max_duration_ms);
int GetMaxTuningDurationMs() const;
@ -185,10 +211,7 @@ class TORCH_CUDA_CPP_API TuningContext {
void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
std::string GetFilename() const;
void WriteFileOnExit(bool value);
bool ReadFile(const std::string& filename={});
bool WriteFile(const std::string& filename={});
template<class... Types>
void Log(int level, Types... args) {
@ -207,7 +230,6 @@ class TORCH_CUDA_CPP_API TuningContext {
bool tuning_enable_;
bool record_untuned_enable_;
bool manager_initialized_;
bool write_file_on_exit_;
bool numerics_check_enable_;
int max_tuning_duration_ms_;
int max_tuning_iterations_;
@ -222,6 +244,8 @@ class TORCH_CUDA_CPP_API TuningContext {
std::ofstream untuned_file_;
size_t results_count_from_input_file_;
bool is_shutting_down_;
NumericalCheckConfig numerics_cfg_{};
};
TORCH_CUDA_CPP_API TuningContext* getTuningContext();

View File

@ -109,7 +109,8 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
params->c_scale_ptr,
params->ldc,
params->c_dtype,
params->use_fast_accum);
params->use_fast_accum,
std::nullopt /* alpha */);
return OK;
}
};

View File

@ -267,27 +267,10 @@ class TunableOp {
for (size_t i = 0; i < op_names_.size(); i++) {
auto* candidate = ops_[op_names_[i]].get(); // borrow pointer
if (do_numerics_check) {
ParamsT* numerical_params = params->DeepCopy(false);
auto status = candidate->Call(numerical_params);
if (status != OK) {
numerical_params->Delete();
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
status = reference_params->NumericalCheck(numerical_params);
numerical_params->Delete();
if (status != OK) {
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
}
else {
auto status = candidate->Call(reusable_params[0]);
if (status != OK) {
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
auto status = candidate->Call(reusable_params[0]);
if (status != OK) {
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
// collect a small profile
@ -310,6 +293,22 @@ class TunableOp {
continue;
}
if (do_numerics_check) {
ParamsT* numerical_params = params->DeepCopy(false);
auto status = candidate->Call(numerical_params);
if (status != OK) {
numerical_params->Delete();
TUNABLE_LOG3("├──unsupported id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
status = reference_params->NumericalCheck(numerical_params);
numerical_params->Delete();
if (status != OK) {
TUNABLE_LOG3("├──numerics check failed for id=", i, ", ", op_sig, '(', params_sig, ") ", op_names_[i]);
continue;
}
}
// for warmup does user set max duration, max iters, or both?
// warmup is skipped by default, i.e. warmup_iter = 0
// warmup will be set to the non-zero value of max_warmup_duration

View File

@ -160,6 +160,10 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
DispatchKey::CUDA,
DispatchKey::CPU,
DispatchKey::PrivateUse1,
DispatchKey::SparseCPU,
DispatchKey::SparseCUDA,
DispatchKey::SparseCsrCPU,
DispatchKey::SparseCsrCUDA,
});
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {

View File

@ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) (
namespace at::native {
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
static const double SELU_SCALE = 1.0507009873554804934193349852946;
static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717;
static constexpr double SELU_SCALE = 1.0507009873554804934193349852946;
DEFINE_DISPATCH(elu_stub);
DEFINE_DISPATCH(elu_backward_stub);

View File

@ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in
#if AT_BUILD_WITH_BLAS()
template <>
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
auto intmax = std::numeric_limits<int>::max();
auto constexpr intmax = std::numeric_limits<int>::max();
return n <= intmax && incx <= intmax;
}
@ -315,7 +315,7 @@ bool gemv_use_fast_path<float>(
int64_t incx,
[[maybe_unused]] float beta,
int64_t incy) {
auto intmax = std::numeric_limits<int>::max();
auto constexpr intmax = std::numeric_limits<int>::max();
return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
(incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
}

View File

@ -658,6 +658,7 @@ static void check_shape_forward(const at::Tensor& input,
TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported");
TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported");
TORCH_CHECK(!params.is_dilation_neg(), "dilation should be greater than zero");
TORCH_CHECK(groups > 0, "expected groups to be greater than 0, but got groups=", groups);
TORCH_CHECK(weight_dim == k,
"Expected ", weight_dim, "-dimensional input for ", weight_dim,

View File

@ -1,5 +1,6 @@
#pragma once
#include <array>
#include <ATen/native/Math.h>
#include <c10/macros/Macros.h>
#include <c10/util/MathConstants.h>
@ -127,7 +128,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
template<typename scalar_t>
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
const static scalar_t kTailValues[] = {
constexpr static scalar_t kTailValues[] = {
0.0810614667953272,
0.0413406959554092,
0.0276779256849983,
@ -139,7 +140,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
0.00925546218271273,
0.00833056343336287
};
if (k <= 9) {
if (k < std::size(kTailValues)) {
return kTailValues[static_cast<size_t>(k)];
}
scalar_t kp1sq = (k + 1) * (k + 1);

View File

@ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
template <typename scalar_t>
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
static const scalar_t lanczos_sum_expg_scaled_num[13] = {
static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
@ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
@ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
template <typename scalar_t>
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
static const scalar_t d[25][25] =
static constexpr scalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,

View File

@ -62,7 +62,7 @@
#include <utility>
#include <vector>
static const int MIOPEN_DIM_MAX = 5;
static constexpr int MIOPEN_DIM_MAX = 5;
namespace at::meta {

View File

@ -1906,11 +1906,9 @@ Tensor& index_fill_(
"This also applies to advanced indexing e.g. tensor[mask] = scalar");
}
if (!self.is_complex() && source.isComplex()) {
TORCH_CHECK(
false,
"index_fill_(): Converting complex Scalar to non-complex type is not supported");
}
TORCH_CHECK(
self.is_complex() || !source.isComplex(),
"index_fill_(): Converting complex Scalar to non-complex type is not supported");
// Handle the case when `self` is 0-dim
Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self;

View File

@ -77,7 +77,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
// next broadcast all index tensors together
try {
indices = expand_outplace(indices);
} catch (std::exception& e) {
} catch (std::exception&) {
TORCH_CHECK_INDEX(
false,
"shape mismatch: indexing tensors could not be broadcast together"

View File

@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
} else if (dtype == ScalarType::Half) {
[&]() {
using scalar_t =
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
const auto exp = exp_scalar.to<scalar_t>();
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(iter,

View File

@ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase {
// We keep this structure for BC and consider as deprecated.
// See HelperInterpNearestExact as replacement
static const int interp_size = 1;
static constexpr int interp_size = 1;
static inline void init_indices_weights(
at::ScalarType output_type,
@ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest {
struct HelperInterpLinear : public HelperInterpBase {
static const int interp_size = 2;
static constexpr int interp_size = 2;
// Compute indices and weights for each interpolated dimension
// indices_weights = {
@ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase {
struct HelperInterpCubic : public HelperInterpBase {
static const int interp_size = 4;
static constexpr int interp_size = 4;
// Compute indices and weights for each interpolated dimension
// indices_weights = {

File diff suppressed because it is too large Load Diff

View File

@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
out_calc_t output_offset_calculator,
loader_t loader,
storer_t storer) {
if (ret_t == rt_binary_specializations[arg_index][0] &&
arg0_t == rt_binary_specializations[arg_index][1] &&
arg1_t == rt_binary_specializations[arg_index][2])
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
launch_vectorized_templated_kernel<
func_t,
array_t,
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
out_calc_t,
loader_t,
storer_t,
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][0]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][1]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][2]>::t)>(
cret_t,
carg0_t,
carg1_t>(
numel,
f,
data,
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
output_offset_calculator,
loader,
storer);
}
}
};

View File

@ -38,12 +38,41 @@ __device__ inline int min(int a, int b) {
#define BLOCK_STRIDE_BWD 2 // increasing block_stride to lower # of blocks launched
#endif
static __device__ inline int p_start(int size, int pad, int kernel, int dilation, int stride) {
return (size + pad < ((kernel - 1) * dilation + 1)) ? 0 : (size + pad - ((kernel - 1) * dilation + 1)) / stride + 1;
template <typename index_t>
static __device__ inline index_t p_start(index_t size, int pad, int kernel, int dilation, int stride) {
const auto kernel_extent = static_cast<index_t>((kernel - 1) * dilation + 1);
return (size + pad < kernel_extent) ? index_t(0) : (size + pad - kernel_extent) / stride + 1;
}
static __device__ inline int p_end(int size, int pad, int pooled_size, int stride) {
return min((size + pad) / stride + 1, pooled_size);
template <typename index_t>
static __device__ inline index_t p_end(index_t size, int pad, index_t pooled_size, int stride) {
return std::min((size + pad) / stride + 1, pooled_size);
}
static inline bool can_use_int32_nhwc(
int64_t nbatch, int64_t channels,
int64_t height, int64_t width,
int64_t pooled_height, int64_t pooled_width,
int64_t in_stride_n, int64_t in_stride_c,
int64_t in_stride_h, int64_t in_stride_w)
{
constexpr int64_t int_max = std::numeric_limits<int>::max();
int64_t max_intra_batch =
(height ? (height - 1) * in_stride_h : 0) +
(width ? (width - 1) * in_stride_w : 0) +
(channels? (channels - 1) * in_stride_c : 0);
int64_t max_input_offset = (nbatch ? (nbatch - 1) * in_stride_n : 0) + max_intra_batch;
if (max_input_offset > int_max) return false;
int64_t out_batch_stride = pooled_height * pooled_width * channels;
if ((nbatch ? (nbatch - 1) * out_batch_stride : 0) > int_max) return false;
if (height * width > int_max) return false;
return true;
}
// kernels borrowed from Caffe
@ -85,21 +114,25 @@ __global__ void max_pool_forward_nchw(const int nthreads, const scalar_t* bottom
}
}
template <typename scalar_t>
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(CUDA_MAX_THREADS)
__global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nbatch,
const int64_t channels, const int64_t height,
const int64_t width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int in_stride_n, const int in_stride_c,
const int in_stride_h, const int in_stride_w,
const int kernel_stride_C, const int kernel_size_C,
scalar_t* top_data, int64_t* top_mask) {
extern __shared__ int smem[];
int *out_mask_cached = smem;
scalar_t *out_cached = reinterpret_cast<scalar_t*>(&out_mask_cached[kernel_size_C*blockDim.x*blockDim.y*blockDim.z]);
__global__ void max_pool_forward_nhwc(
const scalar_t* bottom_data,
const int nbatch,
const index_t channels, const index_t height, const index_t width,
const index_t pooled_height, const index_t pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const index_t in_stride_n, const index_t in_stride_c,
const index_t in_stride_h, const index_t in_stride_w,
const int kernel_stride_C, const int kernel_size_C,
scalar_t* top_data, int64_t* top_mask) {
extern __shared__ unsigned char smem_raw[];
index_t *out_mask_cached = reinterpret_cast<index_t*>(smem_raw);
scalar_t *out_cached = reinterpret_cast<scalar_t*>(
out_mask_cached + kernel_size_C*blockDim.x*blockDim.y*blockDim.z);
// flattening cta for pre-computation & smem initialization;
int thread_id = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
@ -118,26 +151,26 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba
int channel_id = blockIdx.x / nbatch;
int channel_offset = threadIdx.x + channel_id * blockDim.x;
top_data = top_data + batch_id * pooled_height * pooled_width * channels;
top_mask = top_mask + batch_id * pooled_height * pooled_width * channels;
bottom_data = bottom_data + batch_id * in_stride_n;
top_data = top_data + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
top_mask = top_mask + static_cast<index_t>(batch_id) * (pooled_height * pooled_width * channels);
bottom_data = bottom_data + static_cast<index_t>(batch_id) * in_stride_n;
out_cached = &out_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
out_mask_cached = &out_mask_cached[(threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x];
out_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
out_mask_cached += (threadIdx.z * blockDim.y + threadIdx.y) * kernel_size_C*blockDim.x;
int oH = (pooled_height + gridDim.z-1) / gridDim.z;
int oW = (pooled_width + gridDim.y-1) / gridDim.y;
int oH = (static_cast<int>(pooled_height) + gridDim.z - 1) / gridDim.z;
int oW = (static_cast<int>(pooled_width) + gridDim.y - 1) / gridDim.y;
int ostartH = threadIdx.z + blockIdx.z*oH;
int oendH = ::min(ostartH+oH, pooled_height);
int oendH = ::min(ostartH+oH, static_cast<int>(pooled_height));
int ostartW = threadIdx.y + blockIdx.y*oW;
int oendW = ::min(ostartW+oW, pooled_width);
int oendW = ::min(ostartW+oW, static_cast<int>(pooled_width));
for (int oh = ostartH; oh < oendH; oh+=blockDim.z) {
int hstart = oh * stride_h - pad_h;
int hend = min(hstart + (kernel_h - 1) * dilation_h + 1, height);
index_t hstart = static_cast<index_t>(oh) * stride_h - pad_h;
index_t hend = std::min(hstart + static_cast<index_t>((kernel_h - 1) * dilation_h + 1), height);
for (int ow = ostartW; ow < oendW; ow+=blockDim.y) {
int wstart = ow * stride_w - pad_w;
int wend = min(wstart + (kernel_w - 1) * dilation_w + 1, width);
index_t wstart = static_cast<index_t>(ow) * stride_w - pad_w;
index_t wend = std::min(wstart + static_cast<index_t>((kernel_w - 1) * dilation_w + 1), width);
while(hstart < 0)
hstart += dilation_h;
while(wstart < 0)
@ -185,12 +218,12 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba
// Else do it Non-Prefetch...
else
#endif
for (int ih = hstart; ih < hend; ih += dilation_h) {
for (int iw = wstart; iw < wend; iw += dilation_w) {
for (index_t ih = hstart; ih < hend; ih += dilation_h) {
for (index_t iw = wstart; iw < wend; iw += dilation_w) {
int cached_index = threadIdx.x;
const scalar_t *ptr_input = bottom_data + ih * in_stride_h + iw * in_stride_w;
for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
scalar_t val = ptr_input[c*in_stride_c];
for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
scalar_t val = ptr_input[c * in_stride_c];
if ((val > out_cached[cached_index]) || at::_isnan(val)) {
out_cached[cached_index] = val;
out_mask_cached[cached_index] = ih * width + iw;
@ -200,15 +233,15 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba
}
}
scalar_t *ptr_output_data = top_data + (oh * pooled_width + ow) * channels;
int64_t *ptr_output_mask = top_mask + (oh * pooled_width + ow) * channels;
scalar_t *ptr_output_data = top_data + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
int64_t *ptr_output_mask = top_mask + (static_cast<index_t>(oh) * pooled_width + ow) * channels;
int cached_index = threadIdx.x;
for(int c = channel_offset; c < channels; c+= blockDim.x*kernel_stride_C) {
for (index_t c = channel_offset; c < channels; c += static_cast<index_t>(blockDim.x) * kernel_stride_C) {
ptr_output_data[c] = out_cached[cached_index];
ptr_output_mask[c] = out_mask_cached[cached_index];
ptr_output_mask[c] = static_cast<int64_t>(out_mask_cached[cached_index]);
out_cached[cached_index] = at::numeric_limits<scalar_t>::lower_bound();
out_mask_cached[cached_index] = 0;
out_mask_cached[cached_index] = index_t(0);
cached_index += blockDim.x;
}
}
@ -216,7 +249,7 @@ __global__ void max_pool_forward_nhwc(const scalar_t* bottom_data, const int nba
}
static const int BLOCK_THREADS = 256;
static constexpr int BLOCK_THREADS = 256;
template <typename scalar_t, typename accscalar_t>
#if defined (USE_ROCM)
@ -462,6 +495,11 @@ const Tensor& indices) {
maxThreadsDim[0], std::min<int>(lastPow2(nInputPlane), max_threads / block_y / block_z));
const dim3 block(block_x, block_y, block_z);
bool use_int32 = can_use_int32_nhwc(
nbatch, nInputPlane, inputHeight, inputWidth,
outputHeight, outputWidth,
in_stride_n, in_stride_c, in_stride_h, in_stride_w);
int kernel_stride_C = ceil_div(
safe_downcast<int, int64_t>(nInputPlane), block_x * 4);
int kernel_size_C = ceil_div(
@ -476,18 +514,41 @@ const Tensor& indices) {
ceil_div(safe_downcast<int, int64_t>(outputHeight), block_z*BLOCK_STRIDE_FWD));
const dim3 grid(grid_x, grid_y, grid_z);
size_t shmem_size = (kernel_size_C * block_x*block_y*block_z) * (sizeof(int) + sizeof(scalar_t));
AT_ASSERT(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock);
size_t shmem_size;
size_t mask_elems = static_cast<size_t>(kernel_size_C) * block_x * block_y * block_z;
max_pool_forward_nhwc<scalar_t>
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
input_data, nbatch,
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
in_stride_n, in_stride_c,
in_stride_h, in_stride_w,
kernel_stride_C, kernel_size_C,
output_data, indices_data);
if (use_int32) {
shmem_size = mask_elems * (sizeof(int32_t) + sizeof(scalar_t));
TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
"shared memory too small");
max_pool_forward_nhwc<scalar_t, int32_t>
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
input_data, static_cast<int>(nbatch),
static_cast<int32_t>(nInputPlane),
static_cast<int32_t>(inputHeight),
static_cast<int32_t>(inputWidth),
static_cast<int32_t>(outputHeight),
static_cast<int32_t>(outputWidth),
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
static_cast<int32_t>(in_stride_n),
static_cast<int32_t>(in_stride_c),
static_cast<int32_t>(in_stride_h),
static_cast<int32_t>(in_stride_w),
kernel_stride_C, kernel_size_C,
output_data, indices_data);
} else {
shmem_size = mask_elems * (sizeof(int64_t) + sizeof(scalar_t));
TORCH_CHECK(shmem_size <= at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock,
"shared memory too small");
max_pool_forward_nhwc<scalar_t, int64_t>
<<<grid, block, shmem_size, at::cuda::getCurrentCUDAStream()>>>(
input_data, static_cast<int>(nbatch),
nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
in_stride_n, in_stride_c, in_stride_h, in_stride_w,
kernel_stride_C, kernel_size_C,
output_data, indices_data);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
break;
}

View File

@ -15,9 +15,7 @@
#include <ATen/native/cuda/block_reduce.cuh>
#include <ATen/native/cuda/thread_constants.h>
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <thrust/iterator/reverse_iterator.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -36,9 +34,9 @@ namespace at::native {
namespace {
#if defined(USE_ROCM)
static const int BLOCKDIMY = 16;
static constexpr int BLOCKDIMY = 16;
#else
static const int BLOCKDIMY = 32;
static constexpr int BLOCKDIMY = 32;
#endif
template
@ -240,10 +238,6 @@ __global__ void renorm_kernel(
} // anonymous namespace
#if !CUB_SUPPORTS_SCAN_BY_KEY()
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
#endif
Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
int64_t num_weights, int64_t padding_idx,
@ -306,7 +300,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
#if CUB_SUPPORTS_SCAN_BY_KEY()
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -333,11 +326,6 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice
num_indices
);
});
#else
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
});
#endif
}
return embedding_backward_cuda_kernel(grad, orig_indices,

View File

@ -10,9 +10,7 @@
#include <c10/macros/Macros.h>
#if CUB_SUPPORTS_UNIQUE_BY_KEY()
#include <thrust/iterator/counting_iterator.h>
#endif
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -196,18 +194,9 @@ __global__ void compute_num_of_partial_segments(const index_t *partials_per_segm
partials_per_segment_offset[num_of_segments-1];
}
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
__global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) {
*num_of_segments_ptr = num_of_segments;
}
#endif
} // anon namespace
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
template<typename index_t>
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets);
#endif
Tensor embedding_backward_cuda_kernel(
const Tensor &grad,
@ -234,20 +223,12 @@ Tensor embedding_backward_cuda_kernel(
auto segment_offsets = at::empty({numel}, orig_indices.options());
auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong));
int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>();
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
#else
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
cuda::cub::unique_by_key(
sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0),
segment_offsets.mutable_data_ptr<index_t>(),
num_of_segments_ptr, sorted_indices.numel());
});
#endif
int64_t max_segments = std::min<int64_t>(numel, num_weights);

View File

@ -31,16 +31,10 @@
#include <c10/macros/Macros.h>
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <thrust/iterator/reverse_iterator.h>
#endif
namespace at::native {
#if !CUB_SUPPORTS_SCAN_BY_KEY()
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
#endif
namespace {
@ -199,7 +193,6 @@ Tensor embedding_bag_backward_cuda_sum_avg(
if (scale_grad_by_freq) {
count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
#if CUB_SUPPORTS_SCAN_BY_KEY()
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -226,11 +219,6 @@ Tensor embedding_bag_backward_cuda_sum_avg(
num_indices
);
});
#else
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
});
#endif
}
return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag,

View File

@ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t lanczos_sum_expg_scaled_num[13] = {
constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
@ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
static const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
@ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t ax, fac, res, num, numfac;
static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
7.09782712893383996843E2 : 88.72283905206835;
static const accscalar_t EXP1 = 2.718281828459045;
static const accscalar_t lanczos_g = 6.024680040776729583740234375;
constexpr accscalar_t EXP1 = 2.718281828459045;
constexpr accscalar_t lanczos_g = 6.024680040776729583740234375;
if (::fabs(a - x) > 0.4 * ::fabs(a)) {
ax = a * ::log(x) - x - ::lgamma(a);
@ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
// Compute igam using DLMF 8.11.4. [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static const int MAXITER = 2000;
constexpr int MAXITER = 2000;
int i;
accscalar_t ans, ax, c, r;
@ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
accscalar_t fac = 1;
accscalar_t sum = 0;
accscalar_t term, logx;
static const int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr int MAXITER = 2000;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
for (n = 1; n < MAXITER; n++) {
@ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t d[25][25] =
constexpr accscalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
@ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
int k, n, sgn;
int maxpow = 0;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
accscalar_t lambda = x / a;
accscalar_t sigma = (x - a) / a;
@ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar
int i;
accscalar_t ans, ax, c, yc, r, t, y, z;
accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
static const int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr int MAXITER = 2000;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
4.503599627370496e15 : 16777216.;
static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
2.22044604925031308085e-16 : 5.9604644775390625E-8;
ax = _igam_helper_fac(a, x);
@ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
static const accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5;
constexpr accscalar_t SMALL = 20.0;
constexpr accscalar_t LARGE = 200.0;
constexpr accscalar_t SMALLRATIO = 0.3;
constexpr accscalar_t LARGERATIO = 4.5;
if ((x < 0) || (a < 0)) {
// out of defined-region of the function
@ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
static const accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5;
constexpr accscalar_t SMALL = 20.0;
constexpr accscalar_t LARGE = 200.0;
constexpr accscalar_t SMALLRATIO = 0.3;
constexpr accscalar_t LARGERATIO = 4.5;
// boundary values following SciPy
if ((x < 0) || (a < 0)) {

View File

@ -1,90 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/native/cuda/SortingCommon.cuh>
#include <ATen/cuda/cub_definitions.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty_like.h>
#endif
#include <ATen/cuda/ThrustAllocator.h>
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include <thrust/unique.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/constant_iterator.h>
namespace at::native {
#if !CUB_SUPPORTS_SCAN_BY_KEY()
template<typename index_t>
void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
auto policy = thrust::cuda::par(allocator).on(stream);
auto num_indices = count.numel();
// Compute an increasing sequence per unique item in sortedIndices:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 1 2 3 1 2 1 1 2
auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>());
thrust::inclusive_scan_by_key(
policy,
sorted_data,
sorted_data + num_indices,
thrust::make_constant_iterator(1),
count_data
);
// Take the maximum of each count per unique key in reverse:
// sorted: 2 5 5 5 7 7 8 9 9
// count: 1 3 3 3 2 2 1 2 2
thrust::inclusive_scan_by_key(
policy,
thrust::make_reverse_iterator(sorted_data + num_indices),
thrust::make_reverse_iterator(sorted_data),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::make_reverse_iterator(count_data + num_indices),
thrust::equal_to<index_t>(),
thrust::maximum<index_t>()
);
}
template
void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count);
template
void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count);
#endif
template<typename index_t>
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) {
auto stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
auto policy = thrust::cuda::par(allocator).on(stream);
const ptrdiff_t numel = sorted_indices.numel();
auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>());
auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>());
auto ends = thrust::unique_by_key_copy(
policy,
sorted_indices_dev,
sorted_indices_dev + numel,
thrust::make_counting_iterator(0),
dummy_dev,
thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>()));
return thrust::get<0>(ends) - dummy_dev;
}
template
int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets);
template
int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets);
} // namespace at::native

View File

@ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify(
const auto digamma_string = jiterator_stringify(
template <typename T>
T digamma(T x) {
static const double PI_f64 = 3.14159265358979323846;
static constexpr double PI_f64 = 3.14159265358979323846;
// Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
if (x == 0) {
@ -3072,9 +3072,9 @@ template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const double PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
const accscalar_t A[] = {
static constexpr double PI_f64 = 3.14159265358979323846;
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
constexpr accscalar_t A[] = {
8.33333333333333333333E-2,
-2.10927960927960927961E-2,
7.57575757575757575758E-3,

View File

@ -413,14 +413,12 @@ struct ReduceOp {
value = thread_reduce<output_vec_size>(input_slice);
}
if (config.should_block_y_reduce()) {
value = block_y_reduce<output_vec_size>(value, shared_memory);
}
__syncthreads();
if (config.should_block_x_reduce()) {
value = block_x_reduce<output_vec_size>(value, shared_memory);
}
if (config.should_block_y_reduce()) {
value = block_y_reduce<output_vec_size>(value, shared_memory);
}
using out_ptr_vec_t = std::array<out_scalar_t*, output_vec_size>;
using offset_vec_t = std::array<index_t, output_vec_size>;
offset_vec_t base_offsets;
@ -655,8 +653,14 @@ struct ReduceOp {
}
__syncthreads();
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
// matching Triton, etc.
// todo for AMD
#ifdef USE_ROCM
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = ops.warp_shfl_down(value[i], offset);
@ -1091,11 +1095,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
// threads with different threadIdx.x are independent and will produce results for different outputs.
// In such case, values in each loaded vector always correspond to different outputs.
if (fastest_moving_stride == sizeof(scalar_t)) {
#ifdef USE_ROCM
if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
#else
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
#endif
// Case 1: "vectorize along input"
// Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
// we should avoid vectorization.

View File

@ -39,9 +39,14 @@ static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool ta
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
void mean_kernel_impl(TensorIterator& iter) {
// returns acc_t for all non-complex dtypes and returns T for c10::complex<T>
constexpr bool is_16_bits = sizeof(scalar_t) == 2;
using factor_t = typename c10::scalar_value_type<acc_t>::type;
factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel();
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
if constexpr (is_16_bits) {
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
} else {
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
}
}
static void mean_kernel_cuda(TensorIterator& iter) {

View File

@ -13,24 +13,19 @@ namespace at::native {
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
struct sum_functor {
void operator()(TensorIterator& iter) {
#ifdef USE_ROCM
// Half and BFloat16 can be packed in groups of up to 8 elements and
// can use *_DWORDX4 instructions to achieve that.
const bool is_16_bits =
( (std::is_same<at::Half, scalar_t>::value) ||
(std::is_same<at::BFloat16, scalar_t>::value) );
if (is_16_bits) {
const auto sum_combine = [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
};
constexpr bool is_16_bits = sizeof(scalar_t) == 2;
if constexpr (is_16_bits) {
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
}));
return;
iter, func_wrapper<out_t>(sum_combine)
);
} else {
gpu_reduce_kernel<scalar_t, out_t>(
iter, func_wrapper<out_t>(sum_combine)
);
}
#endif
gpu_reduce_kernel<scalar_t, out_t>(
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
}));
}
};
@ -77,8 +72,8 @@ struct nansum_functor_complex {
#if AT_USE_JITERATOR()
void operator()(TensorIterator& iter) {
std::string func = jiterator_stringify(
arg_t combine(arg_t a, scalar_t b) {
return a + (std::isnan(b) ? arg_t{0.} : arg_t{b});
arg_t combine(arg_t a, arg_t b) {
return a + (std::isnan(b) ? arg_t{0.} : b);
}
);
jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>(

View File

@ -19,7 +19,6 @@
namespace at::native {
// TODO: remove this when CUDA <11.6 is no longer supported
void topk_out_with_sort(
const Tensor& self,
int64_t k, int64_t dim, bool largest,
@ -31,21 +30,12 @@ void topk_out_with_sort(
indices.copy_(sorted_indices.narrow(dim, 0, k));
}
// TODO: remove this when CUDA <11.6 is no longer supported
bool disable_sort_for_topk();
bool should_use_sort(const Tensor& self, int64_t dim) {
#if defined(USE_ROCM)
if (self.dtype() == kBool) return false; // Bool sort not supported in ROCm: https://github.com/pytorch/pytorch/issues/139972
return (self.numel() >= 10000 && self.numel() == self.size(dim)); // based on the experiments in https://github.com/pytorch/pytorch/pull/146387
#else
if (disable_sort_for_topk()) return false;
// This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/68632
if (self.dim() == 0) return false;
if (self.dtype() == kBool) return false; // Bool is not support by topk
int64_t slice_size = self.size(dim);
if (slice_size == 0) return false;
int64_t num_slices = self.numel() / slice_size;
return num_slices <= 10 && slice_size >= 100000;
return false;
#endif
}

View File

@ -21,11 +21,6 @@ using namespace at::native;
namespace at::native {
// TODO: remove this when CUDA <11.6 is no longer supported
bool disable_sort_for_topk() {
return CUB_SUPPORTS_SCAN_BY_KEY();
}
namespace sbtopk { // single_block_topk
template <typename T>
@ -418,10 +413,6 @@ __global__ void computeBlockwiseWithinKCounts(
}
__syncthreads();
#if !CUB_SUPPORTS_SCAN_BY_KEY()
return;
#endif
Bitwise desired_digit = at::cuda::Bitfield<Bitwise>::getBitfield(desired, current_bit, RADIX_BITS);
// if largest, then only threads that has tidx > desired_digit are active
@ -477,7 +468,6 @@ __global__ void computeBlockwiseWithinKCounts(
}
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
// Assumption: slice_size can not be larger than UINT32_MAX
template <typename Bitwise>
__global__ void computeBlockwiseKthCounts(
@ -609,7 +599,6 @@ __global__ void gatherTopK(at::cuda::detail::TensorInfo<const T, IndexType> inpu
}
}
}
#endif
int get_items_per_thread(uint64_t num_slices, uint64_t slice_size) {
// occupancy of this kernel is limited by registers per threads
@ -687,16 +676,12 @@ void launch(
uint32_t* digit_cum_sum = reinterpret_cast<uint32_t*>(digit_cum_sum_buffer.get());
AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream));
#if CUB_SUPPORTS_SCAN_BY_KEY()
auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get());
AT_CUDA_CHECK(cudaMemsetAsync(withinKCounts, 0, num_blocks * sizeof(uint32_t), stream));
auto kthCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
uint32_t* kthCounts = reinterpret_cast<uint32_t*>(kthCounts_buffer.get());
#else
uint32_t* withinKCounts = nullptr;
#endif
Bitwise desiredMask = 0;
dim3 grid;
@ -743,7 +728,6 @@ void launch(
}
desired = desired_in;
#if CUB_SUPPORTS_SCAN_BY_KEY()
computeBlockwiseKthCounts<Bitwise><<<std::min(((int64_t)numInputSlices + 255) / 256, (int64_t)1073741824), 256, 0, stream>>>(
desired, counts, num_blocks, blocks_per_slice, kthCounts);
C10_CUDA_KERNEL_LAUNCH_CHECK();
@ -759,28 +743,6 @@ void launch(
topK, topKWithinSliceStride, indices, indicesWithinSliceStride, items_per_thread,
blocks_per_slice, kthValues, withinKCounts, kthCounts, num_blocks);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#else
// Find topk values based on kth values
{
dim3 grid;
TORCH_INTERNAL_ASSERT(getGridFromTiles(numInputSlices, grid), "Too many slices for topk");
int warp_size = at::cuda::warp_size();
dim3 block(std::min(at::ceil_div((int64_t)inputSliceSize, (int64_t)warp_size) * (int64_t)warp_size, (int64_t)1024));
sbtopk::gatherTopK<T, IndexType, Dim, /* WithKthValues= */true><<<grid, block, 0, stream>>>(
input,
inputSliceSize,
outputSliceSize,
largest,
numInputSlices,
inputWithinSliceStride,
topK,
topKWithinSliceStride,
indices,
indicesWithinSliceStride,
kthValues);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif
}
} // namespace mbtopk
@ -788,7 +750,6 @@ void launch(
bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
if (num_slices > std::numeric_limits<uint32_t>::max() ||
slice_size > std::numeric_limits<uint32_t>::max()) return false;
#if CUB_SUPPORTS_SCAN_BY_KEY()
// This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/74267
return (num_slices <= 20 && slice_size >= 20000) ||
(num_slices > 20 && num_slices <= 40 && slice_size >= 10000) ||
@ -797,12 +758,6 @@ bool should_use_multiblock(int64_t num_slices, int64_t slice_size) {
(num_slices >= 200 && num_slices < 800 && slice_size >= 3000) ||
(num_slices >= 800 && num_slices <= 4000 && slice_size >= 800) ||
(num_slices > 4000 && slice_size >= 400);
#else
// This heuristics is based on the experiment in https://github.com/pytorch/pytorch/pull/71081
return (num_slices <= 400 && slice_size >= 5000) ||
(num_slices > 400 && num_slices < 4000 && slice_size >= 1000) ||
(num_slices >= 4000 && slice_size >= 300);
#endif
}
void launch_gather_topk_kernel(

View File

@ -277,7 +277,7 @@ struct BilinearFilterFunctor {
return 0;
}
static const int size = 2;
static constexpr int size = 2;
};
// taken from
@ -301,7 +301,7 @@ struct BicubicFilterFunctor {
return 0;
}
static const int size = 4;
static constexpr int size = 4;
};
template <typename accscalar_t>

View File

@ -141,7 +141,11 @@ WelfordDataLN cuWelfordOnlineSum(
if constexpr (!rms_norm){
U delta = val - curr_sum.mean;
U new_count = curr_sum.count + 1.f;
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
#else
U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
#endif
return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
} else{
return {0.f, curr_sum.sigma2 + val * val, 0};
@ -159,7 +163,11 @@ WelfordDataLN cuWelfordCombine(
U count = dataA.count + dataB.count;
U mean, sigma2;
if (count > decltype(dataB.count){0}) {
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
auto coef = __builtin_amdgcn_rcpf(count);
#else
auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
#endif
auto nA = dataA.count * coef;
auto nB = dataB.count * coef;
mean = nA*dataA.mean + nB*dataB.mean;

View File

@ -466,7 +466,11 @@ struct ReduceJitOp {
__syncthreads();
#ifdef USE_ROCM
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = reducer::warp_shfl_down(value[i], offset);

View File

@ -487,9 +487,7 @@ std::unique_ptr<fe::graph::Graph> build_graph(
auto scaled_dot_product_flash_attention_options =
fe::graph::SDPA_attributes()
.set_name("CUDNN_SDPA")
.set_is_inference(return_softmaxstats == false)
// TODO(eqy): switch to this API once cuDNN FE is upgraded
// .set_generate_stats(return_softmaxstats)
.set_generate_stats(return_softmaxstats)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) {
@ -707,9 +705,7 @@ std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor(
auto scaled_dot_product_flash_attention_options =
fe::graph::SDPA_attributes()
.set_name("CUDNN_SDPA_NESTEDTENSOR")
.set_is_inference(return_softmaxstats == false)
// TODO(eqy): switch to this API once cuDNN FE is upgraded
// .set_generate_stats(return_softmaxstats)
.set_generate_stats(return_softmaxstats)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale)
.set_seq_len_q(SEQ_LEN_Q_)

View File

@ -160,8 +160,12 @@ static bool mkldnn_conv_enabled_fpmath_mode_bf16(){
}
static bool mkldnn_conv_enabled_fpmath_mode_tf32(){
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 &&
cpuinfo_has_x86_amx_fp16();
#if defined(__x86_64__) || defined(_M_X64)
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 &&
cpuinfo_has_x86_amx_fp16();
#else
return false; //TF32 not supported on power system
#endif
}
static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) {

View File

@ -74,8 +74,12 @@ static bool use_mkldnn_bf32_linear() {
}
static bool use_mkldnn_tf32_linear() {
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 &&
#if defined(__x86_64__) || defined(_M_X64)
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 &&
cpuinfo_has_x86_amx_fp16();
#else
return false; // TF32 not supported on power system
#endif
}
Tensor mkldnn_linear(

View File

@ -114,8 +114,13 @@ static bool use_mkldnn_bf32_matmul() {
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16;
}
static bool use_mkldnn_tf32_matmul() {
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32;
#if defined(__x86_64__) || defined(_M_X64)
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32;
#else
return false; // TF32 not supported on power system
#endif
}
// returns an ideep::tensor
@ -411,7 +416,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
if (mat1.dim() == 1 && mat2.dim() == 1) {
// aten::dot
return mat1.size(0) > mkldnn_gemm_min_size;

View File

@ -712,7 +712,7 @@ Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device) {
} else if (scalar.isBoolean()) {
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kBool));
} else if (scalar.isComplex()) {
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexDouble));
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kComplexFloat));
} else {
TORCH_INTERNAL_ASSERT(scalar.isIntegral(false));
tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kLong));

View File

@ -1,16 +1,16 @@
#pragma once
#include <c10/metal/common.h>
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
struct CatLargeSharedParams {
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
struct CatSharedParams {
int32_t ndim;
int32_t cat_dim;
::c10::metal::array<idx_type_t, N> output_strides;
::c10::metal::array<idx_type_t, N> output_sizes;
};
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
struct CatLargeInputParams {
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
struct CatInputParams {
idx_type_t cat_dim_offset;
idx_type_t input_element_offset;
::c10::metal::array<idx_type_t, N> input_strides;

View File

@ -6,26 +6,25 @@
using namespace metal;
using namespace c10::metal;
template <typename T_in, typename T_out>
kernel void cat_large(
template <typename I, typename T_in, typename T_out>
kernel void cat(
constant T_in* input [[buffer(0)]],
device T_out* output [[buffer(1)]],
constant CatLargeSharedParams<>& shared_params [[buffer(2)]],
constant CatLargeInputParams<>& input_params [[buffer(3)]],
constant CatSharedParams<I>& shared_params [[buffer(2)]],
constant CatInputParams<I>& input_params [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
auto ndim = shared_params.ndim;
auto cat_dim = shared_params.cat_dim;
constant auto& output_strides = shared_params.output_strides;
constant auto& output_sizes = shared_params.output_sizes;
auto cat_dim_offset = input_params.cat_dim_offset;
auto input_element_offset = input_params.input_element_offset;
constant auto& input_strides = input_params.input_strides;
constant auto& input_sizes = input_params.input_sizes;
auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset;
int64_t input_offset = 0;
int64_t output_offset = 0;
auto input_element_idx = static_cast<I>(tid) + input_element_offset;
I input_offset = 0;
I output_offset = 0;
for (auto dim = ndim - 1; dim >= 0; dim--) {
auto dim_size = input_sizes[dim];
@ -42,41 +41,45 @@ kernel void cat_large(
output[output_offset] = static_cast<T_out>(input[input_offset]);
}
#define REGISTER_CAT_LARGE_OP(T_in, T_out) \
template [[host_name("cat_large_" #T_in "_" #T_out)]] \
kernel void cat_large<T_in, T_out>( \
constant T_in * input [[buffer(0)]], \
device T_out * output [[buffer(1)]], \
constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \
constant CatLargeInputParams<> & input_params [[buffer(3)]], \
#define REGISTER_CAT_OP(I, T_in, T_out) \
template [[host_name("cat_" #I "_" #T_in "_" #T_out)]] \
kernel void cat<I, T_in, T_out>( \
constant T_in * input [[buffer(0)]], \
device T_out * output [[buffer(1)]], \
constant CatSharedParams<I> & shared_params [[buffer(2)]], \
constant CatInputParams<I> & input_params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);
#define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \
REGISTER_CAT_LARGE_OP(float, T_out); \
REGISTER_CAT_LARGE_OP(half, T_out); \
REGISTER_CAT_LARGE_OP(bfloat, T_out); \
REGISTER_CAT_LARGE_OP(int, T_out); \
REGISTER_CAT_LARGE_OP(uint, T_out); \
REGISTER_CAT_LARGE_OP(long, T_out); \
REGISTER_CAT_LARGE_OP(ulong, T_out); \
REGISTER_CAT_LARGE_OP(short, T_out); \
REGISTER_CAT_LARGE_OP(ushort, T_out); \
REGISTER_CAT_LARGE_OP(char, T_out); \
REGISTER_CAT_LARGE_OP(uchar, T_out); \
REGISTER_CAT_LARGE_OP(bool, T_out);
#define REGISTER_CAT_OP_ALL_INPUT_TYPES(I, T_out) \
REGISTER_CAT_OP(I, float, T_out); \
REGISTER_CAT_OP(I, half, T_out); \
REGISTER_CAT_OP(I, bfloat, T_out); \
REGISTER_CAT_OP(I, int, T_out); \
REGISTER_CAT_OP(I, uint, T_out); \
REGISTER_CAT_OP(I, long, T_out); \
REGISTER_CAT_OP(I, ulong, T_out); \
REGISTER_CAT_OP(I, short, T_out); \
REGISTER_CAT_OP(I, ushort, T_out); \
REGISTER_CAT_OP(I, char, T_out); \
REGISTER_CAT_OP(I, uchar, T_out); \
REGISTER_CAT_OP(I, bool, T_out);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar);
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool);
#define REGISTER_CAT_FOR_INDEX_TYPE(I) \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, float); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, half); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bfloat); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, int); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uint); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, long); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ulong); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, short); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ushort); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, char); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uchar); \
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bool); \
\
REGISTER_CAT_OP(I, float2, float2); \
REGISTER_CAT_OP(I, half2, half2);
REGISTER_CAT_LARGE_OP(float2, float2);
REGISTER_CAT_LARGE_OP(half2, half2);
REGISTER_CAT_FOR_INDEX_TYPE(int64_t);
REGISTER_CAT_FOR_INDEX_TYPE(int32_t);

View File

@ -54,6 +54,10 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) {
using namespace mps;
using CachedGraph = MPSBinaryCachedGraph;
if (self.numel() == 0 & other.numel() == 0) {
return zeros({}, self.options());
}
dot_check(self, other);
auto output = at::empty({}, self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);

View File

@ -907,6 +907,8 @@ Tensor& index_fill_mps_(Tensor& self, int64_t dim, const Tensor& index, const Te
TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int,
"index_fill_(): Expected dtype int32 or int64 for index");
TORCH_CHECK(dim == 0 || dim < self.dim(), "index_fill_(): Indexing dim ", dim, " is out of bounds of tensor");
TORCH_CHECK(self.is_complex() || !source.is_complex(),
"index_fill_(): Converting complex Scalar to non-complex type is not supported");
// MPS.scatter crashes if used with complex dtypes
TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "index_fill_(): Complex types are yet not supported");

View File

@ -196,6 +196,28 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output)
other.size(0) > max_stride_size || other.size(1) > max_stride_size);
}
void map_mps_decomposition_error_code_to_blas(const Tensor& status) {
const auto& status_flat = status.view(-1);
for (const auto i : c10::irange(status_flat.size(0))) {
int code = status_flat[i].item<int>();
switch (code) {
case MPSMatrixDecompositionStatusSuccess:
status_flat[i] = 0;
break;
case MPSMatrixDecompositionStatusNonPositiveDefinite:
case MPSMatrixDecompositionStatusSingular:
status_flat[i] = 2;
break;
case MPSMatrixDecompositionStatusFailure:
status_flat[i] = -1;
break;
default:
TORCH_INTERNAL_ASSERT(false, "Unknown MPSMatrixDecompositionStatus enum value: ", code);
}
}
}
} // anonymous namespace
static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
@ -487,6 +509,9 @@ static void linalg_solve_out_mps_impl(const Tensor& A,
"mpsmatrixdecompositionstatus for details.");
}
}
map_mps_decomposition_error_code_to_blas(info);
if (!left) {
// If this was a right solve, transpose the result back
result.copy_(result_t.transpose(-2, -1).contiguous());

View File

@ -3,6 +3,7 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/Pool.h>
#include <ATen/native/TensorShape.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/native/mps/OperationUtils.h>
@ -69,29 +70,40 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in
}
}
// This implementation of cat is used only if one of the inputs or the output is
// too large to use MPSGraph.
template <typename T>
std::string get_type_str();
template <>
std::string get_type_str<int64_t>() {
return "int64_t";
}
template <>
std::string get_type_str<int32_t>() {
return "int32_t";
}
// NOTE: `output` is expected to already have the correct size.
static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
CatLargeSharedParams shared_params;
template <typename idx_type_t>
static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
CatSharedParams<idx_type_t> shared_params;
shared_params.ndim = output.dim();
shared_params.cat_dim = dimension;
for (const auto dim : c10::irange(output.dim())) {
shared_params.output_strides[dim] = output.stride(dim);
shared_params.output_sizes[dim] = output.size(dim);
shared_params.output_strides[dim] = safe_downcast<idx_type_t, int64_t>(output.stride(dim));
shared_params.output_sizes[dim] = safe_downcast<idx_type_t, int64_t>(output.size(dim));
}
int64_t cat_dim_offset = 0;
idx_type_t cat_dim_offset = 0;
size_t input_idx = 0;
MPSStream* stream = getCurrentMPSStream();
// Launch a separate kernels for each input. This will produce some overhead,
// but that should be relatively minimal since at least one of the inputs is
// very large. In order to launch only one kernel to process all inputs, we
// would have to copy all the input tensor data into a packed buffer, which
// would not be ideal.
// Launch a separate kernels for each input. This will produce some overhead.
// In order to launch only one kernel to process all inputs, we would have to
// copy all the input tensor data into a packed buffer, which would not be
// ideal.
for (const Tensor& input : inputs) {
if (input.numel() == 0) {
continue;
@ -104,21 +116,23 @@ static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimen
for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) {
auto num_threads = std::min(max_num_threads, numel_remaining);
CatLargeInputParams input_params;
CatInputParams<idx_type_t> input_params;
input_params.cat_dim_offset = cat_dim_offset;
input_params.input_element_offset = input.numel() - numel_remaining;
input_params.cat_dim_offset = safe_downcast<idx_type_t, int64_t>(cat_dim_offset);
input_params.input_element_offset = safe_downcast<idx_type_t, int64_t>(input.numel() - numel_remaining);
for (const auto dim : c10::irange(input.dim())) {
input_params.input_strides[dim] = input.stride(dim);
input_params.input_sizes[dim] = input.size(dim);
input_params.input_strides[dim] = safe_downcast<idx_type_t, int64_t>(input.stride(dim));
input_params.input_sizes[dim] = safe_downcast<idx_type_t, int64_t>(input.size(dim));
}
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
auto pipeline_state = lib.getPipelineStateForFunc(
fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output)));
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("cat_{}_{}_{}",
get_type_str<idx_type_t>(),
scalarToMetalTypeString(input),
scalarToMetalTypeString(output)));
getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input});
[computeEncoder setComputePipelineState:pipeline_state];
mtl_setArgs(computeEncoder, input, output, shared_params, input_params);
@ -294,13 +308,6 @@ TORCH_IMPL_FUNC(cat_out_mps)
" and out is on ",
out.device());
// TODO: For better performance by eliminating input tensor gathering and post transpose,
// TODO: it is better to keep the out tensor's memory format.
// TODO: dimension needs to be recomputed as:
// TODO: dim = 0 --> dim = 0; dim = 1 or 2 --> dim = out.dim()- dim; otherwise dim = dim-1
if (needsGather(out)) {
out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
}
std::vector<int64_t> size(notSkippedTensor.sizes().vec());
// Compute size of the result in the cat dimension
@ -331,82 +338,9 @@ TORCH_IMPL_FUNC(cat_out_mps)
has_large_tensor |= isTooLargeForMPSGraph(out);
if (has_large_tensor) {
return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out);
}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
std::vector<MPSGraphTensor*> inputTensors_;
MPSGraphTensor* outputTensor_ = nil;
};
@autoreleasepool {
std::string key = "cat_out_mps:" + std::to_string(dimension) + ":" +
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
if (!all_same_dtype) {
key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride);
} else {
key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size());
}
for (auto idx : skipped_tensor_indices) {
key += "," + std::to_string(idx);
}
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto len_tensor_array = inputs.size() - skipped_tensor_indices.size();
std::vector<MPSGraphTensor*> castInputTensors(len_tensor_array);
newCachedGraph->inputTensors_.reserve(len_tensor_array);
for (const auto idx : c10::irange(len_tensor_array)) {
const Tensor& tensor = input_tensors[idx];
auto scalar_type = getMPSScalarType(tensor.scalar_type());
if (tensor.scalar_type() == kBool) {
scalar_type = MPSDataTypeInt8;
}
newCachedGraph->inputTensors_[idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, scalar_type);
if (tensor.scalar_type() != out_dtype) {
castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
toType:getMPSDataType(out_dtype)
name:@"castInput"];
} else {
castInputTensors[idx] = newCachedGraph->inputTensors_[idx];
}
}
auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array];
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
dimension:dimension // Maybe convert this from int64_t -> int32
name:nil];
if (getMPSDataType(out_dtype) == MPSDataTypeBool) {
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"];
}
newCachedGraph->outputTensor_ = outputTensor;
});
std::vector<Placeholder> inputPlaceholders;
int i = 0;
int t_idx = 0;
for (const Tensor& tensor : materialized_inputs) {
if (std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) {
auto scalar_type = getMPSScalarType(tensor.scalar_type());
if (tensor.scalar_type() == kBool) {
scalar_type = MPSDataTypeInt8;
}
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor, nullptr, true, scalar_type);
t_idx++;
}
i++;
}
auto outputDataType = getMPSScalarType(out.scalar_type());
Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
for (auto& inputPlaceholder : inputPlaceholders) {
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
}
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
return mps::cat_out_mps_impl<int64_t>(materialized_inputs, dimension, out);
} else {
return mps::cat_out_mps_impl<int32_t>(materialized_inputs, dimension, out);
}
}

View File

@ -1370,6 +1370,7 @@
dispatch:
SparseCPU: bmm_sparse_cpu
SparseCUDA: bmm_sparse_cuda
SparseMPS: bmm_sparse_mps
NestedTensorCPU: bmm_nested
NestedTensorCUDA: bmm_nested_cuda
tags: core
@ -1385,6 +1386,7 @@
MTIA: bmm_out_mtia
SparseCPU: bmm_out_sparse_cpu
SparseCUDA: bmm_out_sparse_cuda
SparseMPS: bmm_out_sparse_mps
SparseCsrCUDA: bmm_out_sparse_csr_cuda
- func: bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor
@ -4173,7 +4175,7 @@
structured_delegate: mm.out
variants: function, method
dispatch:
SparseCPU, SparseCUDA: _sparse_mm
SparseCPU, SparseCUDA, SparseMPS: _sparse_mm
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm
tags: core
@ -6531,6 +6533,7 @@
dispatch:
CPU, CUDA: var
MPS: var_mps
MTIA: var_mtia
tags: core
- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
@ -7111,6 +7114,7 @@
MTIA: addmm_out_mtia
SparseCPU: addmm_out_sparse_dense_cpu
SparseCUDA: addmm_out_sparse_dense_cuda
SparseMPS: addmm_out_sparse_dense_mps
SparseCsrCPU: addmm_out_sparse_compressed_cpu
SparseCsrCUDA: addmm_out_sparse_compressed_cuda
@ -7120,6 +7124,7 @@
dispatch:
SparseCPU: addmm_sparse_dense_cpu
SparseCUDA: addmm_sparse_dense_cuda
SparseMPS: addmm_sparse_dense_mps
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense
tags: core
@ -7183,6 +7188,12 @@
CUDA: _scaled_grouped_mm_cuda
tags: needs_exact_strides
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CUDA: _scaled_grouped_mm_cuda_v2
tags: needs_exact_strides
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
variants: function
dispatch:
@ -7378,7 +7389,7 @@
- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
variants: method
dispatch:
SparseCPU, SparseCUDA: sparse_mask
SparseCPU, SparseCUDA, SparseMPS: sparse_mask
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_mask_sparse_compressed
autogen: sparse_mask.out

View File

@ -184,15 +184,23 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba
0 & \text{ else }
\end{cases}
*/
float scale_val = scale[0].item<float>();
float inv_scale_val = 1.0f / scale_val;
int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point, quant_min, quant_max, false);
TORCH_CHECK(dY.scalar_type() == ScalarType::Float);
TORCH_CHECK(X.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale.scalar_type() == ScalarType::Float);
TORCH_CHECK(zero_point.scalar_type() == ScalarType::Float);
TORCH_CHECK(X.numel() == dY.numel(), "`X` and `dY` are not the same size");
bool is_bfloat16 = (X.scalar_type() == at::kBFloat16);
at::Tensor X_ = is_bfloat16 ? X.to(ScalarType::Float) : X;
at::Tensor dY_ = is_bfloat16 ? dY.to(ScalarType::Float) : dY;
at::Tensor scale_ = is_bfloat16 ? scale.to(ScalarType::Float) : scale;
at::Tensor zero_point_ = is_bfloat16 ? zero_point.to(ScalarType::Float) : zero_point;
float scale_val = scale_[0].item<float>();
float inv_scale_val = 1.0f / scale_val;
int64_t zero_point_val = native::_get_zero_point_from_tensor(zero_point_, quant_min, quant_max, false);
TORCH_CHECK(dY_.scalar_type() == ScalarType::Float);
TORCH_CHECK(X_.scalar_type() == ScalarType::Float);
TORCH_CHECK(scale_.scalar_type() == ScalarType::Float);
TORCH_CHECK(zero_point_.scalar_type() == ScalarType::Float);
TORCH_CHECK(X_.numel() == dY_.numel(), "`X` and `dY` are not the same size");
TORCH_CHECK(
quant_min <= 0 && quant_max >= 0,
"`quant_min` should be less than or \
@ -200,28 +208,28 @@ std::tuple<Tensor, Tensor, Tensor> _fake_quantize_learnable_per_tensor_affine_ba
TORCH_CHECK(
zero_point_val >= quant_min && zero_point_val <= quant_max,
"`zero_point` must be between `quant_min` and `quant_max`.");
if (X.numel() <= 0) {
if (X_.numel() <= 0) {
return std::make_tuple(X, scale, zero_point);
}
auto dX = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto dScale_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto dZeroPoint_vec = at::empty_like(X, X.options(), MemoryFormat::Preserve);
auto dX = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
auto dScale_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
auto dZeroPoint_vec = at::empty_like(X_, X_.options(), MemoryFormat::Preserve);
auto iter = TensorIteratorConfig()
.add_output(dX)
.add_output(dScale_vec)
.add_output(dZeroPoint_vec)
.add_input(X)
.add_input(dY)
.add_input(X_)
.add_input(dY_)
.build();
fake_quant_grad_learnable_tensor_stub(
X.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
X_.device().type(), iter, scale_val, inv_scale_val, zero_point_val, quant_min, quant_max, grad_factor);
// The total sums over the scale and zero point gradient vectors are what will be returned in the end.
auto dScale = dScale_vec.sum().unsqueeze(0).to(scale.device());
auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point.device());
auto dScale = dScale_vec.sum().unsqueeze(0).to(scale_.device());
auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point_.device());
return std::make_tuple(dX, dScale, dZeroPoint);
}

View File

@ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
#if defined(__ARM_NEON__) || defined(__aarch64__)
const static int PARALLEL_THRESHOLD = 1 << 20;
constexpr static int PARALLEL_THRESHOLD = 1 << 20;
// Generic template defaults to naive quantize implementation
template <typename T>

View File

@ -1388,7 +1388,7 @@ namespace at::native {
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
"onednn int8 linear: act scale/zp size should be 1/<=1");
static std::optional<at::Tensor> other = std::nullopt;
static const std::string_view binary_post_op = "none";
constexpr std::string_view binary_post_op = "none";
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
return linear_int8_with_onednn_weight(
act, act_scale.item().toDouble(), act_zp,

Some files were not shown because too many files have changed in this diff Show More