Update (base update)

[ghstack-poisoned]
This commit is contained in:
Yu, Guangye
2025-08-04 14:49:54 +00:00
parent 147baaa248
commit ef87ca45ae
15 changed files with 331 additions and 74 deletions

View File

@ -38,17 +38,19 @@ static inline std::string _cudaGetErrorEnum(cufftResult error)
return "CUFFT_INVALID_SIZE";
case CUFFT_UNALIGNED_DATA:
return "CUFFT_UNALIGNED_DATA";
case CUFFT_INCOMPLETE_PARAMETER_LIST:
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
case CUFFT_INVALID_DEVICE:
return "CUFFT_INVALID_DEVICE";
case CUFFT_PARSE_ERROR:
return "CUFFT_PARSE_ERROR";
case CUFFT_NO_WORKSPACE:
return "CUFFT_NO_WORKSPACE";
case CUFFT_NOT_IMPLEMENTED:
return "CUFFT_NOT_IMPLEMENTED";
#if !defined(USE_ROCM)
#if CUDA_VERSION <= 12090
case CUFFT_INCOMPLETE_PARAMETER_LIST:
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
case CUFFT_PARSE_ERROR:
return "CUFFT_PARSE_ERROR";
#endif
#if !defined(USE_ROCM) && CUDA_VERSION <= 12090
case CUFFT_LICENSE_ERROR:
return "CUFFT_LICENSE_ERROR";
#endif

View File

@ -5,15 +5,86 @@
namespace c10 {
namespace metal {
namespace detail {
template <typename T>
struct simd_type {
using t = T;
};
// Helper that allows one to run simd ops over bfl16 by upcasting them to fp32
template <typename T>
using simd_type_t = typename simd_type<T>::t;
#if __METAL_VERSION__ >= 310
template <>
struct simd_type<bfloat> {
using t = float;
};
#endif
} // namespace detail
template <typename T>
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) {
return ::metal::simd_sum(val);
return T(::metal::simd_sum(detail::simd_type_t<T>(val)));
}
template <typename T>
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_prod(T val) {
return ::metal::simd_product(val);
return T(::metal::simd_product(detail::simd_type_t<T>(val)));
}
// Extend simd_broadcast to 64-bit integral types using int2 trick
template <
typename T,
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) == 8, bool> =
true>
inline T simd_broadcast(T val, ushort lane_id) {
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), lane_id));
}
template <
typename T,
::metal::enable_if_t<!::metal::is_integral_v<T> || sizeof(T) != 8, bool> =
true>
inline T simd_broadcast(T val, ushort lane_id) {
return ::metal::simd_broadcast(val, lane_id);
}
// Floating simd_min/max with nan propagation
template <
typename T,
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
inline T simd_max(T val) {
if (::metal::simd_any(::metal::isnan(val))) {
return ::metal::numeric_limits<T>::quiet_NaN();
}
return T(::metal::simd_max(detail::simd_type_t<T>(val)));
}
template <
typename T,
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
inline T simd_min(T val) {
if (::metal::simd_any(::metal::isnan(val))) {
return ::metal::numeric_limits<T>::quiet_NaN();
}
return T(::metal::simd_min(detail::simd_type_t<T>(val)));
}
template <
typename T,
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) != 8, bool> =
true>
inline T simd_max(T val) {
return ::metal::simd_max(val);
}
template <
typename T,
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) != 8, bool> =
true>
inline T simd_min(T val) {
return ::metal::simd_min(val);
}
// Metal does not support SIMD reductions over 64-bit types, but it could be
@ -28,7 +99,7 @@ inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_sum(T val) {
val += as_type<T>(
::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
}
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
return simd_broadcast(val, 0);
}
template <typename T>
@ -37,7 +108,78 @@ inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_prod(T val) {
val *= as_type<T>(
::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
}
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
return simd_broadcast(val, 0);
}
template <typename T>
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_max(T val) {
for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
val = ::metal::max(
val,
as_type<T>(::metal::simd_shuffle_and_fill_down(
as_type<int2>(val), int2(0), i)));
}
return simd_broadcast(val, 0);
}
template <typename T>
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_min(T val) {
for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
val = ::metal::min(
val,
as_type<T>(::metal::simd_shuffle_and_fill_down(
as_type<int2>(val), int2(0), i)));
}
return simd_broadcast(val, 0);
}
// argmin/argmax helpers using simd_ballot
template <
typename T,
::metal::enable_if_t<::metal::is_integral_v<T>, bool> = true>
inline ::c10::metal::pair<T, ushort> simd_argmin(T val) {
const auto rc = simd_min(val);
const auto vote = ::metal::simd_ballot(val == rc);
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
}
template <
typename T,
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
inline ::c10::metal::pair<T, ushort> simd_argmin(T val) {
const auto rc = simd_min(val);
const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val));
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
}
template <
typename T,
::metal::enable_if_t<::metal::is_integral_v<T>, bool> = true>
inline ::c10::metal::pair<T, ushort> simd_argmax(T val) {
const auto rc = simd_max(val);
const auto vote = ::metal::simd_ballot(val == rc);
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
}
template <
typename T,
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
inline ::c10::metal::pair<T, ushort> simd_argmax(T val) {
const auto rc = simd_max(val);
const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val));
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
}
template <typename ARG_T, typename IDX_T>
inline c10::metal::pair<ARG_T, IDX_T> simd_argmin(ARG_T val, IDX_T idx_val) {
auto rc = simd_argmin(val);
return {rc.first, simd_broadcast(idx_val, rc.second)};
}
template <typename ARG_T, typename IDX_T>
inline c10::metal::pair<ARG_T, IDX_T> simd_argmax(ARG_T val, IDX_T idx_val) {
auto rc = simd_argmax(val);
return {rc.first, simd_broadcast(idx_val, rc.second)};
}
// Below algorithms are written with hardcoded assumption that simdgroup is 32
@ -88,6 +230,44 @@ opmath_t<T> threadgroup_prod(
return data[0];
}
template <typename T>
T threadgroup_max(threadgroup T* data, T val, unsigned idx, unsigned size) {
auto rc = simd_max(val);
if (idx % simdgroup_size == 0) {
data[idx / simdgroup_size] = rc;
}
if (size > simdgroup_size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
auto rc1 = simd_max(data[idx]);
if (idx == 0) {
data[0] = rc1;
}
}
}
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
return data[0];
}
template <typename T>
T threadgroup_min(threadgroup T* data, T val, unsigned idx, unsigned size) {
auto rc = simd_min(val);
if (idx % simdgroup_size == 0) {
data[idx / simdgroup_size] = rc;
}
if (size > simdgroup_size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
auto rc1 = simd_min(data[idx]);
if (idx == 0) {
data[0] = rc1;
}
}
}
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
return data[0];
}
template <typename T>
float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
@ -123,28 +303,6 @@ float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) {
return rc;
}
template <typename T>
T threadgroup_max(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
T rc = data[0];
for (unsigned idx = 1; idx < size; ++idx) {
rc = ::c10::metal::max(rc, data[idx]);
}
return rc;
}
template <typename T>
T threadgroup_min(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
T rc = data[0];
for (unsigned idx = 1; idx < size; ++idx) {
rc = ::c10::metal::min(rc, data[idx]);
}
return rc;
}
template <typename T>
int threadgroup_argmax(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee

View File

@ -330,5 +330,11 @@ inline float log1p(float x) {
return rc;
}
template <typename T1, typename T2 = T1>
struct pair {
T1 first;
T2 second;
};
} // namespace metal
} // namespace c10

View File

@ -120,7 +120,6 @@ dtensor_fails = {
xfail("chunk"),
xfail("combinations"),
xfail("complex"),
xfail("constant_pad_nd"),
xfail("count_nonzero"),
xfail("cross"),
xfail("cummax"),
@ -167,7 +166,6 @@ dtensor_fails = {
xfail("grid_sampler_2d"),
xfail("gradient"),
xfail("heaviside"),
xfail("histc"),
xfail("histogram"),
xfail("histogramdd"),
xfail("index_add"),
@ -244,7 +242,6 @@ dtensor_fails = {
xfail("nanquantile"),
xfail("nansum"),
xfail("native_batch_norm"),
xfail("native_dropout_backward"),
xfail("narrow_copy"),
xfail("ne"),
xfail("new_empty"),
@ -309,10 +306,8 @@ dtensor_fails = {
xfail("nn.functional.mish"),
xfail("nn.functional.mse_loss"),
xfail("nn.functional.multi_margin_loss"),
xfail("nn.functional.multi_head_attention_forward"),
xfail("nn.functional.multilabel_margin_loss"),
xfail("nn.functional.multilabel_soft_margin_loss"),
xfail("nn.functional.pad", "constant"),
xfail("nn.functional.pad", "reflect"),
xfail("nn.functional.pad", "replicate"),
xfail("nn.functional.pad", "replicate_negative"),

View File

@ -454,16 +454,30 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None):
None
"""
propagator = DTensor._op_dispatcher.sharding_propagator
_origin_op_strategy_funcs = None
_origin_op_strategy_schema = None
try:
# register the op strategy
if op_overload in propagator.op_strategy_funcs:
_origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload]
del propagator.op_strategy_funcs[op_overload]
if op_overload in propagator.op_to_schema_info:
_origin_op_strategy_schema = propagator.op_to_schema_info[op_overload]
del propagator.op_to_schema_info[op_overload]
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
yield
finally:
# clear this op strategy cache
if op_overload in propagator.op_strategy_funcs:
del propagator.op_strategy_funcs[op_overload]
if op_overload in propagator.op_to_schema_info:
del propagator.op_to_schema_info[op_overload]
if _origin_op_strategy_funcs is None:
if op_overload in propagator.op_strategy_funcs:
del propagator.op_strategy_funcs[op_overload]
else:
propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs
if _origin_op_strategy_schema is None:
if op_overload in propagator.op_to_schema_info:
del propagator.op_to_schema_info[op_overload]
else:
propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema
propagator.propagate_op_sharding.cache.cache_clear()
@ -607,5 +621,28 @@ class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
self.assertEqual(output_dt.placements, [Replicate(), Replicate()])
class TestStrategyHashing(DTensorTestBase):
@with_comms
def test_call_with_different_nontensor_args(self):
mesh = self.build_device_mesh()
global_tensor = torch.tensor(
[
[29.0, 45.0, 3.0, 61.0],
[25.0, 6.0, 21.0, 0.0],
[1.0, 63.0, 49.0, 38.0],
[48.0, 9.0, 55.0, 18.0],
]
)
shard_spec = [Shard(1)]
sharded_dtensor = distribute_tensor(global_tensor, mesh, shard_spec)
with op_strategy_context(torch.ops.aten.sort.default, replicate_op_strategy):
# intentionally do not supply `schema_info=RuntimeSchemaInfo(1)`
torch.sort(sharded_dtensor, dim=0) # sort each column
out1, _ = torch.sort(sharded_dtensor, dim=1) # sort each row
with op_strategy_context(torch.ops.aten.sort.default, replicate_op_strategy):
out2, _ = torch.sort(sharded_dtensor, dim=1)
self.assertEqual(out1.full_tensor(), out2.full_tensor())
if __name__ == "__main__":
run_tests()

View File

@ -1155,7 +1155,7 @@ class ProcessGroupGlooTest(MultiProcessTestCase):
@requires_gloo()
def test_gather_noncontiguous_input(self):
# Take a column of 2D tensor, such that memory is not dense
self._test_gather_basics(lambda t: t.expand(2, 2).contiguous()[:, 0])
self._test_gather_basics(lambda t: t.expand(2, 2).tril().contiguous()[:, 0])
def _test_gather_stress(self, inputs, fn):
store = c10d.FileStore(self.file_name, self.world_size)

View File

@ -130,6 +130,10 @@ class MPSBasicTests(TestCase):
self.common(fn, (torch.eye(64),), check_lowp=False)
def test_reduced_max(self):
# inductor test do not validate that max of say 16K half elements can be computed
self.common(torch.max, (torch.rand(16384, dtype=torch.half),), check_lowp=False)
class MPSBasicTestsAOTI(TestCase):
def check_model(self, m, inp, dynamic_shapes=None):

View File

@ -12464,7 +12464,7 @@ class TestMetalLibrary(TestCaseMPS):
lib = torch.mps.compile_shader("#include <c10/metal/special_math.h>")
self.assertIsNotNone(lib)
@parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64])
@parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int64])
def test_reduction_utils(self, dtype):
from torch._inductor.codegen.mps import DTYPE_TO_METAL
lib = torch.mps.compile_shader(f"""
@ -12474,14 +12474,40 @@ class TestMetalLibrary(TestCaseMPS):
uint idx [[thread_position_in_grid]]) {{
out[idx] = c10::metal::simd_sum(inp[idx]);
}}
kernel void do_max(device {DTYPE_TO_METAL[dtype]}* out0,
device int* out1,
constant {DTYPE_TO_METAL[dtype]}* inp,
uint idx [[thread_position_in_grid]]) {{
auto rc = c10::metal::simd_argmax(inp[idx]);
out0[idx] = rc.first;
out1[idx] = rc.second;
}}
""")
x = torch.testing.make_tensor(28, device="mps", dtype=dtype)
y = torch.empty_like(x)
z0 = torch.empty_like(x)
z1 = torch.empty_like(x, dtype=torch.int32)
lib.do_sum(y, x)
lib.do_max(z0, z1, x)
x_sum = x.sum()
x_max, x_max_idx = x.max(dim=0)
max_err = (y - x_sum).abs().max().item()
self.assertLess(max_err, 1e-2 if dtype == torch.float16 else 1e-5,
f"results are {y}, but all elements should have been {x_sum.item()}")
self.assertTrue((z0 == x_max).all().item(),
f"results are {z0}, but all elements should have been {x_max.item()}")
self.assertTrue((z1 == x_max_idx).all().item(),
f"results are {z1}, but all elements should have been {x_max_idx.item()}")
# Test nan propagation
if not dtype.is_floating_point:
return
x[5] = torch.nan
lib.do_max(z0, z1, x)
self.assertTrue(z0.isnan().all().item(), "results are {z0}, but all elements shold have been nan")
self.assertTrue((z1 == 5).all().item(), "results are {z1}, but all elements shold have been 5")
@parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16])
def test_atomic_add(self, dtype):

View File

@ -818,7 +818,7 @@ class FxGraphHashDetails:
# Global settings affecting matmul codegen.
self.cuda_matmul_settings = (
torch.backends.cuda.matmul.allow_tf32,
torch.backends.cuda.matmul.fp32_precision,
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction,
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction,
)

View File

@ -587,6 +587,7 @@ class MetalKernel(SIMDKernel):
reduction_idx += f"{rd.name} * {acc_buf_size}"
acc_buf_size *= rd.numel
acc_buf_size = min(acc_buf_size, self.max_threadgroup_size)
shmem_buf_size = ceildiv(acc_buf_size, self.simd_group_size)
if reduction_type == "any":
acc = self._new_idxvar(dtype)
@ -610,9 +611,7 @@ class MetalKernel(SIMDKernel):
if reduction_type in ["prod", "sum"]:
acc_dtype = DTYPE_TO_COMPUTATION_DTYPE[src_dtype]
acc_buf = self._new_idxvar(
acc_dtype, ceildiv(acc_buf_size, self.simd_group_size)
)
acc_buf = self._new_idxvar(acc_dtype, shmem_buf_size)
if not self.multistage_reduction_entry:
val = value
else:
@ -628,7 +627,27 @@ class MetalKernel(SIMDKernel):
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})",
dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype],
)
if reduction_type in ["max", "min", "argmin", "argmax"]:
if reduction_type in ["max", "min"]:
acc_buf = self._new_idxvar(src_dtype, shmem_buf_size)
src_metal_type = DTYPE_TO_METAL[src_dtype]
cast_value = f"static_cast<{src_metal_type}>({value})"
if not self.multistage_reduction_entry:
val = cast_value # type: ignore[assignment]
else:
lim_fn = "lowest" if reduction_type.endswith("max") else "max"
limit_val = f"::metal::numeric_limits<{src_metal_type}>::{lim_fn}()"
val = self._new_idxvar(
src_dtype, default_value=limit_val, is_threadgroup=False
)
self.compute.splice(
f"{val} = ::c10::metal::{reduction_type}({val}, {cast_value});"
)
return self.cse.generate(
self.stores,
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})",
dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype],
)
if reduction_type in ["argmin", "argmax"]:
acc_buf = self._new_idxvar(src_dtype, acc_buf_size)
acc_thread_var = f"{acc_buf}[{reduction_idx}]"
src_metal_type = DTYPE_TO_METAL[src_dtype]
@ -645,31 +664,20 @@ class MetalKernel(SIMDKernel):
self.indexing_code.writeline(
f"{acc_thread_var} = ::metal::numeric_limits<{src_metal_type}>::{lim_fn}();"
)
if reduction_type.startswith("arg"):
idx_var = next(
t for t in self.range_tree_nodes.values() if t.is_reduction
)
idx_acc_buf = self._new_idxvar(torch.long, acc_buf_size)
cmp_op = ">" if reduction_type == "argmax" else "<"
idx_thread_var = f"{idx_acc_buf}[{reduction_idx}]"
self.indexing_code.splice(f"{idx_thread_var} = -1;")
self.compute.splice(f"""
if ({value} {cmp_op} {acc_thread_var}) {{
{acc_thread_var} = {value};
{idx_thread_var} = {idx_var.name};
}}
""")
return self.cse.generate(
self.stores,
f"{idx_acc_buf}[c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})]",
dtype=dtype,
)
self.compute.writeline(
f"{acc_thread_var} = ::c10::metal::{reduction_type}({acc_thread_var}, {value});"
)
idx_var = next(t for t in self.range_tree_nodes.values() if t.is_reduction)
idx_acc_buf = self._new_idxvar(torch.long, acc_buf_size)
cmp_op = ">" if reduction_type == "argmax" else "<"
idx_thread_var = f"{idx_acc_buf}[{reduction_idx}]"
self.indexing_code.splice(f"{idx_thread_var} = -1;")
self.compute.splice(f"""
if ({value} {cmp_op} {acc_thread_var}) {{
{acc_thread_var} = {value};
{idx_thread_var} = {idx_var.name};
}}
""")
return self.cse.generate(
self.stores,
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
f"{idx_acc_buf}[c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})]",
dtype=dtype,
)
if reduction_type == "welford_reduce":

View File

@ -1775,7 +1775,8 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
}
// Set single input tensor on all processes.
GENERATE_ALL_TYPES(scalarType, setInput, opts, inputs[0]);
at::Tensor flatInputTensor = flattenDenseTensors(inputs[0]);
GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor);
gloo::gather(opts);
// Unflatten into output tensors on root process.

View File

@ -140,6 +140,8 @@ PyObject* dynamo__custom_eval_frame(
auto fail = [&]() { clear_old_frame_if_python_312_plus(tstate, frame); };
#if IS_PYTHON_3_12_PLUS
// skip tracing the frame if CPython is in a tracing state (e.g.
// sys.monitoring call)
if (tstate->tracing > 0) {
eval_default();
return eval_result;

View File

@ -172,7 +172,9 @@ class OpDispatcher:
# on args first, which could potentially modify args (i.e. allgather certain arg)
assert output_sharding.redistribute_schema is not None
self.redistribute_local_args(
op_info, output_sharding.redistribute_schema
op_info,
output_sharding.redistribute_schema,
output_sharding.use_val_from_redistribute_schema,
)
local_tensor_args = (
@ -293,6 +295,7 @@ class OpDispatcher:
def redistribute_local_args(
op_info: OpInfo,
suggested_input_schema: OpSchema,
use_val_from_redistribute_schema: bool,
) -> None:
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
if op_info.args_tree_spec is not None:
@ -315,7 +318,12 @@ class OpDispatcher:
else:
new_local_args.append(local_tensor)
else:
new_local_args.append(reshard_arg_spec)
if use_val_from_redistribute_schema:
# args can be updated for view related ops, we refer to the
# update in redistribute_schema.
new_local_args.append(reshard_arg_spec)
else:
new_local_args.append(arg_spec)
op_info.local_args = tuple(new_local_args)

View File

@ -559,9 +559,14 @@ class OutputSharding:
exactly the same as the operator OpSchema, except the DTensorSpecs
"""
# specifies the output sharding pattern
output_spec: OutputSpecType
# schema for redistribution if needed
redistribute_schema: Optional[OpSchema] = None
# flag indicating if inputs need redistribution
needs_redistribute: bool = False
# flag to use values from `redistribute_schema`
use_val_from_redistribute_schema: bool = False
@cached_property
def mesh(self):

View File

@ -303,7 +303,7 @@ class ShardingPropagator:
# because SymInts are not hashable.
# This is generally ok because this only happens during tracing in torch.compile,
# and compile autograd initial tracing, which do not need to be as fast as
# eagermode DTensor usages.
# eager mode DTensor usages.
if _are_we_tracing():
output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
else:
@ -336,6 +336,8 @@ class ShardingPropagator:
# check if we need to redistribute the input
needs_redistribute = False
# check if we want to use args value from redistribute_schema
use_val_from_redistribute_schema = False
expected_input_specs: list[DTensorSpec] = []
# in case where the op does not specify input_specs and output_specs
@ -378,6 +380,7 @@ class ShardingPropagator:
out_tensor_meta, schema, output_strategy.output_spec
)
needs_redistribute = True
use_val_from_redistribute_schema = True
# construct output spec for the op
if op_schema.return_type_tuple_tensor_like():
@ -410,6 +413,7 @@ class ShardingPropagator:
output_specs,
suggestion_schema,
needs_redistribute=needs_redistribute,
use_val_from_redistribute_schema=use_val_from_redistribute_schema,
)
elif isinstance(op_strategy, TupleStrategy):
# tuple strategy output sharding processing
@ -478,6 +482,7 @@ class ShardingPropagator:
tuple(out_spec_list) if out_tensor_meta is not None else None,
suggestion_schema,
needs_redistribute=needs_redistribute,
use_val_from_redistribute_schema=False,
)
else:
raise ValueError("Unsupported op strategy type")