Compare commits

..

2 Commits

Author SHA1 Message Date
94b993d7e4 linting 2025-11-13 17:18:35 -08:00
0c76b784d1 MPS: Fix clamp scalar cache key to store floats in hex representation 2025-11-13 15:58:58 -08:00
26 changed files with 102 additions and 220 deletions

View File

@ -337,6 +337,10 @@ Tensor _convolution_out(
TORCH_CHECK(
3 == ndim || 4 == ndim || 5 == ndim,
"convolution only supports 3D, 4D, 5D tensor");
// get computation format for Conv/TransposedConv
bool is_channels_last_suggested =
use_channels_last_for_conv(input_r, weight_r);
Tensor input = input_r, weight = weight_r;
// PyTorch does not support ChannelsLast1D case,
// thus we need the transformation here
@ -344,8 +348,13 @@ Tensor _convolution_out(
input = view4d(input_r);
weight = view4d(weight_r);
}
// get computation format for Conv/TransposedConv
bool is_channels_last_suggested = use_channels_last_for_conv(input, weight);
// ensure the input/weight/bias/output are congituous in desired format
at::MemoryFormat mfmt = is_channels_last_suggested
? get_cl_tag_by_ndim(input.ndimension())
: at::MemoryFormat::Contiguous;
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
input = input.contiguous(mfmt);
weight = weight.contiguous(mfmt);
auto k = weight.ndimension();
if (k == input.ndimension() + 1) {
@ -379,14 +388,6 @@ Tensor _convolution_out(
expand_param_if_needed(output_padding_, "output_padding", dim);
params.groups = groups_;
}
// ensure the input/weight/bias/output are congituous in desired format
at::MemoryFormat mfmt = is_channels_last_suggested
? get_cl_tag_by_ndim(input.ndimension())
: at::MemoryFormat::Contiguous;
auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r;
input = input.contiguous(mfmt);
weight = weight.contiguous(mfmt);
check_shape_forward(input, weight, bias, params, true);
Tensor output;
@ -513,9 +514,18 @@ Tensor convolution_overrideable(
at::borrow_from_optional_tensor(bias_r_opt);
const Tensor& bias_r = *bias_r_maybe_owned;
auto k = weight_r.ndimension();
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
if (xpu_conv_use_channels_last(input_r, weight_r)) {
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d
: at::MemoryFormat::ChannelsLast;
}
Tensor input_c = input_r.contiguous(backend_memory_format);
Tensor weight_c = weight_r.contiguous(backend_memory_format);
return _convolution(
input_r,
weight_r,
input_c,
weight_c,
bias_r,
stride_,
padding_,

View File

@ -82,6 +82,7 @@ NSArray<NSNumber*>* getTensorAxes(const TensorBase& t);
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false);
std::string to_hex_key(float);
std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);

View File

@ -301,6 +301,10 @@ std::string getArrayRefString(const IntArrayRef s) {
return fmt::to_string(fmt::join(s, ","));
}
std::string to_hex_key(float f) {
return fmt::format("{:a}", f);
}
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype, bool exclude_shape) {
fmt::basic_memory_buffer<char, 100> buffer;
auto buf_iterator = std::back_inserter(buffer);

View File

@ -244,8 +244,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
@autoreleasepool {
// the optional min/max refs could affect how we build the cached graph
std::string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
std::string key = op_name + (has_min ? ("_min:" + to_hex_key(min_scalar)) : "") +
(has_max ? ("_max:" + to_hex_key(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
if (has_min)
newCachedGraph->minTensor = [mpsGraph constantWithScalar:min_scalar

View File

@ -743,19 +743,16 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
self.assertTrue(tail_log.stopped())
def test_binary_duplicate_log_filters(self):
envs = {0: {"RANK": "0"}, 1: {"RANK": "1"}}
logs_specs = DefaultLogsSpecs(
log_dir=self.log_dir(),
redirects={0: Std.ERR, 1: Std.NONE},
tee={0: Std.OUT, 1: Std.ERR},
)
logs_dest = logs_specs.reify(envs)
pc = start_processes(
name="trainer",
entrypoint=bin("echo1.py"),
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
envs=envs,
logs_specs=logs_specs,
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(),
redirects={0: Std.ERR, 1: Std.NONE},
tee={0: Std.OUT, 1: Std.ERR},
),
log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"},
duplicate_stdout_filters=["helloA"],
duplicate_stderr_filters=["worldA", "B"],
@ -765,18 +762,12 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
result = pc.wait()
self.assertFalse(result.is_failed())
self.assert_in_file(
["[rank0]:helloA stdout from 0"], logs_dest.filtered_stdout
)
self.assert_in_file(["[rank0]:helloA stdout from 0"], pc.filtered_stdout)
self.assert_not_in_file(
["[rank0]:helloB stdout from 0"], logs_dest.filtered_stdout
)
self.assert_in_file(
["[rank1]:worldA stderr from 1"], logs_dest.filtered_stderr
)
self.assert_in_file(
["[rank1]:worldB stderr from 1"], logs_dest.filtered_stderr
["[rank0]:helloB stdout from 0"], pc.filtered_stdout
)
self.assert_in_file(["[rank1]:worldA stderr from 1"], pc.filtered_stderr)
self.assert_in_file(["[rank1]:worldB stderr from 1"], pc.filtered_stderr)
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())
@ -847,19 +838,16 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
def test_function_duplicate_log_filters(self):
for start_method in self._start_methods:
with self.subTest(start_method=start_method):
envs = {0: {"RANK": "0"}, 1: {"RANK": "1"}}
logs_specs = DefaultLogsSpecs(
log_dir=self.log_dir(),
redirects={0: Std.ERR, 1: Std.NONE},
tee={0: Std.OUT, 1: Std.ERR},
)
logs_dest = logs_specs.reify(envs)
pc = start_processes(
name="trainer",
entrypoint=echo1,
args={0: ("helloA,helloB",), 1: ("worldA,worldB",)},
envs=envs,
logs_specs=logs_specs,
envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
logs_specs=DefaultLogsSpecs(
log_dir=self.log_dir(),
redirects={0: Std.ERR, 1: Std.NONE},
tee={0: Std.OUT, 1: Std.ERR},
),
duplicate_stdout_filters=["helloA"],
duplicate_stderr_filters=["worldA", "B"],
start_method="spawn",
@ -869,16 +857,16 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
self.assertFalse(result.is_failed())
self.assert_in_file(
["[trainer0]:helloA stdout from 0"], logs_dest.filtered_stdout
["[trainer0]:helloA stdout from 0"], pc.filtered_stdout
)
self.assert_not_in_file(
["[trainer0]:helloB stdout from 0"], logs_dest.filtered_stdout
["[trainer0]:helloB stdout from 0"], pc.filtered_stdout
)
self.assert_in_file(
["[trainer1]:worldA stderr from 1"], logs_dest.filtered_stderr
["[trainer1]:worldA stderr from 1"], pc.filtered_stderr
)
self.assert_in_file(
["[trainer1]:worldB stderr from 1"], logs_dest.filtered_stderr
["[trainer1]:worldB stderr from 1"], pc.filtered_stderr
)
for tail_log in pc._tail_logs:
self.assertTrue(tail_log.stopped())

View File

@ -100,9 +100,8 @@ class TailLogTest(unittest.TestCase):
}
dst = os.path.join(self.test_dir, "tailed_stdout.log")
dst_file = open(dst, "w", buffering=1)
tail = TailLog(
name="writer", log_files=log_files, dst=dst_file, interval_sec=interval_sec
name="writer", log_files=log_files, dst=dst, interval_sec=interval_sec
).start()
# sleep here is intentional to ensure that the log tail
# can gracefully handle and wait for non-existent log files
@ -118,11 +117,10 @@ class TailLogTest(unittest.TestCase):
wait(futs, return_when=ALL_COMPLETED)
self.assertFalse(tail.stopped())
tail.stop()
dst_file.close()
actual: dict[int, set[int]] = {}
with open(dst) as read_dst_file:
for line in read_dst_file:
with open(dst) as dst_file:
for line in dst_file:
header, num = line.split(":")
nums = actual.setdefault(header, set())
nums.add(int(num))
@ -258,4 +256,4 @@ class TailLogTest(unittest.TestCase):
tail = TailLog("writer", log_files={0: self.test_dir}, dst=sys.stdout).start()
tail.stop()
mock_logger.exception.assert_called_once()
mock_logger.error.assert_called_once()

View File

@ -16271,10 +16271,8 @@ for test in criterion_tests:
add_nn_module_test(**test)
if __name__ == '__main__':
if sys.version_info < (3, 14):
TestCase._default_dtype_check_enabled = True
run_tests()
import jit.test_module_interface
suite = unittest.findTestCases(jit.test_module_interface)
unittest.TextTestRunner().run(suite)
TestCase._default_dtype_check_enabled = True
run_tests()
import jit.test_module_interface
suite = unittest.findTestCases(jit.test_module_interface)
unittest.TextTestRunner().run(suite)

View File

@ -4,7 +4,6 @@ import torch
from torch.cuda.amp import autocast
from typing import Optional
import sys
import unittest
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo
@ -962,5 +961,4 @@ class TestJitTraceAutocast(JitTestCase):
if __name__ == "__main__":
if sys.version_info < (3, 14):
run_tests()
run_tests()

View File

@ -88,5 +88,4 @@ print("Didn't throw exception")
self.compare_enabled_disabled(_program_string)
if __name__ == '__main__':
if sys.version_info < (3, 14):
run_tests()
run_tests()

View File

@ -995,5 +995,4 @@ class TestFuser(JitTestCase):
if __name__ == '__main__':
if sys.version_info < (3, 14):
run_tests()
run_tests()

View File

@ -13,5 +13,4 @@ if __name__ == "__main__":
from test_jit_fuser import * # noqa: F403
if __name__ == '__main__':
if sys.version_info < (3, 14):
run_tests()
run_tests()

View File

@ -5,7 +5,6 @@ import contextlib
import math
import operator
import os
import sys
import unittest
import warnings
@ -3048,5 +3047,4 @@ instantiate_device_type_tests(TestLoopnestRandomization, globals(), only_for=("c
if __name__ == "__main__":
if sys.version_info < (3, 14):
run_tests()
run_tests()

View File

@ -12,5 +12,4 @@ if __name__ == '__main__':
from test_jit import * # noqa: F403, F401
if __name__ == '__main__':
if sys.version_info < (3, 14):
run_tests()
run_tests()

View File

@ -854,5 +854,4 @@ instantiate_device_type_tests(TestFusionPattern, globals())
instantiate_device_type_tests(TestOp, globals())
if __name__ == '__main__':
if sys.version_info < (3, 14):
run_tests()
run_tests()

View File

@ -5,5 +5,4 @@ sys.argv.append("--jit-executor=profiling")
from test_jit import * # noqa: F403
if __name__ == '__main__':
if sys.version_info < (3, 14):
run_tests()
run_tests()

View File

@ -5,5 +5,4 @@ sys.argv.append("--jit-executor=simple")
from test_jit import * # noqa: F403
if __name__ == '__main__':
if sys.version_info < (3, 14):
run_tests()
run_tests()

View File

@ -1,6 +1,5 @@
# Owner(s): ["oncall: jit"]
import sys
from test_jit import JitTestCase
from torch.testing._internal.common_utils import run_tests
@ -330,5 +329,4 @@ class TestScript(JitTestCase):
self.checkScript(test_slice, ("hellotest",))
if __name__ == '__main__':
if sys.version_info < (3, 14):
run_tests()
run_tests()

View File

@ -367,31 +367,6 @@ class TestConvolutionNNDeviceType(NNTestCase):
no_weight,
)
@dtypes(torch.float)
def test_conv1d_large_input(self, device, dtype):
N, C_in, L = 4, 512, 441
C_out, K, P = 512, 3, 1
torch.manual_seed(42)
conv_cpu = (
nn.Conv1d(C_in, C_out, kernel_size=K, padding=P, bias=True)
.to(torch.float32)
.requires_grad_()
)
x_cpu = torch.randn(N, C_in, L, dtype=torch.float32)
out_cpu = conv_cpu(x_cpu)
conv_dev = nn.Conv1d(C_in, C_out, kernel_size=K, padding=P, bias=True).to(
device, dtype
)
conv_dev.weight.data.copy_(conv_cpu.weight.data.to(dtype))
conv_dev.bias.data.copy_(conv_cpu.bias.data.to(dtype))
x_dev = x_cpu.to(device, dtype).requires_grad_()
out_dev = conv_dev(x_dev)
self.assertEqual(out_cpu, out_dev, atol=1e-5, rtol=1e-5, exact_device=False)
@dtypes(torch.float)
def test_conv1d_same_padding(self, device, dtype):
test_args = [

2
third_party/xpu.txt vendored
View File

@ -1 +1 @@
1e69f40b3c03492eb3dd7e03462a5566f29674d3
9aac5a1ddf50d75f929d572df51bb368b32da14e

View File

@ -25,7 +25,7 @@ from dataclasses import dataclass, field
from enum import IntFlag
from multiprocessing import synchronize
from types import FrameType
from typing import Any, Optional, TextIO, Union
from typing import Any, Optional, Union
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
@ -491,8 +491,8 @@ class PContext(abc.ABC):
self.stderrs = logs_dest.stderrs
self.error_files = logs_dest.error_files
self.nprocs = nprocs
self.filtered_stdout: Optional[TextIO] = None
self.filtered_stderr: Optional[TextIO] = None
self.filtered_stdout = logs_dest.filtered_stdout
self.filtered_stderr = logs_dest.filtered_stderr
self._tail_logs = [
TailLog(name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes),
@ -500,9 +500,6 @@ class PContext(abc.ABC):
]
if duplicate_stdout_filters:
self.filtered_stdout = open(
logs_dest.filtered_stdout, mode="w", errors="replace", buffering=1
)
self._tail_logs.append(
TailLog(
name,
@ -516,9 +513,6 @@ class PContext(abc.ABC):
)
if duplicate_stderr_filters:
self.filtered_stderr = open(
logs_dest.filtered_stderr, mode="w", errors="replace", buffering=1
)
self._tail_logs.append(
TailLog(
name,
@ -663,10 +657,6 @@ class PContext(abc.ABC):
self._close(death_sig=death_sig, timeout=timeout)
for tail_log in self._tail_logs:
tail_log.stop()
if self.filtered_stdout:
self.filtered_stdout.close()
if self.filtered_stderr:
self.filtered_stderr.close()
def get_std_cm(std_rd: str, redirect_fn):

View File

@ -13,11 +13,12 @@ import time
from collections.abc import Callable
from concurrent.futures.thread import ThreadPoolExecutor
from threading import Event
from typing import Optional, TextIO, TYPE_CHECKING
from typing import Optional, TextIO, TYPE_CHECKING, Union
if TYPE_CHECKING:
from concurrent.futures._base import Future
from io import TextIOWrapper
__all__ = ["tail_logfile", "TailLog"]
@ -97,7 +98,7 @@ class TailLog:
self,
name: str,
log_files: dict[int, str],
dst: TextIO,
dst: Union[TextIO, str],
log_line_prefixes: Optional[dict[int, str]] = None,
interval_sec: float = 0.1,
log_line_filter: Callable[[str], bool] = (lambda _: True),
@ -112,7 +113,19 @@ class TailLog:
)
self._name = name
self._dst = dst
self._dst_file: Optional[TextIOWrapper] = None
self._dst: Optional[Union[TextIO, TextIOWrapper]] = None
if isinstance(dst, str):
try:
self._dst_file = open(dst, mode="w", errors="replace")
self._dst = self._dst_file
except Exception:
logger.exception("error opening dst file %s.", dst)
self._dst = None
self._dst_file = None
else:
self._dst = dst
self._log_files = log_files
self._log_line_prefixes = log_line_prefixes
self._log_line_filter = log_line_filter
@ -162,6 +175,9 @@ class TailLog:
if self._threadpool:
self._threadpool.shutdown(wait=True)
if self._dst_file:
self._dst_file.close()
self._stopped = True
def stopped(self) -> bool:

View File

@ -7337,14 +7337,19 @@ class ShapeEnv:
if insts[cur].opname in ("TO_BOOL", "COMPARE_OP"):
# Peek 1 instruction further.
cur += 1
inst = insts[cur]
assert_insts = torch._dynamo.symbolic_convert.get_assert_bytecode_sequence(
False
)
if inst.opname == "POP_JUMP_IF_TRUE" and inst.arg is not None:
first = insts[cur + 1]
cur_insts = insts[cur + 1 : cur + 1 + len(assert_insts)]
cur_insts = [inst.opname for inst in cur_insts]
return cur_insts == assert_insts
starts_with_assert = (
first.opname == "LOAD_GLOBAL"
and first.argval == "AssertionError"
or first.opname == "LOAD_ASSERTION_ERROR"
)
if starts_with_assert and insts[cur + 2].opname == "RAISE_VARARGS":
return True
return False
def _log_real_tensor_propagation(
self, orig_expr: sympy.Basic, unsound_result: sympy.Basic

View File

@ -13,7 +13,6 @@ import enum
import functools
import inspect
import pickle
import sys
import warnings
from collections.abc import Callable
from typing import Any, Union
@ -352,17 +351,6 @@ class ScriptWarning(Warning):
def script_method(fn):
if sys.version_info >= (3, 14):
warnings.warn(
"`torch.jit.script_method` is not supported in Python 3.14+ and may break. "
"Please switch to `torch.compile` or `torch.export`.",
DeprecationWarning,
)
else:
warnings.warn(
"`torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.",
DeprecationWarning,
)
if not _enabled:
return fn
# NOTE: we need to traverse two frames here because the meta-class frame
@ -1470,17 +1458,6 @@ def script(
# Run the scripted_model with actual inputs
print(scripted_model([20]))
"""
if sys.version_info >= (3, 14):
warnings.warn(
"`torch.jit.script` is not supported in Python 3.14+ and may break. "
"Please switch to `torch.compile` or `torch.export`.",
DeprecationWarning,
)
else:
warnings.warn(
"`torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.",
DeprecationWarning,
)
if not _enabled:
return obj
try:

View File

@ -10,8 +10,6 @@ functionalities in `torch.jit`.
"""
import os
import sys
import warnings
import torch
from torch._jit_internal import _get_model_id
@ -79,17 +77,6 @@ def save(m, f, _extra_files=None) -> None:
extra_files = {'foo.txt': b'bar'}
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
"""
if sys.version_info >= (3, 14):
warnings.warn(
"`torch.jit.save` is not supported in Python 3.14+ and may break. "
"Please switch to `torch.export`.",
DeprecationWarning,
)
else:
warnings.warn(
"`torch.jit.save` is deprecated. Please switch to `torch.export`.",
DeprecationWarning,
)
log_torchscript_usage("save", model_id=_get_model_id(m))
if _extra_files is None:
_extra_files = {}
@ -166,17 +153,6 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
import os
os.remove("scriptmodule.pt")
"""
if sys.version_info >= (3, 14):
warnings.warn(
"`torch.jit.load` is not supported in Python 3.14+ and may break. "
"Please switch to `torch.export`.",
DeprecationWarning,
)
else:
warnings.warn(
"`torch.jit.load` is deprecated. Please switch to `torch.export`.",
DeprecationWarning,
)
if isinstance(f, (str, os.PathLike)):
if not os.path.exists(f):
raise ValueError(f"The provided filename {f} does not exist")

View File

@ -15,7 +15,6 @@ import functools
import inspect
import os
import re
import sys
import warnings
from collections.abc import Callable
from enum import Enum
@ -990,17 +989,6 @@ def trace(
module = torch.jit.trace(n, example_forward_input)
"""
if sys.version_info >= (3, 14):
warnings.warn(
"`torch.jit.trace` is not supported in Python 3.14+ and may break. "
"Please switch to `torch.compile` or `torch.export`.",
DeprecationWarning,
)
else:
warnings.warn(
"`torch.jit.trace` is deprecated. Please switch to `torch.compile` or `torch.export`.",
DeprecationWarning,
)
if not _enabled:
return func
if optimize is not None:
@ -1129,17 +1117,6 @@ def trace_module(
module = torch.jit.trace_module(n, inputs)
"""
if sys.version_info >= (3, 14):
warnings.warn(
"`torch.jit.trace_method` is not supported in Python 3.14+ and may break. "
"Please switch to `torch.compile` or `torch.export`.",
DeprecationWarning,
)
else:
warnings.warn(
"`torch.jit.trace_method` is deprecated. Please switch to `torch.compile` or `torch.export`.",
DeprecationWarning,
)
if not _enabled:
return mod
if optimize is not None:

View File

@ -648,16 +648,6 @@ op_db: list[OpInfo] = [
dtypes=all_types_and(torch.bool),
ref=lambda x: scipy.special.spherical_jn(0, x) if TEST_SCIPY else None,
supports_autograd=False,
skips=(
DecorateInfo(
unittest.skip(
"Scipy doesn't support bool inputs to spherical_bessel_j0"
),
"TestUnaryUfuncs",
"test_reference_numerics_normal",
dtypes=(torch.bool,),
),
),
),
]
@ -778,16 +768,6 @@ python_ref_db: list[OpInfo] = [
}
),
),
skips=(
DecorateInfo(
unittest.skip(
"Scipy doesn't support bool inputs to spherical_bessel_j0"
),
"TestUnaryUfuncs",
"test_reference_numerics_normal",
dtypes=(torch.bool,),
),
),
),
#
# Elementwise Binary Special OpInfos