mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-15 23:04:54 +08:00
Update (base update)
[ghstack-poisoned]
This commit is contained in:
@ -38,17 +38,19 @@ static inline std::string _cudaGetErrorEnum(cufftResult error)
|
||||
return "CUFFT_INVALID_SIZE";
|
||||
case CUFFT_UNALIGNED_DATA:
|
||||
return "CUFFT_UNALIGNED_DATA";
|
||||
case CUFFT_INCOMPLETE_PARAMETER_LIST:
|
||||
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
|
||||
case CUFFT_INVALID_DEVICE:
|
||||
return "CUFFT_INVALID_DEVICE";
|
||||
case CUFFT_PARSE_ERROR:
|
||||
return "CUFFT_PARSE_ERROR";
|
||||
case CUFFT_NO_WORKSPACE:
|
||||
return "CUFFT_NO_WORKSPACE";
|
||||
case CUFFT_NOT_IMPLEMENTED:
|
||||
return "CUFFT_NOT_IMPLEMENTED";
|
||||
#if !defined(USE_ROCM)
|
||||
#if CUDA_VERSION <= 12090
|
||||
case CUFFT_INCOMPLETE_PARAMETER_LIST:
|
||||
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
|
||||
case CUFFT_PARSE_ERROR:
|
||||
return "CUFFT_PARSE_ERROR";
|
||||
#endif
|
||||
#if !defined(USE_ROCM) && CUDA_VERSION <= 12090
|
||||
case CUFFT_LICENSE_ERROR:
|
||||
return "CUFFT_LICENSE_ERROR";
|
||||
#endif
|
||||
|
||||
@ -5,15 +5,86 @@
|
||||
|
||||
namespace c10 {
|
||||
namespace metal {
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
struct simd_type {
|
||||
using t = T;
|
||||
};
|
||||
|
||||
// Helper that allows one to run simd ops over bfl16 by upcasting them to fp32
|
||||
template <typename T>
|
||||
using simd_type_t = typename simd_type<T>::t;
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct simd_type<bfloat> {
|
||||
using t = float;
|
||||
};
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) {
|
||||
return ::metal::simd_sum(val);
|
||||
return T(::metal::simd_sum(detail::simd_type_t<T>(val)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_prod(T val) {
|
||||
return ::metal::simd_product(val);
|
||||
return T(::metal::simd_product(detail::simd_type_t<T>(val)));
|
||||
}
|
||||
|
||||
// Extend simd_broadcast to 64-bit integral types using int2 trick
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) == 8, bool> =
|
||||
true>
|
||||
inline T simd_broadcast(T val, ushort lane_id) {
|
||||
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), lane_id));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<!::metal::is_integral_v<T> || sizeof(T) != 8, bool> =
|
||||
true>
|
||||
inline T simd_broadcast(T val, ushort lane_id) {
|
||||
return ::metal::simd_broadcast(val, lane_id);
|
||||
}
|
||||
|
||||
// Floating simd_min/max with nan propagation
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
|
||||
inline T simd_max(T val) {
|
||||
if (::metal::simd_any(::metal::isnan(val))) {
|
||||
return ::metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
return T(::metal::simd_max(detail::simd_type_t<T>(val)));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
|
||||
inline T simd_min(T val) {
|
||||
if (::metal::simd_any(::metal::isnan(val))) {
|
||||
return ::metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
return T(::metal::simd_min(detail::simd_type_t<T>(val)));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) != 8, bool> =
|
||||
true>
|
||||
inline T simd_max(T val) {
|
||||
return ::metal::simd_max(val);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) != 8, bool> =
|
||||
true>
|
||||
inline T simd_min(T val) {
|
||||
return ::metal::simd_min(val);
|
||||
}
|
||||
|
||||
// Metal does not support SIMD reductions over 64-bit types, but it could be
|
||||
@ -28,7 +99,7 @@ inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_sum(T val) {
|
||||
val += as_type<T>(
|
||||
::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
|
||||
}
|
||||
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
|
||||
return simd_broadcast(val, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -37,7 +108,78 @@ inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_prod(T val) {
|
||||
val *= as_type<T>(
|
||||
::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
|
||||
}
|
||||
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
|
||||
return simd_broadcast(val, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_max(T val) {
|
||||
for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
|
||||
val = ::metal::max(
|
||||
val,
|
||||
as_type<T>(::metal::simd_shuffle_and_fill_down(
|
||||
as_type<int2>(val), int2(0), i)));
|
||||
}
|
||||
return simd_broadcast(val, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_min(T val) {
|
||||
for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
|
||||
val = ::metal::min(
|
||||
val,
|
||||
as_type<T>(::metal::simd_shuffle_and_fill_down(
|
||||
as_type<int2>(val), int2(0), i)));
|
||||
}
|
||||
return simd_broadcast(val, 0);
|
||||
}
|
||||
|
||||
// argmin/argmax helpers using simd_ballot
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_integral_v<T>, bool> = true>
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmin(T val) {
|
||||
const auto rc = simd_min(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc);
|
||||
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmin(T val) {
|
||||
const auto rc = simd_min(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val));
|
||||
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_integral_v<T>, bool> = true>
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmax(T val) {
|
||||
const auto rc = simd_max(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc);
|
||||
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmax(T val) {
|
||||
const auto rc = simd_max(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val));
|
||||
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <typename ARG_T, typename IDX_T>
|
||||
inline c10::metal::pair<ARG_T, IDX_T> simd_argmin(ARG_T val, IDX_T idx_val) {
|
||||
auto rc = simd_argmin(val);
|
||||
return {rc.first, simd_broadcast(idx_val, rc.second)};
|
||||
}
|
||||
|
||||
template <typename ARG_T, typename IDX_T>
|
||||
inline c10::metal::pair<ARG_T, IDX_T> simd_argmax(ARG_T val, IDX_T idx_val) {
|
||||
auto rc = simd_argmax(val);
|
||||
return {rc.first, simd_broadcast(idx_val, rc.second)};
|
||||
}
|
||||
|
||||
// Below algorithms are written with hardcoded assumption that simdgroup is 32
|
||||
@ -88,6 +230,44 @@ opmath_t<T> threadgroup_prod(
|
||||
return data[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T threadgroup_max(threadgroup T* data, T val, unsigned idx, unsigned size) {
|
||||
auto rc = simd_max(val);
|
||||
if (idx % simdgroup_size == 0) {
|
||||
data[idx / simdgroup_size] = rc;
|
||||
}
|
||||
if (size > simdgroup_size) {
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
|
||||
auto rc1 = simd_max(data[idx]);
|
||||
if (idx == 0) {
|
||||
data[0] = rc1;
|
||||
}
|
||||
}
|
||||
}
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
return data[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T threadgroup_min(threadgroup T* data, T val, unsigned idx, unsigned size) {
|
||||
auto rc = simd_min(val);
|
||||
if (idx % simdgroup_size == 0) {
|
||||
data[idx / simdgroup_size] = rc;
|
||||
}
|
||||
if (size > simdgroup_size) {
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
|
||||
auto rc1 = simd_min(data[idx]);
|
||||
if (idx == 0) {
|
||||
data[0] = rc1;
|
||||
}
|
||||
}
|
||||
}
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
return data[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
@ -123,28 +303,6 @@ float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T threadgroup_max(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
T rc = data[0];
|
||||
for (unsigned idx = 1; idx < size; ++idx) {
|
||||
rc = ::c10::metal::max(rc, data[idx]);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T threadgroup_min(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
T rc = data[0];
|
||||
for (unsigned idx = 1; idx < size; ++idx) {
|
||||
rc = ::c10::metal::min(rc, data[idx]);
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int threadgroup_argmax(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
|
||||
@ -330,5 +330,11 @@ inline float log1p(float x) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2 = T1>
|
||||
struct pair {
|
||||
T1 first;
|
||||
T2 second;
|
||||
};
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
||||
@ -120,7 +120,6 @@ dtensor_fails = {
|
||||
xfail("chunk"),
|
||||
xfail("combinations"),
|
||||
xfail("complex"),
|
||||
xfail("constant_pad_nd"),
|
||||
xfail("count_nonzero"),
|
||||
xfail("cross"),
|
||||
xfail("cummax"),
|
||||
@ -167,7 +166,6 @@ dtensor_fails = {
|
||||
xfail("grid_sampler_2d"),
|
||||
xfail("gradient"),
|
||||
xfail("heaviside"),
|
||||
xfail("histc"),
|
||||
xfail("histogram"),
|
||||
xfail("histogramdd"),
|
||||
xfail("index_add"),
|
||||
@ -244,7 +242,6 @@ dtensor_fails = {
|
||||
xfail("nanquantile"),
|
||||
xfail("nansum"),
|
||||
xfail("native_batch_norm"),
|
||||
xfail("native_dropout_backward"),
|
||||
xfail("narrow_copy"),
|
||||
xfail("ne"),
|
||||
xfail("new_empty"),
|
||||
@ -309,10 +306,8 @@ dtensor_fails = {
|
||||
xfail("nn.functional.mish"),
|
||||
xfail("nn.functional.mse_loss"),
|
||||
xfail("nn.functional.multi_margin_loss"),
|
||||
xfail("nn.functional.multi_head_attention_forward"),
|
||||
xfail("nn.functional.multilabel_margin_loss"),
|
||||
xfail("nn.functional.multilabel_soft_margin_loss"),
|
||||
xfail("nn.functional.pad", "constant"),
|
||||
xfail("nn.functional.pad", "reflect"),
|
||||
xfail("nn.functional.pad", "replicate"),
|
||||
xfail("nn.functional.pad", "replicate_negative"),
|
||||
|
||||
@ -454,16 +454,30 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None):
|
||||
None
|
||||
"""
|
||||
propagator = DTensor._op_dispatcher.sharding_propagator
|
||||
_origin_op_strategy_funcs = None
|
||||
_origin_op_strategy_schema = None
|
||||
try:
|
||||
# register the op strategy
|
||||
if op_overload in propagator.op_strategy_funcs:
|
||||
_origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload]
|
||||
del propagator.op_strategy_funcs[op_overload]
|
||||
if op_overload in propagator.op_to_schema_info:
|
||||
_origin_op_strategy_schema = propagator.op_to_schema_info[op_overload]
|
||||
del propagator.op_to_schema_info[op_overload]
|
||||
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
|
||||
yield
|
||||
finally:
|
||||
# clear this op strategy cache
|
||||
if op_overload in propagator.op_strategy_funcs:
|
||||
del propagator.op_strategy_funcs[op_overload]
|
||||
if op_overload in propagator.op_to_schema_info:
|
||||
del propagator.op_to_schema_info[op_overload]
|
||||
if _origin_op_strategy_funcs is None:
|
||||
if op_overload in propagator.op_strategy_funcs:
|
||||
del propagator.op_strategy_funcs[op_overload]
|
||||
else:
|
||||
propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs
|
||||
if _origin_op_strategy_schema is None:
|
||||
if op_overload in propagator.op_to_schema_info:
|
||||
del propagator.op_to_schema_info[op_overload]
|
||||
else:
|
||||
propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema
|
||||
propagator.propagate_op_sharding.cache.cache_clear()
|
||||
|
||||
|
||||
@ -607,5 +621,28 @@ class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
|
||||
self.assertEqual(output_dt.placements, [Replicate(), Replicate()])
|
||||
|
||||
|
||||
class TestStrategyHashing(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_call_with_different_nontensor_args(self):
|
||||
mesh = self.build_device_mesh()
|
||||
global_tensor = torch.tensor(
|
||||
[
|
||||
[29.0, 45.0, 3.0, 61.0],
|
||||
[25.0, 6.0, 21.0, 0.0],
|
||||
[1.0, 63.0, 49.0, 38.0],
|
||||
[48.0, 9.0, 55.0, 18.0],
|
||||
]
|
||||
)
|
||||
shard_spec = [Shard(1)]
|
||||
sharded_dtensor = distribute_tensor(global_tensor, mesh, shard_spec)
|
||||
with op_strategy_context(torch.ops.aten.sort.default, replicate_op_strategy):
|
||||
# intentionally do not supply `schema_info=RuntimeSchemaInfo(1)`
|
||||
torch.sort(sharded_dtensor, dim=0) # sort each column
|
||||
out1, _ = torch.sort(sharded_dtensor, dim=1) # sort each row
|
||||
with op_strategy_context(torch.ops.aten.sort.default, replicate_op_strategy):
|
||||
out2, _ = torch.sort(sharded_dtensor, dim=1)
|
||||
self.assertEqual(out1.full_tensor(), out2.full_tensor())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -1155,7 +1155,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
|
||||
@requires_gloo()
|
||||
def test_gather_noncontiguous_input(self):
|
||||
# Take a column of 2D tensor, such that memory is not dense
|
||||
self._test_gather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0])
|
||||
self._test_gather_basics(lambda t: t.expand(2, 2).tril().contiguous()[:, 0])
|
||||
|
||||
def _test_gather_stress(self, inputs, fn):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
|
||||
@ -130,6 +130,10 @@ class MPSBasicTests(TestCase):
|
||||
|
||||
self.common(fn, (torch.eye(64),), check_lowp=False)
|
||||
|
||||
def test_reduced_max(self):
|
||||
# inductor test do not validate that max of say 16K half elements can be computed
|
||||
self.common(torch.max, (torch.rand(16384, dtype=torch.half),), check_lowp=False)
|
||||
|
||||
|
||||
class MPSBasicTestsAOTI(TestCase):
|
||||
def check_model(self, m, inp, dynamic_shapes=None):
|
||||
|
||||
@ -12464,7 +12464,7 @@ class TestMetalLibrary(TestCaseMPS):
|
||||
lib = torch.mps.compile_shader("#include <c10/metal/special_math.h>")
|
||||
self.assertIsNotNone(lib)
|
||||
|
||||
@parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64])
|
||||
@parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int64])
|
||||
def test_reduction_utils(self, dtype):
|
||||
from torch._inductor.codegen.mps import DTYPE_TO_METAL
|
||||
lib = torch.mps.compile_shader(f"""
|
||||
@ -12474,14 +12474,40 @@ class TestMetalLibrary(TestCaseMPS):
|
||||
uint idx [[thread_position_in_grid]]) {{
|
||||
out[idx] = c10::metal::simd_sum(inp[idx]);
|
||||
}}
|
||||
|
||||
kernel void do_max(device {DTYPE_TO_METAL[dtype]}* out0,
|
||||
device int* out1,
|
||||
constant {DTYPE_TO_METAL[dtype]}* inp,
|
||||
uint idx [[thread_position_in_grid]]) {{
|
||||
auto rc = c10::metal::simd_argmax(inp[idx]);
|
||||
out0[idx] = rc.first;
|
||||
out1[idx] = rc.second;
|
||||
}}
|
||||
|
||||
""")
|
||||
x = torch.testing.make_tensor(28, device="mps", dtype=dtype)
|
||||
y = torch.empty_like(x)
|
||||
z0 = torch.empty_like(x)
|
||||
z1 = torch.empty_like(x, dtype=torch.int32)
|
||||
lib.do_sum(y, x)
|
||||
lib.do_max(z0, z1, x)
|
||||
x_sum = x.sum()
|
||||
x_max, x_max_idx = x.max(dim=0)
|
||||
max_err = (y - x_sum).abs().max().item()
|
||||
self.assertLess(max_err, 1e-2 if dtype == torch.float16 else 1e-5,
|
||||
f"results are {y}, but all elements should have been {x_sum.item()}")
|
||||
self.assertTrue((z0 == x_max).all().item(),
|
||||
f"results are {z0}, but all elements should have been {x_max.item()}")
|
||||
self.assertTrue((z1 == x_max_idx).all().item(),
|
||||
f"results are {z1}, but all elements should have been {x_max_idx.item()}")
|
||||
# Test nan propagation
|
||||
if not dtype.is_floating_point:
|
||||
return
|
||||
|
||||
x[5] = torch.nan
|
||||
lib.do_max(z0, z1, x)
|
||||
self.assertTrue(z0.isnan().all().item(), "results are {z0}, but all elements shold have been nan")
|
||||
self.assertTrue((z1 == 5).all().item(), "results are {z1}, but all elements shold have been 5")
|
||||
|
||||
@parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16])
|
||||
def test_atomic_add(self, dtype):
|
||||
|
||||
@ -818,7 +818,7 @@ class FxGraphHashDetails:
|
||||
|
||||
# Global settings affecting matmul codegen.
|
||||
self.cuda_matmul_settings = (
|
||||
torch.backends.cuda.matmul.allow_tf32,
|
||||
torch.backends.cuda.matmul.fp32_precision,
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction,
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction,
|
||||
)
|
||||
|
||||
@ -587,6 +587,7 @@ class MetalKernel(SIMDKernel):
|
||||
reduction_idx += f"{rd.name} * {acc_buf_size}"
|
||||
acc_buf_size *= rd.numel
|
||||
acc_buf_size = min(acc_buf_size, self.max_threadgroup_size)
|
||||
shmem_buf_size = ceildiv(acc_buf_size, self.simd_group_size)
|
||||
|
||||
if reduction_type == "any":
|
||||
acc = self._new_idxvar(dtype)
|
||||
@ -610,9 +611,7 @@ class MetalKernel(SIMDKernel):
|
||||
|
||||
if reduction_type in ["prod", "sum"]:
|
||||
acc_dtype = DTYPE_TO_COMPUTATION_DTYPE[src_dtype]
|
||||
acc_buf = self._new_idxvar(
|
||||
acc_dtype, ceildiv(acc_buf_size, self.simd_group_size)
|
||||
)
|
||||
acc_buf = self._new_idxvar(acc_dtype, shmem_buf_size)
|
||||
if not self.multistage_reduction_entry:
|
||||
val = value
|
||||
else:
|
||||
@ -628,7 +627,27 @@ class MetalKernel(SIMDKernel):
|
||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})",
|
||||
dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype],
|
||||
)
|
||||
if reduction_type in ["max", "min", "argmin", "argmax"]:
|
||||
if reduction_type in ["max", "min"]:
|
||||
acc_buf = self._new_idxvar(src_dtype, shmem_buf_size)
|
||||
src_metal_type = DTYPE_TO_METAL[src_dtype]
|
||||
cast_value = f"static_cast<{src_metal_type}>({value})"
|
||||
if not self.multistage_reduction_entry:
|
||||
val = cast_value # type: ignore[assignment]
|
||||
else:
|
||||
lim_fn = "lowest" if reduction_type.endswith("max") else "max"
|
||||
limit_val = f"::metal::numeric_limits<{src_metal_type}>::{lim_fn}()"
|
||||
val = self._new_idxvar(
|
||||
src_dtype, default_value=limit_val, is_threadgroup=False
|
||||
)
|
||||
self.compute.splice(
|
||||
f"{val} = ::c10::metal::{reduction_type}({val}, {cast_value});"
|
||||
)
|
||||
return self.cse.generate(
|
||||
self.stores,
|
||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})",
|
||||
dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype],
|
||||
)
|
||||
if reduction_type in ["argmin", "argmax"]:
|
||||
acc_buf = self._new_idxvar(src_dtype, acc_buf_size)
|
||||
acc_thread_var = f"{acc_buf}[{reduction_idx}]"
|
||||
src_metal_type = DTYPE_TO_METAL[src_dtype]
|
||||
@ -645,31 +664,20 @@ class MetalKernel(SIMDKernel):
|
||||
self.indexing_code.writeline(
|
||||
f"{acc_thread_var} = ::metal::numeric_limits<{src_metal_type}>::{lim_fn}();"
|
||||
)
|
||||
if reduction_type.startswith("arg"):
|
||||
idx_var = next(
|
||||
t for t in self.range_tree_nodes.values() if t.is_reduction
|
||||
)
|
||||
idx_acc_buf = self._new_idxvar(torch.long, acc_buf_size)
|
||||
cmp_op = ">" if reduction_type == "argmax" else "<"
|
||||
idx_thread_var = f"{idx_acc_buf}[{reduction_idx}]"
|
||||
self.indexing_code.splice(f"{idx_thread_var} = -1;")
|
||||
self.compute.splice(f"""
|
||||
if ({value} {cmp_op} {acc_thread_var}) {{
|
||||
{acc_thread_var} = {value};
|
||||
{idx_thread_var} = {idx_var.name};
|
||||
}}
|
||||
""")
|
||||
return self.cse.generate(
|
||||
self.stores,
|
||||
f"{idx_acc_buf}[c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})]",
|
||||
dtype=dtype,
|
||||
)
|
||||
self.compute.writeline(
|
||||
f"{acc_thread_var} = ::c10::metal::{reduction_type}({acc_thread_var}, {value});"
|
||||
)
|
||||
idx_var = next(t for t in self.range_tree_nodes.values() if t.is_reduction)
|
||||
idx_acc_buf = self._new_idxvar(torch.long, acc_buf_size)
|
||||
cmp_op = ">" if reduction_type == "argmax" else "<"
|
||||
idx_thread_var = f"{idx_acc_buf}[{reduction_idx}]"
|
||||
self.indexing_code.splice(f"{idx_thread_var} = -1;")
|
||||
self.compute.splice(f"""
|
||||
if ({value} {cmp_op} {acc_thread_var}) {{
|
||||
{acc_thread_var} = {value};
|
||||
{idx_thread_var} = {idx_var.name};
|
||||
}}
|
||||
""")
|
||||
return self.cse.generate(
|
||||
self.stores,
|
||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
|
||||
f"{idx_acc_buf}[c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})]",
|
||||
dtype=dtype,
|
||||
)
|
||||
if reduction_type == "welford_reduce":
|
||||
|
||||
@ -1775,7 +1775,8 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
|
||||
}
|
||||
|
||||
// Set single input tensor on all processes.
|
||||
GENERATE_ALL_TYPES(scalarType, setInput, opts, inputs[0]);
|
||||
at::Tensor flatInputTensor = flattenDenseTensors(inputs[0]);
|
||||
GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor);
|
||||
gloo::gather(opts);
|
||||
|
||||
// Unflatten into output tensors on root process.
|
||||
|
||||
@ -140,6 +140,8 @@ PyObject* dynamo__custom_eval_frame(
|
||||
auto fail = [&]() { clear_old_frame_if_python_312_plus(tstate, frame); };
|
||||
|
||||
#if IS_PYTHON_3_12_PLUS
|
||||
// skip tracing the frame if CPython is in a tracing state (e.g.
|
||||
// sys.monitoring call)
|
||||
if (tstate->tracing > 0) {
|
||||
eval_default();
|
||||
return eval_result;
|
||||
|
||||
@ -172,7 +172,9 @@ class OpDispatcher:
|
||||
# on args first, which could potentially modify args (i.e. allgather certain arg)
|
||||
assert output_sharding.redistribute_schema is not None
|
||||
self.redistribute_local_args(
|
||||
op_info, output_sharding.redistribute_schema
|
||||
op_info,
|
||||
output_sharding.redistribute_schema,
|
||||
output_sharding.use_val_from_redistribute_schema,
|
||||
)
|
||||
|
||||
local_tensor_args = (
|
||||
@ -293,6 +295,7 @@ class OpDispatcher:
|
||||
def redistribute_local_args(
|
||||
op_info: OpInfo,
|
||||
suggested_input_schema: OpSchema,
|
||||
use_val_from_redistribute_schema: bool,
|
||||
) -> None:
|
||||
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
|
||||
if op_info.args_tree_spec is not None:
|
||||
@ -315,7 +318,12 @@ class OpDispatcher:
|
||||
else:
|
||||
new_local_args.append(local_tensor)
|
||||
else:
|
||||
new_local_args.append(reshard_arg_spec)
|
||||
if use_val_from_redistribute_schema:
|
||||
# args can be updated for view related ops, we refer to the
|
||||
# update in redistribute_schema.
|
||||
new_local_args.append(reshard_arg_spec)
|
||||
else:
|
||||
new_local_args.append(arg_spec)
|
||||
|
||||
op_info.local_args = tuple(new_local_args)
|
||||
|
||||
|
||||
@ -559,9 +559,14 @@ class OutputSharding:
|
||||
exactly the same as the operator OpSchema, except the DTensorSpecs
|
||||
"""
|
||||
|
||||
# specifies the output sharding pattern
|
||||
output_spec: OutputSpecType
|
||||
# schema for redistribution if needed
|
||||
redistribute_schema: Optional[OpSchema] = None
|
||||
# flag indicating if inputs need redistribution
|
||||
needs_redistribute: bool = False
|
||||
# flag to use values from `redistribute_schema`
|
||||
use_val_from_redistribute_schema: bool = False
|
||||
|
||||
@cached_property
|
||||
def mesh(self):
|
||||
|
||||
@ -303,7 +303,7 @@ class ShardingPropagator:
|
||||
# because SymInts are not hashable.
|
||||
# This is generally ok because this only happens during tracing in torch.compile,
|
||||
# and compile autograd initial tracing, which do not need to be as fast as
|
||||
# eagermode DTensor usages.
|
||||
# eager mode DTensor usages.
|
||||
if _are_we_tracing():
|
||||
output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
|
||||
else:
|
||||
@ -336,6 +336,8 @@ class ShardingPropagator:
|
||||
|
||||
# check if we need to redistribute the input
|
||||
needs_redistribute = False
|
||||
# check if we want to use args value from redistribute_schema
|
||||
use_val_from_redistribute_schema = False
|
||||
expected_input_specs: list[DTensorSpec] = []
|
||||
|
||||
# in case where the op does not specify input_specs and output_specs
|
||||
@ -378,6 +380,7 @@ class ShardingPropagator:
|
||||
out_tensor_meta, schema, output_strategy.output_spec
|
||||
)
|
||||
needs_redistribute = True
|
||||
use_val_from_redistribute_schema = True
|
||||
|
||||
# construct output spec for the op
|
||||
if op_schema.return_type_tuple_tensor_like():
|
||||
@ -410,6 +413,7 @@ class ShardingPropagator:
|
||||
output_specs,
|
||||
suggestion_schema,
|
||||
needs_redistribute=needs_redistribute,
|
||||
use_val_from_redistribute_schema=use_val_from_redistribute_schema,
|
||||
)
|
||||
elif isinstance(op_strategy, TupleStrategy):
|
||||
# tuple strategy output sharding processing
|
||||
@ -478,6 +482,7 @@ class ShardingPropagator:
|
||||
tuple(out_spec_list) if out_tensor_meta is not None else None,
|
||||
suggestion_schema,
|
||||
needs_redistribute=needs_redistribute,
|
||||
use_val_from_redistribute_schema=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported op strategy type")
|
||||
|
||||
Reference in New Issue
Block a user