mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
As per title. Fixes https://github.com/pytorch/pytorch/issues/145731 Touches only compute. The CPU overhead can potentially be further reduced. Before: ```python In [3]: n = 512 In [4]: A = torch.rand(n, n) In [5]: B = torch.rand(n, n) In [6]: %timeit torch.tensordot(A, B, [[0, 1], [0, 1]]) 2.04 ms ± 70 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) In [7]: %timeit torch.tensordot(A, B, [[0, 1], [1, 0]]) 2.85 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) In [8]: %timeit torch.tensordot(A, B, [[1, 0], [0, 1]]) 2.9 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) In [9]: %timeit torch.tensordot(A, B, [[1, 0], [1, 0]]) 4.07 ms ± 262 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` After ```python In [2]: n = 512 In [3]: A = torch.rand(n, n) In [4]: B = torch.rand(n, n) In [5]: %timeit torch.tensordot(A, B, [[0, 1], [0, 1]]) 30.7 µs ± 2.51 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) In [6]: %timeit torch.tensordot(A, B, [[0, 1], [1, 0]]) 141 µs ± 6.52 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) In [7]: %timeit torch.tensordot(A, B, [[1, 0], [0, 1]]) 142 µs ± 4.03 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) In [8]: %timeit torch.tensordot(A, B, [[1, 0], [1, 0]]) 62.8 µs ± 4.31 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145936 Approved by: https://github.com/albanD, https://github.com/ngimel
662 lines
23 KiB
Python
662 lines
23 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import unittest
|
|
import warnings
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.testing._internal.common_methods_invocations as common_ops
|
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
|
from torch.overrides import resolve_name
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
ops,
|
|
)
|
|
from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db
|
|
from torch.testing._internal.common_utils import run_tests, suppress_warnings
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorConverter,
|
|
DTensorOpTestBase,
|
|
)
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
# rewrite common size variables to sth can be sharded evenly
|
|
# we can enable uneven shards later, but need to adjust more on
|
|
# sample inputs (i.e. view/reshape need to adjust shape size as well)
|
|
common_ops.L = 24
|
|
common_ops.M = 12
|
|
common_ops.S = 4
|
|
common_ops.XS = 2
|
|
|
|
|
|
# Copied from functorch
|
|
def xfail(op_name, variant_name="", *, device_type=None, dtypes=None):
|
|
return (op_name, variant_name, device_type, dtypes, True)
|
|
|
|
|
|
def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
|
|
return (op_name, variant_name, device_type, dtypes, False)
|
|
|
|
|
|
def skipOps(test_case_name, base_test_name, to_skip):
|
|
all_opinfos = op_db
|
|
for xfail in to_skip:
|
|
op_name, variant_name, device_type, dtypes, expected_failure = xfail
|
|
matching_opinfos = [
|
|
o
|
|
for o in all_opinfos
|
|
if o.name == op_name and o.variant_test_name == variant_name
|
|
]
|
|
assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}"
|
|
for opinfo in matching_opinfos:
|
|
decorators = list(opinfo.decorators)
|
|
if expected_failure:
|
|
decorator = DecorateInfo(
|
|
unittest.expectedFailure,
|
|
test_case_name,
|
|
base_test_name,
|
|
device_type=device_type,
|
|
dtypes=dtypes,
|
|
)
|
|
decorators.append(decorator)
|
|
else:
|
|
decorator = DecorateInfo(
|
|
unittest.skip("Skipped!"),
|
|
test_case_name,
|
|
base_test_name,
|
|
device_type=device_type,
|
|
dtypes=dtypes,
|
|
)
|
|
decorators.append(decorator)
|
|
opinfo.decorators = tuple(decorators)
|
|
|
|
# This decorator doesn't modify fn in any way
|
|
def wrapped(fn):
|
|
return fn
|
|
|
|
return wrapped
|
|
|
|
|
|
# Re-generate this failed list, turn on dry_run of the below func
|
|
# check_dtensor_func(self, test, op, dry_run=True), then run sth
|
|
# like python test/distributed/tensor/test_dtensor_ops.py > failed.expect
|
|
dtensor_fails = {
|
|
# these sometimes pass and sometimes fail
|
|
# we need to remove many of them from list once op
|
|
# get full support with varying sharding specs
|
|
xfail("__getitem__"),
|
|
xfail("__rsub__"),
|
|
xfail("_chunk_cat"),
|
|
xfail("_native_batch_norm_legit"),
|
|
xfail("_upsample_bilinear2d_aa"),
|
|
xfail("addbmm"),
|
|
xfail("addmv"),
|
|
xfail("addr"),
|
|
xfail("all"),
|
|
xfail("allclose"),
|
|
xfail("alias_copy"),
|
|
xfail("aminmax"),
|
|
xfail("any"),
|
|
xfail("arange"),
|
|
xfail("argmax"),
|
|
xfail("argmin"),
|
|
xfail("argsort"),
|
|
xfail("as_strided"),
|
|
xfail("as_strided", "partial_views"),
|
|
xfail("as_strided_copy"),
|
|
xfail("as_strided_scatter"),
|
|
xfail("bernoulli"),
|
|
xfail("_batch_norm_with_update"),
|
|
xfail("block_diag"),
|
|
xfail("broadcast_shapes"),
|
|
xfail("cartesian_prod"),
|
|
xfail("cauchy"),
|
|
xfail("cdist"),
|
|
xfail("cholesky"),
|
|
xfail("cholesky_inverse"),
|
|
xfail("cholesky_solve"),
|
|
xfail("chunk"),
|
|
xfail("clamp"),
|
|
xfail("clamp_max"),
|
|
xfail("clamp_min"),
|
|
xfail("combinations"),
|
|
xfail("complex"),
|
|
xfail("constant_pad_nd"),
|
|
xfail("count_nonzero"),
|
|
xfail("cross"),
|
|
xfail("cummax"),
|
|
xfail("cummin"),
|
|
xfail("diagonal_scatter"),
|
|
xfail("dist"),
|
|
xfail("empty"),
|
|
xfail("empty_strided"),
|
|
xfail("empty_like"),
|
|
xfail("empty_permuted"),
|
|
xfail("expand_copy"),
|
|
xfail("exponential"),
|
|
xfail("equal"),
|
|
xfail("eye"),
|
|
xfail("fft.fft2"),
|
|
xfail("fft.fft"),
|
|
xfail("fft.fftn"),
|
|
xfail("fft.fftshift"),
|
|
xfail("fft.ifft2"),
|
|
xfail("fft.ifft"),
|
|
xfail("fft.ifftshift"),
|
|
xfail("fft.ihfft2"),
|
|
xfail("fft.ihfft"),
|
|
xfail("fft.ihfftn"),
|
|
xfail("fft.irfft2"),
|
|
xfail("fft.irfftn"),
|
|
xfail("fft.rfft2"),
|
|
xfail("fft.rfft"),
|
|
xfail("fft.rfftn"),
|
|
xfail("fill"),
|
|
xfail("flatten"),
|
|
xfail("flip"),
|
|
xfail("fliplr"),
|
|
xfail("flipud"),
|
|
xfail("floor_divide"),
|
|
xfail("fmax"),
|
|
xfail("fmin"),
|
|
xfail("frexp"),
|
|
xfail("full"),
|
|
xfail("full_like"),
|
|
xfail("gather"),
|
|
xfail("geometric"),
|
|
xfail("geqrf"),
|
|
xfail("grid_sampler_2d"),
|
|
xfail("gradient"),
|
|
xfail("heaviside"),
|
|
xfail("histc"),
|
|
xfail("histogram"),
|
|
xfail("histogramdd"),
|
|
xfail("index_add"),
|
|
xfail("index_copy"),
|
|
xfail("index_fill"),
|
|
xfail("index_put"),
|
|
xfail("index_reduce", "prod"),
|
|
xfail("index_reduce", "mean"),
|
|
xfail("index_reduce", "amax"),
|
|
xfail("index_reduce", "amin"),
|
|
xfail("index_select"),
|
|
xfail("isin"),
|
|
xfail("kthvalue"),
|
|
xfail("kron"),
|
|
xfail("linalg.cholesky"),
|
|
xfail("linalg.cholesky_ex"),
|
|
xfail("linalg.cross"),
|
|
xfail("linalg.det"),
|
|
xfail("linalg.eig"),
|
|
xfail("linalg.eigvals"),
|
|
xfail("linalg.householder_product"),
|
|
xfail("linalg.inv"),
|
|
xfail("linalg.inv_ex"),
|
|
xfail("linalg.ldl_factor"),
|
|
xfail("linalg.ldl_factor_ex"),
|
|
xfail("linalg.ldl_solve"),
|
|
xfail("linalg.lstsq"),
|
|
xfail("linalg.lstsq", "grad_oriented"),
|
|
xfail("linalg.lu"),
|
|
xfail("linalg.lu_factor"),
|
|
xfail("linalg.lu_factor_ex"),
|
|
xfail("linalg.lu_solve"),
|
|
xfail("linalg.matrix_power"),
|
|
xfail("linalg.multi_dot"),
|
|
xfail("linalg.pinv"),
|
|
xfail("linalg.pinv", "hermitian"),
|
|
xfail("linalg.slogdet"),
|
|
xfail("linalg.solve"),
|
|
xfail("linalg.solve_ex"),
|
|
xfail("linalg.solve_triangular"),
|
|
xfail("linalg.tensorinv"),
|
|
xfail("linalg.tensorsolve"),
|
|
xfail("linalg.vander"),
|
|
xfail("linalg.vecdot"),
|
|
xfail("linspace"),
|
|
xfail("linspace", "tensor_overload"),
|
|
xfail("log_normal"),
|
|
xfail("logcumsumexp"),
|
|
xfail("logdet"),
|
|
xfail("logspace"),
|
|
xfail("logspace", "tensor_overload"),
|
|
xfail("logsumexp"),
|
|
xfail("lu"),
|
|
xfail("lu_solve"),
|
|
xfail("lu_unpack"),
|
|
xfail("masked_fill"),
|
|
xfail("masked_scatter"),
|
|
xfail("masked_select"),
|
|
xfail("masked.argmax"),
|
|
xfail("masked.argmin"),
|
|
xfail("masked.cumprod"),
|
|
xfail("masked.logsumexp"),
|
|
xfail("masked.median"),
|
|
xfail("matrix_exp"),
|
|
xfail("max", "reduction_with_dim"),
|
|
xfail("median"),
|
|
xfail("min", "reduction_with_dim"),
|
|
xfail("mode"),
|
|
xfail("msort"),
|
|
xfail("multinomial"),
|
|
xfail("mv"),
|
|
xfail("max_pool2d_with_indices_backward", ""),
|
|
xfail("nanmean"),
|
|
xfail("nanmedian"),
|
|
xfail("nanquantile"),
|
|
xfail("nansum"),
|
|
xfail("native_batch_norm"),
|
|
xfail("native_dropout_backward"),
|
|
xfail("narrow_copy"),
|
|
xfail("ne"),
|
|
xfail("new_empty"),
|
|
xfail("new_empty_strided"),
|
|
xfail("transpose"),
|
|
xfail("nn.functional.adaptive_avg_pool1d"),
|
|
xfail("nn.functional.adaptive_avg_pool2d"),
|
|
xfail("nn.functional.adaptive_avg_pool3d"),
|
|
xfail("nn.functional.adaptive_max_pool1d"),
|
|
xfail("nn.functional.adaptive_max_pool2d"),
|
|
xfail("nn.functional.adaptive_max_pool3d"),
|
|
xfail("nn.functional.alpha_dropout"),
|
|
xfail("nn.functional.avg_pool1d"),
|
|
xfail("nn.functional.avg_pool2d"),
|
|
xfail("nn.functional.avg_pool3d"),
|
|
xfail("nn.functional.batch_norm"),
|
|
xfail("nn.functional.batch_norm", "without_cudnn"),
|
|
xfail("nn.functional.bilinear"),
|
|
xfail("nn.functional.binary_cross_entropy"),
|
|
xfail("nn.functional.binary_cross_entropy_with_logits"),
|
|
xfail("nn.functional.celu"),
|
|
xfail("nn.functional.conv1d"),
|
|
xfail("nn.functional.conv2d"),
|
|
xfail("nn.functional.conv3d"),
|
|
xfail("nn.functional.conv_transpose1d"),
|
|
xfail("nn.functional.conv_transpose2d"),
|
|
xfail("nn.functional.conv_transpose3d"),
|
|
xfail("nn.functional.cosine_similarity"),
|
|
xfail("nn.functional.ctc_loss"),
|
|
xfail("nn.functional.dropout"),
|
|
xfail("nn.functional.dropout2d"),
|
|
xfail("nn.functional.dropout3d"),
|
|
xfail("nn.functional.elu"),
|
|
xfail("nn.functional.fractional_max_pool2d"),
|
|
xfail("nn.functional.fractional_max_pool3d"),
|
|
xfail("nn.functional.glu"),
|
|
xfail("nn.functional.grid_sample"),
|
|
xfail("nn.functional.group_norm"),
|
|
xfail("nn.functional.hardshrink"),
|
|
xfail("nn.functional.hardsigmoid"),
|
|
xfail("nn.functional.hardswish"),
|
|
xfail("nn.functional.hardtanh"),
|
|
xfail("nn.functional.huber_loss"),
|
|
xfail("nn.functional.instance_norm"),
|
|
xfail("nn.functional.interpolate", "area"),
|
|
xfail("nn.functional.interpolate", "nearest"),
|
|
xfail("nn.functional.interpolate", "nearest-exact"),
|
|
xfail("nn.functional.leaky_relu"),
|
|
xfail("nn.functional.linear"),
|
|
xfail("nn.functional.local_response_norm"),
|
|
xfail("nn.functional.logsigmoid"),
|
|
xfail("nn.functional.margin_ranking_loss"),
|
|
xfail("nn.functional.max_pool1d"),
|
|
xfail("nn.functional.max_pool2d"),
|
|
xfail("nn.functional.max_pool3d"),
|
|
xfail("nn.functional.max_unpool1d"),
|
|
xfail("nn.functional.max_unpool1d", "grad"),
|
|
xfail("nn.functional.max_unpool2d"),
|
|
xfail("nn.functional.max_unpool2d", "grad"),
|
|
xfail("nn.functional.max_unpool3d"),
|
|
xfail("nn.functional.max_unpool3d", "grad"),
|
|
xfail("nn.functional.mish"),
|
|
xfail("nn.functional.mse_loss"),
|
|
xfail("nn.functional.multi_margin_loss"),
|
|
xfail("nn.functional.multi_head_attention_forward"),
|
|
xfail("nn.functional.multilabel_margin_loss"),
|
|
xfail("nn.functional.multilabel_soft_margin_loss"),
|
|
xfail("nn.functional.normalize"),
|
|
xfail("nn.functional.pad", "constant"),
|
|
xfail("nn.functional.pad", "reflect"),
|
|
xfail("nn.functional.pad", "replicate"),
|
|
xfail("nn.functional.pad", "replicate_negative"),
|
|
xfail("nn.functional.pairwise_distance"),
|
|
xfail("nn.functional.pdist"),
|
|
xfail("nn.functional.pixel_shuffle"),
|
|
xfail("nn.functional.pixel_unshuffle"),
|
|
xfail("nn.functional.prelu"),
|
|
xfail("nn.functional.relu6"),
|
|
xfail("nn.functional.rrelu"),
|
|
xfail("nn.functional.selu"),
|
|
xfail("nn.functional.smooth_l1_loss"),
|
|
xfail("nn.functional.soft_margin_loss"),
|
|
xfail("nn.functional.softplus"),
|
|
xfail("nn.functional.softshrink"),
|
|
xfail("nn.functional.threshold"),
|
|
xfail("nn.functional.triplet_margin_loss"),
|
|
xfail("nn.functional.triplet_margin_with_distance_loss"),
|
|
xfail("nn.functional.unfold"),
|
|
xfail("nn.functional.upsample_nearest"),
|
|
xfail("nonzero"),
|
|
xfail("normal"),
|
|
xfail("normal", "number_mean"),
|
|
xfail("normal", "in_place"),
|
|
xfail("ormqr"),
|
|
xfail("ones"),
|
|
xfail("pca_lowrank"),
|
|
xfail("permute_copy"),
|
|
xfail("pinverse"),
|
|
xfail("polar"),
|
|
xfail("put"),
|
|
xfail("quantile"),
|
|
xfail("rand_like"),
|
|
xfail("randint_like"),
|
|
xfail("randint"),
|
|
xfail("randn"),
|
|
xfail("randn_like"),
|
|
xfail("ravel"),
|
|
xfail("renorm"),
|
|
xfail("repeat_interleave"),
|
|
xfail("resize_"),
|
|
xfail("resize_as_"),
|
|
xfail("reshape"),
|
|
xfail("reshape_as"),
|
|
xfail("roll"),
|
|
xfail("rot90"),
|
|
xfail("rsub"),
|
|
xfail("scalar_tensor"),
|
|
xfail("scatter_add"),
|
|
xfail("scatter_reduce", "amax"),
|
|
xfail("scatter_reduce", "amin"),
|
|
xfail("scatter_reduce", "mean"),
|
|
xfail("scatter_reduce", "prod"),
|
|
xfail("scatter_reduce", "sum"),
|
|
xfail("searchsorted"),
|
|
xfail("select_scatter"),
|
|
xfail("sort"),
|
|
xfail("sparse.sampled_addmm"),
|
|
xfail("sparse.mm", "reduce"),
|
|
xfail("special.airy_ai"),
|
|
xfail("special.bessel_j0"),
|
|
xfail("special.bessel_j1"),
|
|
xfail("special.bessel_y0"),
|
|
xfail("special.bessel_y1"),
|
|
xfail("special.chebyshev_polynomial_t"),
|
|
xfail("special.chebyshev_polynomial_u"),
|
|
xfail("special.entr"),
|
|
xfail("special.erfcx"),
|
|
xfail("special.hermite_polynomial_h"),
|
|
xfail("special.hermite_polynomial_he"),
|
|
xfail("special.i0e"),
|
|
xfail("special.i1"),
|
|
xfail("special.i1e"),
|
|
xfail("special.laguerre_polynomial_l"),
|
|
xfail("special.log_ndtr"),
|
|
xfail("special.modified_bessel_i0"),
|
|
xfail("special.modified_bessel_i1"),
|
|
xfail("special.modified_bessel_k0"),
|
|
xfail("special.modified_bessel_k1"),
|
|
xfail("special.ndtri"),
|
|
xfail("special.scaled_modified_bessel_k0"),
|
|
xfail("special.scaled_modified_bessel_k1"),
|
|
xfail("special.spherical_bessel_j0"),
|
|
xfail("special.xlog1py"),
|
|
xfail("special.zeta"),
|
|
xfail("squeeze", "multiple"),
|
|
xfail("squeeze_copy"),
|
|
xfail("signal.windows.bartlett"),
|
|
xfail("signal.windows.blackman"),
|
|
xfail("signal.windows.cosine"),
|
|
xfail("signal.windows.exponential"),
|
|
xfail("signal.windows.gaussian"),
|
|
xfail("signal.windows.general_cosine"),
|
|
xfail("signal.windows.general_hamming"),
|
|
xfail("signal.windows.hamming"),
|
|
xfail("signal.windows.hann"),
|
|
xfail("signal.windows.nuttall"),
|
|
xfail("signal.windows.kaiser"),
|
|
xfail("stack"),
|
|
xfail("std"),
|
|
xfail("std", "unbiased"),
|
|
xfail("std_mean"),
|
|
xfail("std_mean", "unbiased"),
|
|
xfail("stft"),
|
|
xfail("svd_lowrank"),
|
|
xfail("t_copy"),
|
|
xfail("take"),
|
|
xfail("take_along_dim"),
|
|
xfail("tensor_split"),
|
|
xfail("to_sparse"),
|
|
xfail("trace"),
|
|
xfail("trapezoid"),
|
|
xfail("trapz"),
|
|
xfail("triangular_solve"),
|
|
xfail("unbind"),
|
|
xfail("unbind_copy"),
|
|
xfail("unfold"),
|
|
xfail("unfold_copy"),
|
|
xfail("uniform"),
|
|
xfail("unflatten"),
|
|
xfail("unique_consecutive"),
|
|
xfail("unique"),
|
|
xfail("unsafe_split"),
|
|
xfail("unsafe_chunk"),
|
|
xfail("_unsafe_masked_index"),
|
|
xfail("_unsafe_masked_index_put_accumulate"),
|
|
xfail("var_mean"),
|
|
xfail("var_mean", "unbiased"),
|
|
xfail("vdot"),
|
|
xfail("view"),
|
|
xfail("view_as"),
|
|
xfail("view_copy"),
|
|
xfail("zeros"),
|
|
# /TODO(whc) debug/triage
|
|
# ops inside this might even fail without dtensor
|
|
# tests, as we rescale op db common test size factor (i.e. L, M, S)
|
|
# which triggered the original function run failures with input
|
|
# generation becomes wrong, we skip them for now but should enable later.
|
|
# TODO: need to clean this list and remove all cases
|
|
skip("argwhere"),
|
|
skip("cumprod"),
|
|
skip("__rmatmul__"),
|
|
skip("meshgrid", "list_of_tensors"),
|
|
skip("meshgrid", "variadic_tensors"),
|
|
skip("nn.functional.scaled_dot_product_attention"),
|
|
skip("nn.functional.softmin"),
|
|
skip("nn.functional.embedding"),
|
|
skip("nn.functional.embedding_bag"),
|
|
skip("nn.functional.feature_alpha_dropout", "with_train"),
|
|
skip("nn.functional.feature_alpha_dropout", "without_train"),
|
|
skip("nn.functional.hinge_embedding_loss"),
|
|
skip("nn.functional.cosine_embedding_loss"),
|
|
skip("fft.hfft"),
|
|
skip("fft.hfft2"),
|
|
skip("fft.hfft2"),
|
|
skip("fft.hfftn"),
|
|
skip("fft.ifftn"),
|
|
skip("fft.irfft"),
|
|
skip("istft"),
|
|
skip("isclose"),
|
|
skip("isreal"),
|
|
skip("matmul"),
|
|
skip("masked.mean"),
|
|
skip("masked.var"),
|
|
skip("masked.std"),
|
|
skip("masked.normalize"),
|
|
skip("prod"),
|
|
skip("_segment_reduce", "lengths"),
|
|
skip("_segment_reduce", "offsets"),
|
|
# TODO: fix the following ops
|
|
skip("squeeze"),
|
|
}
|
|
|
|
|
|
# Add a list of ops that are currently failing BW pass
|
|
skip_bw = [
|
|
None, # corresponds to the transpose ops 'H' and 'T'
|
|
"torch.bucketize",
|
|
"torch.conj_physical",
|
|
"torch.eq",
|
|
"torch.isfinite",
|
|
"torch.isnan",
|
|
]
|
|
|
|
|
|
OP_DB_WORLD_SIZE = 4
|
|
# DEVICE_TYPE = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() >= OP_DB_WORLD_SIZE else "cpu"
|
|
# TODO: debug cuda illegal memory access issue and re-enable cuda tests
|
|
DEVICE_TYPE = "cpu"
|
|
|
|
|
|
class TestDTensorOps(DTensorOpTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return OP_DB_WORLD_SIZE
|
|
|
|
# only allow float dytpe for now, we can relax this constraint
|
|
# when feel necessary later (i.e when adding quantization support).
|
|
@suppress_warnings
|
|
@ops(op_db, allowed_dtypes=(torch.float,))
|
|
@skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails)
|
|
def test_dtensor_op_db(self, dtype, op):
|
|
self.mesh = DeviceMesh(DEVICE_TYPE, torch.arange(self.world_size))
|
|
|
|
# test each op with dist tensor inputs and normal inputs
|
|
def test():
|
|
samples = op.sample_inputs(DEVICE_TYPE, dtype, requires_grad=True)
|
|
for sample_input in samples:
|
|
args = [sample_input.input] + list(sample_input.args)
|
|
kwargs = sample_input.kwargs
|
|
|
|
self.run_dtensor_crossref(op.op, args, kwargs)
|
|
# we need to figure out a way to test the out variant, out variant testing
|
|
# is tricky, as we need to pre allocate the dtensor out, some of them rely
|
|
# on sharding placements to be pre-known (i.e. mm.out)
|
|
# if isinstance(expected, torch.Tensor) and op.supports_out:
|
|
# func(*args, **kwargs, out=expected)
|
|
|
|
self.check_dtensor_func(test, op)
|
|
|
|
def assert_ref_dtensor_equal(self, dtensor_rs, rs):
|
|
flat_dtensor_rs = pytree.tree_leaves(dtensor_rs)
|
|
flat_rs = pytree.tree_leaves(rs)
|
|
self.assertEqual(len(flat_dtensor_rs), len(flat_rs))
|
|
for dtensor_r, r in zip(flat_dtensor_rs, flat_rs):
|
|
if not isinstance(r, torch.Tensor):
|
|
continue
|
|
|
|
self.assertIsInstance(dtensor_r, torch.Tensor)
|
|
self.assertEqualOnRank(
|
|
dtensor_r.shape,
|
|
r.shape,
|
|
f"Shape mismatch! original shape:{r.shape}, dtensor shape: {dtensor_r.shape}",
|
|
)
|
|
self.assertEqualOnRank(
|
|
dtensor_r.requires_grad,
|
|
r.requires_grad,
|
|
"op result requires_grad mismatch!"
|
|
f"original requires_grad: {r.requires_grad}, "
|
|
f"dtensor requires_grad: {dtensor_r.requires_grad}",
|
|
)
|
|
|
|
self.assertEqualOnRank(dtensor_r, r)
|
|
|
|
def run_dtensor_crossref(self, func, args, kwargs):
|
|
to_dtensor = DTensorConverter(self.mesh, args, kwargs)
|
|
|
|
def concat_res_if_necessary(func, res: object) -> object:
|
|
# concat the result on corresponding dim for ops like
|
|
# split, so that we can call backward on a single tensor
|
|
if (resolve_name(func) is not None) and ("split" in resolve_name(func)):
|
|
dim = args[2] if len(args) == 3 else 0
|
|
return torch.cat(res, dim=dim)
|
|
else:
|
|
return res
|
|
|
|
# TODO: also handle cases where func raise an exception
|
|
rs = func(*args, **kwargs)
|
|
rs = concat_res_if_necessary(func, rs)
|
|
|
|
def to_replicate(e: object) -> object:
|
|
return e.full_tensor() if isinstance(e, DTensor) else e
|
|
|
|
# Suppress warnings, this doesn't matter for test_meta.py
|
|
# but it does matter if you want to use this decorator
|
|
# for cross-ref testing, as some tests may be looking at
|
|
# errors
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
# for every comb of sharding choices, we test if it works
|
|
for dtensor_args, dtensor_kwargs in to_dtensor:
|
|
# Only attempt if we managed to convert all tensors to DTensor
|
|
# (if any of them failed, we're in a mixed tensor situation and
|
|
# this is not allowed in DTensor)
|
|
try:
|
|
if to_dtensor.successful():
|
|
# Handle special cases first if there's any
|
|
# Suppress warnings, this doesn't matter for test_meta.py
|
|
# but it does matter if you want to use this decorator
|
|
# for cross-ref testing, as some tests may be looking at
|
|
# errors
|
|
dtensor_rs = func(*dtensor_args, **dtensor_kwargs)
|
|
|
|
# we need to skip tests containing tensors of zero elements for now.
|
|
# see issue: https://github.com/pytorch/PiPPy/issues/470
|
|
# TODO remove this once issue above fixed.
|
|
flat_args = pytree.tree_leaves(dtensor_rs)
|
|
if any(
|
|
isinstance(e, torch.Tensor) and e.numel() == 0
|
|
for e in flat_args
|
|
):
|
|
continue
|
|
|
|
# redistribute/all_gather the results to compare with normal output
|
|
dtensor_rs = tree_map(to_replicate, dtensor_rs)
|
|
dtensor_rs = concat_res_if_necessary(func, dtensor_rs)
|
|
try:
|
|
if resolve_name(func) not in skip_bw:
|
|
if isinstance(dtensor_rs, DTensor):
|
|
dtensor_rs.to_local().sum().backward()
|
|
elif isinstance(dtensor_rs, tuple):
|
|
dtensor_rs[0].to_local().sum().backward()
|
|
|
|
except Exception as e:
|
|
# TODO(anj): Remove this guard exception after gaining more confidence.
|
|
if torch.distributed.get_rank() == 0:
|
|
print(
|
|
f"failed to run BW: {resolve_name(func)}, {func}, {str(e)})"
|
|
)
|
|
self.assert_ref_dtensor_equal(dtensor_rs, rs)
|
|
else:
|
|
raise RuntimeError(
|
|
f"failed to convert args to DTensor; "
|
|
f"originally (*{args}, **{kwargs})"
|
|
)
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"failed to run: {resolve_name(func)}, with (*{dtensor_args}, **{dtensor_kwargs})"
|
|
) from e
|
|
return rs
|
|
|
|
def check_dtensor_func(self, test_func, opinfo, dry_run=False):
|
|
try:
|
|
test_func()
|
|
except Exception:
|
|
if not dry_run:
|
|
raise
|
|
if dist.get_rank() == 0:
|
|
if opinfo.variant_test_name:
|
|
print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),")
|
|
else:
|
|
print(f"xfail('{opinfo.name}'),")
|
|
|
|
|
|
# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU)
|
|
instantiate_device_type_tests(TestDTensorOps, globals(), only_for=(DEVICE_TYPE,))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|