Compare commits

...

8 Commits

Author SHA1 Message Date
5b6cc8215f Change python doc push script to print the undocumented modules 2025-10-21 12:30:49 -07:00
1c43c9cfd0 Update 2025-10-21 12:30:49 -07:00
102e0d5437 Test 2025-10-21 12:30:49 -07:00
0bd12c1168 [CI] Extend test_transfomers to MPS (#165960)
Just skip grad_checks as they need float64
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165960
Approved by: https://github.com/Skylion007
2025-10-21 19:27:44 +00:00
ce8a7764e2 Revert "[dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)"
This reverts commit 1290b077f26543a34262587137ef64ca9ca5e17d.

Reverted https://github.com/pytorch/pytorch/pull/165707 on behalf of https://github.com/clee2000 due to failing internal tests D85160820 ([comment](https://github.com/pytorch/pytorch/pull/165707#issuecomment-3429084393))
2025-10-21 19:25:03 +00:00
d1269a0434 update fr trace analysis (#165994)
Summary:
- allow empty entries from ranks
- allow not all ranks to provide dump

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/165994).
* #165638
* #165640
* #165642
* __->__ #165994
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165994
Approved by: https://github.com/fduwjj
2025-10-21 19:14:33 +00:00
c87cf1be32 Update workaround to old CUDA bug (#164354) (#165984)
The workaround cannot be removed because of BC. Here we'll
update PyTorch code base to not use the workaround.

See https://github.com/pytorch/pytorch/pull/164354 for the BC breakage issue.

Resolves https://github.com/pytorch/pytorch/issues/164348.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165984
Approved by: https://github.com/janeyx99
2025-10-21 19:09:43 +00:00
2fc5e45a41 better error message when there is no pytree impl (#165955)
Differential Revision: [D85117597](https://our.internmc.facebook.com/intern/diff/D85117597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165955
Approved by: https://github.com/avikchaudhuri
2025-10-21 18:49:22 +00:00
13 changed files with 237 additions and 174 deletions

View File

@ -1,15 +1,11 @@
sphinx==5.3.0
sphinx==7.2.6
#Description: This is used to generate PyTorch docs
#Pinned versions: 5.3.0
#Pinned versions: 7.2.6
standard-imghdr==3.13.0; python_version >= "3.13"
#Description: This is needed by Sphinx, so it needs to be added here.
# The reasons are as follows:
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
pytorch_sphinx_theme2==0.1.0
#Description: This is needed to generate PyTorch docs
#Pinned versions: 0.1.0
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later.
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0
breathe==4.34.0
breathe==4.36.0
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.34.0
#Pinned versions: 4.36.0
exhale==0.2.3
exhale==0.3.7
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.2.3
#Pinned versions: 0.3.7
docutils==0.16
docutils==0.20
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.16
#Pinned versions: 0.20
bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs
@ -56,13 +52,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0
myst-nb==0.17.2
myst-nb==1.3.0
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 0.17.2
#Pinned versions: 1.3.0
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5
sphinx-copybutton==0.5.0
sphinx-design==0.4.0
sphinx-design==0.6.1
sphinxcontrib-mermaid==1.0.0
myst-parser==0.18.1
myst-parser==4.0.1

View File

@ -102,8 +102,18 @@ if [ "$is_main_doc" = true ]; then
echo coverage output not found
exit 1
elif [ $undocumented -gt 0 ]; then
echo undocumented objects found:
echo "======================================"
echo "ERROR: $undocumented undocumented objects found!"
echo "======================================"
echo ""
echo "Full coverage report:"
cat build/coverage/python.txt
echo ""
echo "======================================"
echo "Undocumented modules/objects (lines after TOTAL):"
tail -n +$((lines - undocumented + 1)) build/coverage/python.txt
echo "======================================"
echo ""
echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
exit 1

View File

@ -120,7 +120,7 @@ static void pow_tensor_scalar_kernel(
} else if (dtype == ScalarType::Half) {
[&]() {
using scalar_t =
decltype(c10::impl::ScalarTypeToCPPType<ScalarType::Half>::t);
c10::impl::ScalarTypeToCPPTypeT<ScalarType::Half>;
const auto exp = exp_scalar.to<scalar_t>();
using Vec = Vectorized<scalar_t>;
cpu_kernel_vec(iter,

View File

@ -856,9 +856,13 @@ struct type_specialized_kernel_launcher {
out_calc_t output_offset_calculator,
loader_t loader,
storer_t storer) {
if (ret_t == rt_binary_specializations[arg_index][0] &&
arg0_t == rt_binary_specializations[arg_index][1] &&
arg1_t == rt_binary_specializations[arg_index][2])
constexpr ScalarType sret_t = rt_binary_specializations[arg_index][0];
constexpr ScalarType sarg0_t = rt_binary_specializations[arg_index][1];
constexpr ScalarType sarg1_t = rt_binary_specializations[arg_index][2];
if (ret_t == sret_t && arg0_t == sarg0_t && arg1_t == sarg1_t) {
using cret_t = c10::impl::ScalarTypeToCPPTypeT<sret_t>;
using carg0_t = c10::impl::ScalarTypeToCPPTypeT<sarg0_t>;
using carg1_t = c10::impl::ScalarTypeToCPPTypeT<sarg1_t>;
launch_vectorized_templated_kernel<
func_t,
array_t,
@ -866,12 +870,9 @@ struct type_specialized_kernel_launcher {
out_calc_t,
loader_t,
storer_t,
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][0]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][1]>::t),
decltype(c10::impl::ScalarTypeToCPPType<
rt_binary_specializations[arg_index][2]>::t)>(
cret_t,
carg0_t,
carg1_t>(
numel,
f,
data,
@ -879,6 +880,7 @@ struct type_specialized_kernel_launcher {
output_offset_calculator,
loader,
storer);
}
}
};

View File

@ -207,6 +207,42 @@ templates_path = [
]
# TODO: document these and remove them from here.
# Fixes the duplicated
autosummary_filename_map = {
"torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function",
"torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class",
"torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function",
"torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class",
"torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function",
"torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class",
"torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function",
"torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class",
"torch.optim.radam.radam": "torch.optim.radam.radam_function",
"torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class",
"torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function",
"torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class",
"torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function",
"torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class",
"torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function",
"torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class",
"torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function",
"torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class",
"torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function",
"torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class",
"torch.optim.adam.adam": "torch.optim.adam.adam_function",
"torch.optim.adam.Adam": "torch.optim.adam.Adam_class",
"torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function",
"torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class",
"torch.mtia.stream": "torch.mtia.stream_function",
"torch.mtia.Stream": "torch.mtia.Stream_class",
"torch.cpu.stream": "torch.cpu.stream_function",
"torch.cpu.Stream": "torch.cpu.Stream_class",
"torch.cuda.stream": "torch.cuda.stream_function",
"torch.cuda.Stream": "torch.cuda.Stream_class",
"torch.xpu.stream": "torch.xpu.stream_function",
"torch.xpu.Stream": "torch.xpu.Stream_class",
}
coverage_ignore_functions = [
# torch
"typename",
@ -3193,6 +3229,11 @@ autodoc_type_aliases = {
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings
autodoc_default_options = {
"exclude-members": "from_bytes, to_bytes",
}
# -- katex javascript in header
#
# def setup(app):

View File

@ -253,7 +253,6 @@ regular full-precision tensor.
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
view
as_strided

View File

@ -15192,6 +15192,25 @@ graph():
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
)
def test_invalid_pytree_dynamo_graph_capture(self):
class Block:
def __init__(self, a, b):
self.a = a
self.b = b
class Foo(torch.nn.Module):
def forward(self, block):
return block.a + block.b
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
with self.assertRaisesRegex(
torch._dynamo.exc.UserError, "It looks like one of the inputs with type"
):
_dynamo_graph_capture_for_export(Foo())(
Block(torch.randn(4, 4), torch.randn(4, 4))
)
def test_enum_str(self):
class TensorDim(str, enum.Enum):
DDP = "ddp"

View File

@ -17,7 +17,7 @@ from unittest.mock import patch, MagicMock, ANY
import math
import itertools
import torch.optim as optim
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, largeTensorTest
from torch.testing._internal.common_device_type import expectedFailureMPS, instantiate_device_type_tests, onlyCUDA, largeTensorTest
from typing import Optional
import torch.utils.cpp_extension
from torch.testing._internal.common_nn import NNTestCase
@ -2022,6 +2022,7 @@ class TestSDPA(NNTestCase):
for both cpu and cuda. If you're test is only applicable to cuda,
add it to TestSDPACudaOnly.
"""
@expectedFailureMPS # No double support
@parametrize("contiguous_inputs", [True, False])
def test_sdp_math_gradcheck(self, device, contiguous_inputs: bool):
@ -4625,13 +4626,13 @@ class TestAttnBias(NNTestCase):
scaled_dot_product_attention(query, key, value, attn_mask=attn_bias, is_causal=True, dropout_p=0.0)
if NOTEST_CPU:
device_types = ("cuda", )
device_types = ("cuda", "mps")
else:
device_types = ("cpu", "cuda")
device_types = ("cpu", "cuda", "mps")
instantiate_device_type_tests(TestTransformers, globals(), only_for=device_types)
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types)
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types)
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types, allow_mps=True)
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types, allow_mps=True)
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
instantiate_device_type_tests(TestSDPACpuOnly, globals(), only_for=("cpu"))
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)

View File

@ -754,6 +754,10 @@ def align_trace_from_beginning(
# Rank 3: [0, 1, 2, 3, 4, 5, None]
# Then we should start from collective 2 not 0 because any collective before,
# we don't have complete records from all ranks so we need to ignore them.
# If we don't have any trace from some ranks, ignore them
# as well.
if len(entries[rank]) == 0:
continue
first_record_id = entries[rank][0]["record_id"]
maximum_starting_record_id = max(maximum_starting_record_id, first_record_id)

View File

@ -1707,6 +1707,39 @@ def check_signature_rewritable(graph: torch.fx.GraphModule) -> None:
)
def check_user_input_output(flat_values: list[Any], error_type: UserErrorType) -> None:
supported_types = [
torch.Tensor,
torch.SymInt,
torch.SymFloat,
torch.SymBool,
torch._C.ScriptObject,
_IntWrapper,
] + list(common_constant_types)
def is_supported_type(val: Any) -> bool:
return isinstance(val, tuple(supported_types))
value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output"
# We only check that the outputs are not None. Inputs can be None.
for v in flat_values:
if not is_supported_type(v):
if error_type == UserErrorType.INVALID_INPUT and v is None:
continue
raise UserError(
error_type,
f"It looks like one of the {value_type}s with type `{type(v)}` "
"is not supported or pytree-flattenable. \n"
f"Exported graphs {value_type}s can only contain the "
f"following supported types: {supported_types}. \n"
"If you are using a custom class object, "
"please register a pytree_flatten/unflatten function "
"using `torch.utils._pytree.register_pytree_node` or "
"`torch.export.register_dataclass`.",
)
def rewrite_signature(
f_sig: inspect.Signature,
graph: torch.fx.GraphModule,
@ -1721,40 +1754,6 @@ def rewrite_signature(
) -> torch.fx.GraphModule:
orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec)
def check_user_input_output(
flat_values: list[Any], error_type: UserErrorType
) -> None:
supported_types = [
torch.Tensor,
torch.SymInt,
torch.SymFloat,
torch.SymBool,
torch._C.ScriptObject,
_IntWrapper,
] + list(common_constant_types)
def is_supported_type(val: Any) -> bool:
return isinstance(val, tuple(supported_types))
value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output"
# We only check that the outputs are not None. Inputs can be None.
for v in flat_values:
if not is_supported_type(v):
if error_type == UserErrorType.INVALID_INPUT and v is None:
continue
raise UserError(
error_type,
f"It looks like one of the {value_type}s with type `{type(v)}` "
"is not supported or pytree-flattenable. \n"
f"Exported graphs {value_type}s can only contain the "
f"following supported types: {supported_types}. \n"
"If you are using a custom class object, "
"please register a pytree_flatten/unflatten function "
"using `torch.utils._pytree.register_pytree_node` or "
"`torch.export.register_dataclass`.",
)
check_user_input_output(flat_args, UserErrorType.INVALID_INPUT)
flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result)
check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT)

View File

@ -10,7 +10,8 @@ import torch
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo.convert_frame import CaptureOutput, fullgraph_capture, get_traced_fn
from torch._dynamo.eval_frame import argument_names
from torch._dynamo.eval_frame import argument_names, check_user_input_output
from torch._dynamo.exc import UserErrorType
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._export.utils import _compiling_state_context
from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint
@ -479,6 +480,7 @@ def _dynamo_graph_capture_for_export(
# This sets the is_exporting flag when building guards.
with _compiling_state_context():
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
check_user_input_output(flat_inputs, UserErrorType.INVALID_INPUT)
module_to_trace = ModuleToTrace(mod, in_spec)
orig_callable = mod.forward if isinstance(mod, torch.nn.Module) else mod

View File

@ -200,10 +200,9 @@ class SuperVariable(VariableTracker):
and not (args or kwargs)
):
with do_not_convert_to_tracable_parameter():
fn_vt = VariableTracker.build(
tx, unpatched_nn_module_init, source=source
)
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
return variables.UserFunctionVariable(
unpatched_nn_module_init, source=source
).call_function(tx, [self.objvar] + args, kwargs)
else:
unimplemented_v2(
gb_type="Unsupported super().__init__() call",
@ -231,8 +230,9 @@ class SuperVariable(VariableTracker):
elif isinstance(inner_fn, staticmethod) and isinstance(
inner_fn.__func__, types.FunctionType
):
fn_vt = VariableTracker.build(tx, inner_fn.__func__, source=source)
return fn_vt.call_function(tx, args, kwargs)
return variables.UserFunctionVariable(
inner_fn.__func__, source=source
).call_function(tx, args, kwargs)
elif isinstance(inner_fn, classmethod) and isinstance(
inner_fn.__func__, types.FunctionType
):
@ -255,13 +255,13 @@ class SuperVariable(VariableTracker):
tx, self.objvar.value_type, cls_source
)
fn_vt = VariableTracker.build(
tx, inner_fn.__func__, source=AttrSource(source, "__func__")
)
return fn_vt.call_function(tx, [cls_variable, *args], kwargs)
return variables.UserFunctionVariable(
inner_fn.__func__, source=AttrSource(source, "__func__")
).call_function(tx, [cls_variable, *args], kwargs)
elif isinstance(inner_fn, types.FunctionType):
fn_vt = VariableTracker.build(tx, inner_fn, source=source)
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
return variables.UserFunctionVariable(
inner_fn, source=source
).call_function(tx, [self.objvar] + args, kwargs)
elif isinstance(inner_fn, types.MethodType):
return variables.UserMethodVariable(
inner_fn.__func__, self.objvar, source=source
@ -574,8 +574,10 @@ class ComptimeVariable(VariableTracker):
from ..comptime import comptime
# To support the comptime.print_graph convenience accessors
return VariableTracker.build(
tx, getattr(comptime, name), source=AttrSource(self.source, name)
from .functions import UserFunctionVariable
return UserFunctionVariable(
getattr(comptime, name), source=AttrSource(self.source, name)
)
def call_function(
@ -769,8 +771,9 @@ class AutogradFunctionVariable(VariableTracker):
sig = inspect.signature(fn)
if len(args) - 1 == len(sig._parameters):
args = args[1:] # Don't use context
fn_vt = VariableTracker.build(tx, fn, source=source)
return fn_vt.call_function(tx, args, kwargs)
return variables.UserFunctionVariable(fn, source=source).call_function(
tx, args, kwargs
)
elif isinstance(fn, types.MethodType):
return variables.UserMethodVariable(
fn.__func__,
@ -796,8 +799,9 @@ class AutogradFunctionVariable(VariableTracker):
assert isinstance(fn, types.FunctionType)
fn_source = AttrSource(self.source, "backward")
fn_vt = VariableTracker.build(tx, fn, source=fn_source)
return fn_vt.call_function(tx, args, kwargs)
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
tx, args, kwargs
)
def call_function(self, tx: "InstructionTranslator", args, kwargs):
return AutogradFunctionVariable(self.fn_cls)
@ -1022,12 +1026,10 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
assert tx.one_graph or tx.error_on_graph_break, (
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
)
fn_vt = VariableTracker.build(
tx,
return variables.UserFunctionVariable(
torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
source=self.source,
)
return fn_vt.call_function(
).call_function(
tx,
(tx.output.side_effects.get_ca_final_callbacks_var(), *args),
kwargs,

View File

@ -63,15 +63,15 @@ struct dummy_int1_7_t {};
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(c10::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn)
_(c10::BFloat16, BFloat16) \
_(c10::Float8_e5m2, Float8_e5m2) \
_(c10::Float8_e4m3fn, Float8_e4m3fn)
// This macro controls many of our C++ APIs, including constructors
// for Scalar as well as the data() and item() accessors on Tensor
@ -81,19 +81,19 @@ struct dummy_int1_7_t {};
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(c10::Half, Half) \
_(float, Float) \
_(double, Double) \
_(c10::complex<c10::Half>, ComplexHalf) \
_(c10::complex<float>, ComplexFloat) \
_(c10::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
_(c10::BFloat16, BFloat16) \
_(c10::Float8_e5m2, Float8_e5m2) \
_(c10::Float8_e4m3fn, Float8_e4m3fn) \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu)
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
@ -103,7 +103,7 @@ struct dummy_int1_7_t {};
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(c10::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
@ -113,7 +113,7 @@ struct dummy_int1_7_t {};
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
@ -176,24 +176,19 @@ struct dummy_int1_7_t {};
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE>::t), \
SCALARTYPE)
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE>, SCALARTYPE)
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2)
#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
SCALARTYPE1) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, SCALARTYPE2)
#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \
_(uint8_t, Byte) \
@ -203,53 +198,41 @@ struct dummy_int1_7_t {};
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3)
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE3>, SCALARTYPE3)
#define AT_FORALL_SCALAR_TYPES_AND7( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE1>::t), \
SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE2>::t), \
SCALARTYPE2) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE3>::t), \
SCALARTYPE3) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE4>::t), \
SCALARTYPE4) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE5>::t), \
SCALARTYPE5) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE6>::t), \
SCALARTYPE6) \
_(decltype(::c10::impl::ScalarTypeToCPPType< \
::c10::ScalarType::SCALARTYPE7>::t), \
SCALARTYPE7)
#define AT_FORALL_SCALAR_TYPES_AND7( \
SCALARTYPE1, \
SCALARTYPE2, \
SCALARTYPE3, \
SCALARTYPE4, \
SCALARTYPE5, \
SCALARTYPE6, \
SCALARTYPE7, \
_) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE1>, \
SCALARTYPE1) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE2>, \
SCALARTYPE2) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE3>, \
SCALARTYPE3) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE4>, \
SCALARTYPE4) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE5>, \
SCALARTYPE5) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE6>, \
SCALARTYPE6) \
_(c10::impl::ScalarTypeToCPPTypeT<c10::ScalarType::SCALARTYPE7>, SCALARTYPE7)
#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8) \
@ -258,12 +241,12 @@ struct dummy_int1_7_t {};
_(c10::quint4x2, QUInt4x2) \
_(c10::quint2x4, QUInt2x4)
#define AT_FORALL_FLOAT8_TYPES(_) \
_(at::Float8_e5m2, Float8_e5m2) \
_(at::Float8_e4m3fn, Float8_e4m3fn) \
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
#define AT_FORALL_FLOAT8_TYPES(_) \
_(c10::Float8_e5m2, Float8_e5m2) \
_(c10::Float8_e4m3fn, Float8_e4m3fn) \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu)
#define AT_FORALL_COMPLEX_TYPES(_) \
_(c10::complex<float>, ComplexFloat) \
@ -298,7 +281,12 @@ struct ScalarTypeToCPPType;
/* can't pick between at::detail and at::cuda::detail. */ \
/* For repro example, please see: */ \
/* https://gist.github.com/izdeby/952ae7cf256ddb740a73776d39a7e7ba */ \
/* TODO: remove once the bug is fixed. */ \
/* UPDATE: while the CUDA bug is fixed, we cannot remove the */ \
/* workaround as it is BC breaking. However, it is recommended to */ \
/* update any code that contains */ \
/* decltype(ScalarTypeToCPPType<T>::t) */ \
/* with */ \
/* ScalarTypeToCPPTypeT<T> */ \
static type t; \
};