mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 07:04:52 +08:00
Compare commits
13 Commits
ciflow/tru
...
csl/manual
| Author | SHA1 | Date | |
|---|---|---|---|
| 9b1953d29e | |||
| 874ba0a6b4 | |||
| 1868ca8b6a | |||
| 2228067938 | |||
| 56ea1aec79 | |||
| c8324d30f6 | |||
| 26568b469d | |||
| 3e7188c5e7 | |||
| 6a3694adb0 | |||
| 24c259cd6c | |||
| 9cc873e582 | |||
| cb574c7adb | |||
| 44cfeec317 |
@ -308,12 +308,16 @@ class StepcurrentPlugin:
|
||||
self.report_status = ""
|
||||
assert config.cache is not None
|
||||
self.cache: pytest.Cache = config.cache
|
||||
self.directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}"
|
||||
self.lastrun: Optional[str] = self.cache.get(self.directory, None)
|
||||
directory = f"{STEPCURRENT_CACHE_DIR}/{config.getoption('stepcurrent')}"
|
||||
self.lastrun_location = f"{directory}/lastrun"
|
||||
self.lastrun: Optional[str] = self.cache.get(self.lastrun_location, None)
|
||||
self.initial_val = self.lastrun
|
||||
self.skip: bool = config.getoption("stepcurrent_skip")
|
||||
self.run_single: bool = config.getoption("run_single")
|
||||
|
||||
self.made_failing_xml_location = f"{directory}/made_failing_xml"
|
||||
self.cache.set(self.made_failing_xml_location, False)
|
||||
|
||||
def pytest_collection_modifyitems(self, config: Config, items: list[Any]) -> None:
|
||||
if not self.lastrun:
|
||||
self.report_status = "Cannot find last run test, not skipping"
|
||||
@ -349,8 +353,10 @@ class StepcurrentPlugin:
|
||||
|
||||
def pytest_runtest_protocol(self, item, nextitem) -> None:
|
||||
self.lastrun = item.nodeid
|
||||
self.cache.set(self.directory, self.lastrun)
|
||||
self.cache.set(self.lastrun_location, self.lastrun)
|
||||
|
||||
def pytest_sessionfinish(self, session, exitstatus):
|
||||
if exitstatus == 0:
|
||||
self.cache.set(self.directory, self.initial_val)
|
||||
self.cache.set(self.lastrun_location, self.initial_val)
|
||||
if exitstatus != 0:
|
||||
self.cache.set(self.made_failing_xml_location, True)
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
import random
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
@ -52,9 +51,6 @@ from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def reset_rng_state():
|
||||
torch.manual_seed(1337)
|
||||
random.seed(1337)
|
||||
@ -1204,116 +1200,6 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
for r in res[1:]:
|
||||
self.assertEqual(res[0], r)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@patch.object(torch._dynamo.config, "enable_compiler_collectives", True)
|
||||
@patch.object(torch._inductor.config, "max_autotune_gemm", True)
|
||||
@patch.object(torch._inductor.config, "distributed_max_autotune_gemm", True)
|
||||
def test_multiproc_autotune(self):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
torch._dynamo.utils.clear_compilation_metrics()
|
||||
|
||||
@torch.compile()
|
||||
def f(a, b, c):
|
||||
res = (
|
||||
torch.sum((a @ b) + 1.0)
|
||||
+ torch.sum(torch.relu(b @ c))
|
||||
+ torch.sum(c @ a)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
a = torch.randn(1024, 1024, device=self.rank, dtype=torch.bfloat16)
|
||||
b = torch.randn(1024, 2048, device=self.rank, dtype=torch.bfloat16)
|
||||
c = torch.randn(2048, 1024, device=self.rank, dtype=torch.bfloat16)
|
||||
|
||||
try:
|
||||
f(a, b, c)
|
||||
except Exception:
|
||||
log.exception("Caught exception running f")
|
||||
raise
|
||||
|
||||
metrics = torch._dynamo.utils.get_compilation_metrics()
|
||||
res = [None] * self.world_size
|
||||
torch.distributed.all_gather_object(res, len(metrics))
|
||||
for r in res[1:]:
|
||||
self.assertEqual(res[0], r)
|
||||
|
||||
print(f"Result from {self.rank} is {f(a, b, c)}")
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@patch.object(torch._dynamo.config, "enable_compiler_collectives", True)
|
||||
@patch.object(torch._inductor.config, "max_autotune_gemm", True)
|
||||
@patch.object(torch._inductor.config, "distributed_max_autotune_gemm", True)
|
||||
def test_multiproc_autotune_dynamic_shapes(self):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
torch._dynamo.utils.clear_compilation_metrics()
|
||||
|
||||
@torch.compile()
|
||||
def f(a, b, c):
|
||||
res = (
|
||||
torch.sum((a @ b) + 1.0)
|
||||
+ torch.sum(torch.relu(b @ c))
|
||||
+ torch.sum(c @ a)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
a = torch.randn(1024, 1024, device=self.rank, dtype=torch.bfloat16)
|
||||
b = torch.randn(1024, 2048, device=self.rank, dtype=torch.bfloat16)
|
||||
c = torch.randn(2048, 1024, device=self.rank, dtype=torch.bfloat16)
|
||||
|
||||
# Mark tensors as dynamic on dimension 0
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(a, 1)
|
||||
torch._dynamo.mark_dynamic(b, 0)
|
||||
torch._dynamo.mark_dynamic(b, 1)
|
||||
torch._dynamo.mark_dynamic(c, 0)
|
||||
torch._dynamo.mark_dynamic(c, 1)
|
||||
|
||||
try:
|
||||
f(a, b, c)
|
||||
except Exception:
|
||||
log.exception("Caught exception running f")
|
||||
raise
|
||||
|
||||
metrics = torch._dynamo.utils.get_compilation_metrics()
|
||||
res = [None] * self.world_size
|
||||
torch.distributed.all_gather_object(res, len(metrics))
|
||||
for r in res[1:]:
|
||||
self.assertEqual(res[0], r)
|
||||
|
||||
print(f"Result from {self.rank} is {f(a, b, c)}")
|
||||
|
||||
# Store the initial compilation count
|
||||
initial_compile_count = len(metrics)
|
||||
|
||||
# # Test with different sizes to ensure dynamic shapes work without recompilation
|
||||
a2 = torch.randn(512, 512, device=self.rank, dtype=torch.bfloat16)
|
||||
b2 = torch.randn(512, 2048, device=self.rank, dtype=torch.bfloat16)
|
||||
c2 = torch.randn(2048, 512, device=self.rank, dtype=torch.bfloat16)
|
||||
|
||||
try:
|
||||
result2 = f(a2, b2, c2)
|
||||
print(f"Result2 from {self.rank} is {result2}")
|
||||
except Exception:
|
||||
log.exception("Caught exception running f with different sizes")
|
||||
raise
|
||||
|
||||
# Verify no recompilation occurred
|
||||
metrics_after = torch._dynamo.utils.get_compilation_metrics()
|
||||
final_compile_count = len(metrics_after)
|
||||
self.assertEqual(
|
||||
initial_compile_count,
|
||||
final_compile_count,
|
||||
"Expected no recompilation with dynamic shapes",
|
||||
)
|
||||
|
||||
# Verify all ranks have the same compilation count
|
||||
res_after = [None] * self.world_size
|
||||
torch.distributed.all_gather_object(res_after, final_compile_count)
|
||||
for r in res_after[1:]:
|
||||
self.assertEqual(res_after[0], r)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_get_pg_attr(self):
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
|
||||
@ -1428,170 +1428,6 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
||||
|
||||
self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3])))
|
||||
|
||||
def test_check_compiles_when_predicate_true_and_message_has_no_closure(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_check_compiles_when_predicate_true_constant_and_message_has_no_closure(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_check_compiles_when_predicate_true_constant_and_message_None(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3)
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_check_compiles_when_predicate_true_and_message_None(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3)
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_check_compiles_when_predicate_true_and_message_has_global(self):
|
||||
global GLOBAL_INT
|
||||
GLOBAL_INT = 1
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: f"{GLOBAL_INT} is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
y = f(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_check_raises_at_runtime_when_predicate_false_and_message_has_global(self):
|
||||
global GLOBAL_INT
|
||||
GLOBAL_INT = 1
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: f"{GLOBAL_INT} is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, f"{GLOBAL_INT} is not greater than 3"
|
||||
):
|
||||
f(x)
|
||||
|
||||
def test_check_raises_at_runtime_when_predicate_false_and_message_None(self):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3)
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, None):
|
||||
f(x)
|
||||
|
||||
def test_check_raises_at_runtime_when_predicate_false_constant_and_message_None(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3)
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, None):
|
||||
f(x)
|
||||
|
||||
def test_check_raises_at_runtime_when_predicate_false_and_message_has_no_closure(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Shape is not greater than 3"):
|
||||
f(x)
|
||||
|
||||
def test_check_raises_at_runtime_when_predicate_false_constant_and_message_has_no_closure(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Shape is not greater than 3"):
|
||||
f(x)
|
||||
|
||||
def test_check_assert_error_at_runtime_when_predicate_false_and_message_has_closure(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: f"{x.shape[0]} is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(3)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "Can't extract message from torch._check()"
|
||||
):
|
||||
f(x)
|
||||
|
||||
def test_check_assert_error_at_runtime_when_predicate_true_and_message_has_closure(
|
||||
self,
|
||||
):
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def f(x):
|
||||
torch._check(x.shape[0] > 3, lambda: f"{x.shape[0]} is not greater than 3")
|
||||
return x + 1
|
||||
|
||||
x = torch.randn(4)
|
||||
torch._dynamo.maybe_mark_dynamic(x, 0)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "Can't extract message from torch._check()"
|
||||
):
|
||||
f(x)
|
||||
|
||||
def test_assert(self):
|
||||
@torch.compile
|
||||
def fn1(x):
|
||||
|
||||
@ -78,6 +78,7 @@ from tools.testing.test_selections import (
|
||||
try:
|
||||
from tools.testing.upload_artifacts import (
|
||||
parse_xml_and_upload_json,
|
||||
upload_adhoc_failure_json,
|
||||
zip_and_upload_artifacts,
|
||||
)
|
||||
except ImportError:
|
||||
@ -87,7 +88,10 @@ except ImportError:
|
||||
def parse_xml_and_upload_json():
|
||||
pass
|
||||
|
||||
def zip_and_upload_artifacts(failed: bool):
|
||||
def zip_and_upload_artifacts(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def upload_adhoc_failure_json(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
@ -642,6 +646,7 @@ def run_test(
|
||||
output,
|
||||
options.continue_through_error,
|
||||
test_file,
|
||||
options,
|
||||
)
|
||||
else:
|
||||
command.extend([f"--sc={stepcurrent_key}", "--print-items"])
|
||||
@ -728,6 +733,7 @@ def run_test_retries(
|
||||
output,
|
||||
continue_through_error,
|
||||
test_file,
|
||||
options,
|
||||
):
|
||||
# Run the test with -x to stop at first failure. Rerun the test by itself.
|
||||
# If it succeeds, move on to the rest of the tests in a new process. If it
|
||||
@ -746,6 +752,16 @@ def run_test_retries(
|
||||
|
||||
num_failures = defaultdict(int)
|
||||
|
||||
def read_pytest_cache(key: str) -> Any:
|
||||
cache_file = (
|
||||
REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key / key
|
||||
)
|
||||
try:
|
||||
with open(cache_file) as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
print_items = ["--print-items"]
|
||||
sc_command = f"--sc={stepcurrent_key}"
|
||||
while True:
|
||||
@ -766,12 +782,11 @@ def run_test_retries(
|
||||
|
||||
# Read what just failed/ran
|
||||
try:
|
||||
with open(
|
||||
REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key
|
||||
) as f:
|
||||
current_failure = f.read()
|
||||
if current_failure == "null":
|
||||
current_failure = f"'{test_file}'"
|
||||
current_failure = read_pytest_cache("lastrun")
|
||||
if current_failure is None:
|
||||
raise FileNotFoundError
|
||||
if current_failure == "null":
|
||||
current_failure = f"'{test_file}'"
|
||||
except FileNotFoundError:
|
||||
print_to_file(
|
||||
"No stepcurrent file found. Either pytest didn't get to run (e.g. import error)"
|
||||
@ -794,6 +809,13 @@ def run_test_retries(
|
||||
# This is for log classifier so it can prioritize consistently
|
||||
# failing tests instead of reruns. [1:-1] to remove quotes
|
||||
print_to_file(f"FAILED CONSISTENTLY: {current_failure[1:-1]}")
|
||||
if (
|
||||
read_pytest_cache("made_failing_xml") == "false"
|
||||
and IS_CI
|
||||
and options.upload_artifacts_while_running
|
||||
):
|
||||
upload_adhoc_failure_json(test_file, current_failure[1:-1])
|
||||
|
||||
if not continue_through_error:
|
||||
print_to_file("Stopping at first consistent failure")
|
||||
break
|
||||
|
||||
@ -4465,54 +4465,6 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
||||
res = f(x, start, 0)
|
||||
self.assertEqual(res.shape, torch.Size([0]))
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
@torch.fx.experimental._config.patch("backed_size_oblivious", True)
|
||||
def test_backed_size_oblivious_broadcast(self):
|
||||
cnt = CompileCounterWithBackend("inductor")
|
||||
torch._dynamo.reset()
|
||||
|
||||
def func(a, b):
|
||||
torch.broadcast_shapes(a.size(), b.size())
|
||||
return a + b
|
||||
|
||||
compiled = torch.compile(func, fullgraph=True, backend=cnt, dynamic=True)
|
||||
|
||||
def run(a, b):
|
||||
self.assertEqual(compiled(a, b), func(a, b))
|
||||
|
||||
# No 0/1 specializations, no broadcasts.
|
||||
# but a[0] == b[0] and a[1] == b[1] are asserted.
|
||||
run(torch.rand(1, 10), torch.rand(1, 10))
|
||||
run(torch.rand(1, 1), torch.rand(1, 1))
|
||||
run(torch.rand(10, 10), torch.rand(10, 10))
|
||||
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
run(torch.rand(10, 10), torch.rand(1, 10))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
cnt.clear()
|
||||
torch._dynamo.reset()
|
||||
|
||||
# specialize a[0] == 1. b[0] not specialized.
|
||||
run(torch.rand(1, 10), torch.rand(9, 10))
|
||||
run(torch.rand(1, 10), torch.rand(1, 10))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
# if we change a[0] we get recompilation.
|
||||
run(torch.rand(10, 10), torch.rand(10, 10))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
cnt.clear()
|
||||
torch._dynamo.reset()
|
||||
|
||||
# TODO duck sizing shall be disabled when backed_size_oblivious
|
||||
# is on probably.
|
||||
# specialize b[0] == 1. a[0] not specialized.
|
||||
run(torch.rand(10, 11), torch.rand(1, 11))
|
||||
run(torch.rand(1, 10), torch.rand(1, 10))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
run(torch.rand(2, 10), torch.rand(2, 10))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
||||
@ -4251,7 +4251,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_stack_trace_augmentation(self):
|
||||
"""
|
||||
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
|
||||
@ -4307,7 +4307,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_multiple_modules(self):
|
||||
"""
|
||||
Test that multiple compiled modules under the same profiler session
|
||||
@ -4351,7 +4351,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_nested_graph_modules(self):
|
||||
"""
|
||||
Test that nested graph modules (e.g., graph modules calling subgraphs)
|
||||
|
||||
@ -208,3 +208,45 @@ def parse_xml_and_upload_json() -> None:
|
||||
lock.release()
|
||||
except Exception as e:
|
||||
print(f"Failed to parse and upload json test reports: {e}")
|
||||
|
||||
|
||||
def upload_adhoc_failure_json(invoking_file: str, current_failure: str) -> None:
|
||||
"""
|
||||
manually upload a json to s3 indicating that the entire test file failed
|
||||
since xml was probably not generated in this case
|
||||
"""
|
||||
try:
|
||||
job_id = int(os.environ["JOB_ID"])
|
||||
workflow_id = int(os.environ["GITHUB_RUN_ID"])
|
||||
except Exception as e:
|
||||
print(f"Failed to get job_id or workflow_id: {e}")
|
||||
return
|
||||
|
||||
split_failure = current_failure.split("::")
|
||||
if len(split_failure) >= 2:
|
||||
className = split_failure[-2]
|
||||
testName = split_failure[-1]
|
||||
else:
|
||||
testName = current_failure
|
||||
className = ""
|
||||
|
||||
message = "The test file failed but pytest did not generate xml. The most likely cause is a segfault"
|
||||
j = {
|
||||
"invoking_file": invoking_file,
|
||||
"file": f"{invoking_file}.py",
|
||||
"name": testName,
|
||||
"classname": className,
|
||||
"workflow_id": workflow_id,
|
||||
"workflow_run_attempt": os.environ.get("GITHUB_RUN_ATTEMPT"),
|
||||
"job_id": job_id,
|
||||
"failure": {"message": message, "text": message},
|
||||
}
|
||||
gzipped = gzip.compress(json.dumps(j).encode("utf-8"))
|
||||
s3_key = f"{invoking_file.replace('/', '_')}_{os.urandom(8).hex()}.json"
|
||||
get_s3_resource().put_object(
|
||||
Body=gzipped,
|
||||
Bucket="gha-artifacts",
|
||||
Key=f"test_jsons_while_running/{workflow_id}/{job_id}/{s3_key}",
|
||||
ContentType="application/json",
|
||||
ContentEncoding="gzip",
|
||||
)
|
||||
|
||||
@ -739,8 +739,11 @@ enable_aot_compile = False
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
# Deprecated! Please use the config in torch/fx/experimental/_config instead.
|
||||
enrich_profiler_metadata: bool = False
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
@ -47,7 +47,7 @@ from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
|
||||
|
||||
from . import config
|
||||
from .utils import clone_inputs, get_debug_dir, warn_once
|
||||
from .utils import clone_inputs, get_debug_dir
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -617,7 +617,7 @@ class InputReader:
|
||||
# way would be very mysterious! Would have been better
|
||||
# not to store device in the serialized format...
|
||||
return storage
|
||||
warn_once(f"could not load {storage_hash}, generating random data instead")
|
||||
log.warning("could not load %s, generating random data instead", storage_hash)
|
||||
shape = (nbytes // dtype_hint.itemsize,)
|
||||
stride = _stride_or_default(None, shape=shape)
|
||||
return rand_strided(shape, stride, dtype_hint, device).untyped_storage()
|
||||
|
||||
@ -2937,18 +2937,5 @@
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0288": [
|
||||
{
|
||||
"Gb_type": "Can't extract message from torch._check()",
|
||||
"Context": "str(message_vt)",
|
||||
"Explanation": "The second argument of torch._check() must be a functiondefined within the torch.compile regionthat does not reference a non-local variable.",
|
||||
"Hints": [
|
||||
"Make sure the message function is defined in the torch.compile region.",
|
||||
"Remove any closure variables, e.g. ",
|
||||
"remove references to closure variable `x` in `lambda: f'{x} failed check'`",
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -180,7 +180,6 @@ manual_torch_name_rule_map: dict[
|
||||
"torch.compiler.is_exporting": TorchInGraphFunctionVariable,
|
||||
"torch._C._to_dlpack": SkipFunctionVariable,
|
||||
"torch.to_dlpack": SkipFunctionVariable,
|
||||
"torch._check": TorchInGraphFunctionVariable,
|
||||
# We graph break on RNG state setters or getters like
|
||||
# `torch.get_rng_state` or `torch.set_rng_state`. These functions
|
||||
# are not aten operations and therefore they are completely ignored
|
||||
@ -2344,6 +2343,7 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._check_type",
|
||||
"torch._check_value",
|
||||
"torch._check_with",
|
||||
"torch._check",
|
||||
"torch._compile._disable_dynamo",
|
||||
"torch._functorch.apis.chunk_vmap",
|
||||
"torch._functorch.batch_norm_replacement.batch_norm_without_running_stats",
|
||||
|
||||
@ -78,7 +78,7 @@ from .ctx_manager import (
|
||||
)
|
||||
from .dicts import ConstDictVariable
|
||||
from .distributed import DistributedVariable, ProcessGroupVariable
|
||||
from .functions import bind_args_cached, NestedUserFunctionVariable
|
||||
from .functions import bind_args_cached
|
||||
from .lists import ListVariable, TupleVariable
|
||||
from .torch_function import (
|
||||
can_dispatch_torch_function,
|
||||
@ -1318,86 +1318,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
@register(torch._check)
|
||||
def handle_check(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
predicate_vt = None
|
||||
message_vt = None
|
||||
|
||||
if args:
|
||||
predicate_vt = args[0]
|
||||
rest_args = args[1:]
|
||||
else:
|
||||
rest_args = ()
|
||||
|
||||
if predicate_vt is None and "cond" in kwargs:
|
||||
predicate_vt = kwargs.pop("cond")
|
||||
|
||||
if rest_args:
|
||||
message_vt = rest_args[0]
|
||||
elif "message" in kwargs:
|
||||
message_vt = kwargs.pop("message")
|
||||
|
||||
if predicate_vt is None:
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
(),
|
||||
{},
|
||||
),
|
||||
)
|
||||
|
||||
message_eager = None
|
||||
message_graph_proxy = None
|
||||
if message_vt is not None:
|
||||
if (
|
||||
not isinstance(message_vt, NestedUserFunctionVariable)
|
||||
or message_vt.has_closure()
|
||||
):
|
||||
unimplemented_v2(
|
||||
gb_type="Can't extract message from torch._check()",
|
||||
context=str(message_vt),
|
||||
explanation=(
|
||||
"The second argument of torch._check() must be a function"
|
||||
"defined within the torch.compile region"
|
||||
"that does not reference a non-local variable."
|
||||
),
|
||||
hints=[
|
||||
"Make sure the message function is defined in the torch.compile region.",
|
||||
"Remove any closure variables, e.g. "
|
||||
"remove references to closure variable `x` in `lambda: f'{x} failed check'`",
|
||||
*graph_break_hints.SUPPORTABLE,
|
||||
],
|
||||
)
|
||||
message_eager = message_vt.get_function()
|
||||
|
||||
message_graph_proxy = tx.output.register_static_attr_and_return_proxy(
|
||||
"_check_message", message_eager
|
||||
)
|
||||
|
||||
if predicate_vt.is_python_constant():
|
||||
self.value(predicate_vt.as_python_constant(), message_eager)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
predicate_proxy = predicate_vt.as_proxy()
|
||||
|
||||
proxy_args: tuple[Any, ...]
|
||||
if message_graph_proxy is None:
|
||||
proxy_args = (predicate_proxy,)
|
||||
else:
|
||||
proxy_args = (predicate_proxy, message_graph_proxy)
|
||||
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
proxy_args,
|
||||
{},
|
||||
),
|
||||
)
|
||||
|
||||
return handlers
|
||||
|
||||
def call_function(
|
||||
|
||||
@ -104,7 +104,7 @@ from .._dynamo.exc import ShortenTraceback, SkipFrame
|
||||
from ..fx._lazy_graph_module import _use_lazy_graph_module
|
||||
from ..fx.graph import _PyTreeCodeGen
|
||||
from ..utils._triton import has_triton
|
||||
from . import config, distributed_autotune, metrics
|
||||
from . import config, metrics
|
||||
from .codegen.common import get_wrapper_codegen_for_device, init_backend_registration
|
||||
from .debug import DebugContext
|
||||
from .decomposition import select_decomp_table
|
||||
@ -1431,11 +1431,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
# We are going to start code generating runtime asserts, so make sure
|
||||
# you don't start adding new ones in the lowering process
|
||||
graph.freeze_runtime_asserts()
|
||||
with (
|
||||
V.set_graph_handler(graph),
|
||||
V.set_extern_kernel_nodes([]),
|
||||
distributed_autotune.graph_context(),
|
||||
):
|
||||
with V.set_graph_handler(graph), V.set_extern_kernel_nodes([]):
|
||||
graph.run(*example_inputs)
|
||||
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
|
||||
if graph.graph_outputs is not None:
|
||||
|
||||
@ -447,14 +447,6 @@ use_experimental_benchmarker: bool = Config(
|
||||
justknob="pytorch/inductor:use_experimental_benchmarker",
|
||||
)
|
||||
|
||||
# Enable distributed autotuning. When this is enabled we will distribute the
|
||||
# autotuning across distributed ranks in the same program group - so instead of
|
||||
# each rank autotuning every kernel they only autotune 1/world size kernels and
|
||||
# then share the results.
|
||||
distributed_max_autotune_gemm = (
|
||||
os.environ.get("TORCHINDUCTOR_DISTRIBUTED_MAX_AUTOTUNE_GEMM") == "1"
|
||||
)
|
||||
|
||||
# enable slow autotuning passes to select algorithms
|
||||
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
|
||||
|
||||
|
||||
@ -1,386 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
|
||||
import torch._logging
|
||||
import torch.distributed as dist
|
||||
import torch.fx
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from . import config, select_algorithm
|
||||
from .ir import (
|
||||
Buffer,
|
||||
ChoiceCaller,
|
||||
Layout,
|
||||
MultiTemplateBuffer,
|
||||
OperationBuffer,
|
||||
ShapeAsConstantBuffer,
|
||||
StorageBox,
|
||||
TensorBox,
|
||||
)
|
||||
from .kernel_inputs import KernelInputs, MMKernelInputs
|
||||
from .scheduler import SchedulerNode
|
||||
from .virtualized import NullHandler, V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Sequence
|
||||
|
||||
|
||||
_DISTRIBUTED_AUTOTUNE_KEY = "distributed_autotune"
|
||||
|
||||
_AUTOTUNE_PG: dist.ProcessGroup | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DistributedAutotuneState:
|
||||
"""
|
||||
State used to track autotuning during a graph_context()
|
||||
"""
|
||||
|
||||
# This is the next operator index. Used to figure out which rank should do
|
||||
# the autotuning.
|
||||
autotuned_index: int = 0
|
||||
|
||||
# For debugging - used to make sure that we autotune the same number of
|
||||
# local operators that we expected to.
|
||||
autotuned_local_count: int = 0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DistributedAutotuneInfo:
|
||||
index: int
|
||||
local: bool
|
||||
|
||||
|
||||
def get_autotune_pg() -> dist.ProcessGroup | None:
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
global _AUTOTUNE_PG
|
||||
if _AUTOTUNE_PG is None:
|
||||
_AUTOTUNE_PG = dist.distributed_c10d._new_group_with_tag(
|
||||
pg_tag="pt2_distributed_autotune_pg"
|
||||
)
|
||||
return _AUTOTUNE_PG
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def schedule(scheduler: torch._inductor.scheduler.Scheduler) -> None:
|
||||
"""
|
||||
Finish the distributed autotuning by propagating the autotuning results
|
||||
between the ranks and then replacing the placeholder with the real Buffer.
|
||||
"""
|
||||
assert config.distributed_max_autotune_gemm
|
||||
autotune_results = _autotune_local_nodes(scheduler)
|
||||
choices_by_index = _sync(autotune_results)
|
||||
_autotune_remote_nodes(scheduler, choices_by_index)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def graph_context() -> Generator[None, None, None]:
|
||||
"""
|
||||
Wrapped around processing a graph, sets up figuring out which ranks tune
|
||||
which shapes.
|
||||
"""
|
||||
assert not isinstance(
|
||||
V.get_distributed_autotune_state(check_poisoned=False), # type: ignore[call-arg]
|
||||
_DistributedAutotuneState,
|
||||
)
|
||||
V.set_distributed_autotune_state(_DistributedAutotuneState())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
V.set_distributed_autotune_state(NullHandler())
|
||||
|
||||
|
||||
def maybe_autotune_remote(
|
||||
name: str, choices: list[ChoiceCaller], inputs: list[Buffer], layout: Layout
|
||||
) -> TensorBox | ShapeAsConstantBuffer | None:
|
||||
"""
|
||||
Used by an op (like `mm`) to determine if the op should be autotuned
|
||||
locally (returns None) or remotely (returns a placeholder Buffer).
|
||||
"""
|
||||
if not config.distributed_max_autotune_gemm:
|
||||
return None
|
||||
|
||||
if not (autotune_pg := get_autotune_pg()):
|
||||
return None
|
||||
|
||||
if len(choices) <= 1:
|
||||
return None
|
||||
|
||||
state = V.distributed_autotune_state
|
||||
index = state.autotuned_index
|
||||
state.autotuned_index += 1
|
||||
local = index % autotune_pg.size() == autotune_pg.rank()
|
||||
|
||||
V.current_node.meta[_DISTRIBUTED_AUTOTUNE_KEY] = _DistributedAutotuneInfo(
|
||||
index, local
|
||||
)
|
||||
if local:
|
||||
state.autotuned_local_count += 1
|
||||
return None
|
||||
|
||||
return torch._inductor.ir.TensorBox.create(
|
||||
_DistributedAutotuneBuffer(name, inputs, layout)
|
||||
)
|
||||
|
||||
|
||||
class _DistributedAutotuneBuffer(MultiTemplateBuffer):
|
||||
"""
|
||||
A MultiTemplateBuffer which represents a kernel being autotuned on a
|
||||
different rank. When `schedule` is called this will be replaced by the
|
||||
"real" buffer.
|
||||
"""
|
||||
|
||||
# Name of the kernel being autotuned.
|
||||
_kernel_name: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
inputs: list[Buffer],
|
||||
layout: Layout,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
layout,
|
||||
inputs,
|
||||
choice_timings_fn=self._dummy_choice_timings,
|
||||
unfiltered_choices=[],
|
||||
allowed_prologue_inps=OrderedSet({}),
|
||||
)
|
||||
|
||||
self._kernel_name = kernel_name
|
||||
|
||||
def _dummy_choice_timings(
|
||||
self, _hint_override: int | None
|
||||
) -> dict[ChoiceCaller, float]:
|
||||
# This should never get called. It means that a remote autotune was
|
||||
# scheduled but never filled in.
|
||||
raise NotImplementedError
|
||||
|
||||
def autotune(self, ser_choice: _SerializedChoice) -> TensorBox:
|
||||
"""
|
||||
Given a _SerializedChoice (autotune results from another rank)
|
||||
compute the final TensorBox.
|
||||
"""
|
||||
|
||||
from .select_algorithm import autotune_select_algorithm
|
||||
|
||||
with patch.object(V.graph, "scheduler", None):
|
||||
kernel_inputs = MMKernelInputs([*self.original_inputs])
|
||||
assert isinstance(self.layout, Layout)
|
||||
choice = ser_choice.get_choice(self.layout, kernel_inputs)
|
||||
buffer = autotune_select_algorithm(
|
||||
self._kernel_name,
|
||||
[choice],
|
||||
kernel_inputs.nodes(),
|
||||
self.layout,
|
||||
)
|
||||
assert isinstance(buffer, TensorBox)
|
||||
return buffer
|
||||
|
||||
|
||||
# Can we make this async?
|
||||
def _sync(autotune_results: list[_SerializedChoice]) -> Sequence[_SerializedChoice]:
|
||||
"""
|
||||
Perform the all_gather to collect the autotune results from all the ranks.
|
||||
"""
|
||||
|
||||
autotune_pg = get_autotune_pg()
|
||||
assert autotune_pg
|
||||
|
||||
# Perform allgather
|
||||
all_states: list[list[_SerializedChoice]] = [None] * autotune_pg.size() # type: ignore[list-item]
|
||||
torch.distributed.all_gather_object(all_states, autotune_results, group=autotune_pg)
|
||||
|
||||
node_count = sum(len(x) for x in all_states)
|
||||
# It's faster to briefly lie about the type than to unzip the results and append.
|
||||
choices_by_index: list[_SerializedChoice] = [None] * node_count # type: ignore[list-item]
|
||||
|
||||
check_count = 0
|
||||
for i, other_results in enumerate(all_states):
|
||||
for choice in other_results:
|
||||
assert isinstance(choice, _SerializedChoice)
|
||||
assert choices_by_index[choice.index] is None
|
||||
choices_by_index[choice.index] = choice
|
||||
check_count += 1
|
||||
|
||||
assert node_count == check_count, f"count mismatch: {node_count} != {check_count}"
|
||||
return choices_by_index
|
||||
|
||||
|
||||
class _SerializedChoice:
|
||||
"""
|
||||
This is a serializer for the autotune choice. KernelTemplateChoice can't
|
||||
be serialized directly (the template and inputs prevent this) so we need to
|
||||
serialize it by parts and reconstruct later on.
|
||||
"""
|
||||
|
||||
def __init__(self, index: int, choice: ChoiceCaller) -> None:
|
||||
self.index = index
|
||||
self.template_uid = _SerializedChoice._template_uid_from_choice(choice)
|
||||
self.kwargs = self._compute_kwargs(choice.description)
|
||||
|
||||
def get_choice(self, layout: Layout, inputs: KernelInputs) -> ChoiceCaller | None:
|
||||
"""
|
||||
Deserialize the ChoiceCaller and return it.
|
||||
"""
|
||||
|
||||
template = self._template_from_uid()
|
||||
|
||||
kwargs = {**self.kwargs}
|
||||
if "BLOCK_K" in kwargs:
|
||||
# TODO: Do we really need to externally compute this value? If it's
|
||||
# needed I'm surprised it's not just part of the original template
|
||||
# description.
|
||||
# This needs the actual 'k' to figure out the value.
|
||||
k = inputs.nodes()[0].get_size()[1]
|
||||
kwargs["EVEN_K"] = sympy.gcd(k, kwargs["BLOCK_K"]) == kwargs["BLOCK_K"]
|
||||
|
||||
extra_kwargs: dict[str, Any] = {}
|
||||
from .kernel_template_choice import (
|
||||
DictKernelTemplateParams,
|
||||
KernelTemplateChoice,
|
||||
)
|
||||
|
||||
params = DictKernelTemplateParams(kwargs)
|
||||
ktc = KernelTemplateChoice(template, params, extra_kwargs, layout, inputs)
|
||||
return ktc.choice
|
||||
|
||||
@staticmethod
|
||||
def _compute_kwargs(description: str) -> dict[str, Union[int, str, bool]]:
|
||||
"""
|
||||
Given a template description turn it into input kwargs.
|
||||
"""
|
||||
if not description:
|
||||
return {}
|
||||
|
||||
# TODO: It seems like it would be better if the template could provide
|
||||
# this directly instead of having to parse a string.
|
||||
kwargs: dict[str, Union[int, str, bool]] = {}
|
||||
for cfg in description.split(","):
|
||||
key, val = cfg.split("=", 1)
|
||||
key, val = key.strip(), val.strip()
|
||||
if val == "True":
|
||||
kwargs[key] = True
|
||||
elif val == "False":
|
||||
kwargs[key] = False
|
||||
elif val.isdigit():
|
||||
kwargs[key] = int(val)
|
||||
else:
|
||||
assert val.startswith("'") and val.endswith("'")
|
||||
kwargs[key] = val[1:-1]
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _template_uid_from_choice(choice: ChoiceCaller) -> str:
|
||||
"""
|
||||
Given a ChoiceCaller figure out which template represents it. This
|
||||
is reversed by _template_from_uid().
|
||||
"""
|
||||
|
||||
# We need a better way to do this - right now we need to add each
|
||||
# supported template directly.
|
||||
if isinstance(choice, select_algorithm.ExternKernelCaller):
|
||||
if choice.choice.name == "mm":
|
||||
return "torch._inductor.kernel.mm.aten_mm"
|
||||
else:
|
||||
raise RuntimeError(f"TODO: kernel {choice.choice.name!r}")
|
||||
elif isinstance(choice, select_algorithm.TritonTemplateCaller):
|
||||
return "torch._inductor.kernel.mm.mm_template"
|
||||
else:
|
||||
raise RuntimeError(f"TODO: {type(choice)}")
|
||||
|
||||
def _template_from_uid(self) -> Any:
|
||||
"""
|
||||
See _template_uid_from_choice().
|
||||
"""
|
||||
parts = self.template_uid.split(".")
|
||||
obj = globals()[parts[0]]
|
||||
for k in parts[1:]:
|
||||
obj = getattr(obj, k)
|
||||
return obj
|
||||
|
||||
|
||||
def _autotune_local_nodes(
|
||||
scheduler: torch._inductor.scheduler.Scheduler,
|
||||
) -> list[_SerializedChoice]:
|
||||
"""
|
||||
Go through the nodes in the scheduler and autotune the kernels which
|
||||
should be autotuned by this rank.
|
||||
"""
|
||||
|
||||
autotune_results: list[_SerializedChoice] = []
|
||||
|
||||
for node in scheduler.nodes:
|
||||
if not isinstance(node, SchedulerNode):
|
||||
continue
|
||||
|
||||
if (inner_node := node.node) is None:
|
||||
continue
|
||||
|
||||
if isinstance(inner_node, _DistributedAutotuneBuffer):
|
||||
# This is marked for remote autotuning.
|
||||
continue
|
||||
|
||||
if not isinstance(inner_node, MultiTemplateBuffer):
|
||||
continue
|
||||
|
||||
if (origin_node := inner_node.origin_node) is None:
|
||||
continue
|
||||
|
||||
if (meta := origin_node.meta) is None:
|
||||
continue
|
||||
|
||||
info = meta.get(_DISTRIBUTED_AUTOTUNE_KEY)
|
||||
if info is None:
|
||||
continue
|
||||
|
||||
assert info.local
|
||||
|
||||
# We force autotuning here
|
||||
# Still takes advantage of async precompile
|
||||
# We need all the configs before fusion
|
||||
min_choice, _ = inner_node.get_min_choice()
|
||||
|
||||
choice = _SerializedChoice(info.index, min_choice)
|
||||
autotune_results.append(choice)
|
||||
|
||||
state = V.distributed_autotune_state
|
||||
assert len(autotune_results) == state.autotuned_local_count, (
|
||||
f"incorrect local autotuned nodes found ({len(autotune_results)} != {state.autotuned_local_count})"
|
||||
)
|
||||
return autotune_results
|
||||
|
||||
|
||||
def _autotune_remote_nodes(
|
||||
scheduler: torch._inductor.scheduler.Scheduler,
|
||||
choices_by_index: Sequence[_SerializedChoice],
|
||||
) -> None:
|
||||
"""
|
||||
Go through the nodes in the scheduler and autotune the nodes that were
|
||||
autotuned on remote ranks.
|
||||
"""
|
||||
|
||||
for i, node in enumerate(scheduler.nodes):
|
||||
if isinstance(node, SchedulerNode) and isinstance(
|
||||
(dist_node := node.node), _DistributedAutotuneBuffer
|
||||
):
|
||||
assert dist_node.origin_node is not None
|
||||
info = dist_node.origin_node.meta[_DISTRIBUTED_AUTOTUNE_KEY]
|
||||
out_tensorbox = dist_node.autotune(choices_by_index[info.index])
|
||||
|
||||
out_storage = out_tensorbox.data
|
||||
assert isinstance(out_storage, StorageBox)
|
||||
out_buffer = out_storage.data
|
||||
assert isinstance(out_buffer, OperationBuffer)
|
||||
|
||||
assert out_buffer.layout == dist_node.layout
|
||||
|
||||
scheduler._replace_node(out_buffer, dist_node, i, node)
|
||||
@ -19,7 +19,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn.functional import ScalingType # type: ignore[attr-defined]
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
from .. import config as inductor_config, distributed_autotune
|
||||
from .. import config as inductor_config
|
||||
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
|
||||
from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate
|
||||
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
|
||||
@ -1315,11 +1315,6 @@ def tuned_mm(mat1, mat2, out_dtype=None, *, layout=None):
|
||||
# The future will be awaited at scheduling time in select_algorithm.py
|
||||
best_config_future = gen_best_config(mat1, mat2)
|
||||
|
||||
if box := distributed_autotune.maybe_autotune_remote(
|
||||
name, choices, kernel_inputs.nodes(), layout
|
||||
):
|
||||
return box
|
||||
|
||||
return autotune_select_algorithm(
|
||||
name,
|
||||
choices,
|
||||
|
||||
@ -449,6 +449,7 @@ class SchedulerDonatedBuffer(SchedulerBuffer):
|
||||
|
||||
class BaseSchedulerNode:
|
||||
ancestors: OrderedSet[str]
|
||||
debug_device_str: Callable[[BaseSchedulerNode], list[str]]
|
||||
group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
|
||||
last_usage: OrderedSet[str]
|
||||
# .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
|
||||
@ -460,26 +461,21 @@ class BaseSchedulerNode:
|
||||
max_order: int
|
||||
mpi_node: MemoryPlanningInfoForNode
|
||||
mutation_renames: dict[str, str]
|
||||
node: Optional[ir.Operation] = None
|
||||
node: Optional[ir.Operation]
|
||||
outputs: list[SchedulerBuffer]
|
||||
outputs_by_name: dict[str, SchedulerBuffer]
|
||||
override_estimated_runtime: Optional[float] = None
|
||||
read_writes: dependencies.ReadWrites
|
||||
unmet_dependencies: OrderedSet[Dep]
|
||||
written: bool = False
|
||||
|
||||
def __init__(self, scheduler: Scheduler) -> None:
|
||||
self.scheduler: Scheduler = scheduler
|
||||
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
|
||||
lambda *args, **kwargs: []
|
||||
)
|
||||
self.scheduler = scheduler
|
||||
self.debug_device_str = lambda *args, **kwargs: []
|
||||
|
||||
def _init_from_node(self, node: ir.Operation) -> None:
|
||||
self.node = node
|
||||
self.ancestors = OrderedSet()
|
||||
self.last_usage = OrderedSet[
|
||||
str
|
||||
]() # buffers that won't be used after this kernel
|
||||
self.last_usage = OrderedSet() # buffers that won't be used after this kernel
|
||||
self.written = False
|
||||
self.outputs = [
|
||||
SchedulerBuffer(
|
||||
@ -2647,12 +2643,6 @@ class Scheduler:
|
||||
if config._pre_fusion_custom_pass is not None:
|
||||
self.nodes = config._pre_fusion_custom_pass(self.nodes)
|
||||
|
||||
if config.distributed_max_autotune_gemm:
|
||||
from . import distributed_autotune
|
||||
|
||||
distributed_autotune.schedule(self)
|
||||
self.compute_ancestors()
|
||||
|
||||
self.nodes = self.fuse_nodes(self.nodes)
|
||||
if config._post_fusion_custom_pass is not None:
|
||||
self.nodes = config._post_fusion_custom_pass(self.nodes)
|
||||
@ -3525,7 +3515,6 @@ class Scheduler:
|
||||
|
||||
new_scheduler_node.min_order = node.min_order
|
||||
new_scheduler_node.max_order = node.max_order
|
||||
new_scheduler_node.ancestors = node.ancestors
|
||||
new_scheduler_node.last_usage = node.last_usage
|
||||
|
||||
def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
|
||||
|
||||
@ -86,8 +86,6 @@ if TYPE_CHECKING:
|
||||
from torch._inductor.loop_body import InterpreterShim
|
||||
from torch._subclasses import FakeTensorMode
|
||||
|
||||
from .distributed_autotune import _DistributedAutotuneState
|
||||
|
||||
threadlocal = local()
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -203,9 +201,6 @@ _current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHand
|
||||
_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized(
|
||||
"local_buffer_context", NullHandler
|
||||
)
|
||||
_distributed_autotune_state: Virtualized[_DistributedAutotuneState] = Virtualized(
|
||||
"distributed_autotune_state", NullHandler
|
||||
)
|
||||
|
||||
|
||||
def _choices_default():
|
||||
@ -375,12 +370,6 @@ class _V:
|
||||
set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler
|
||||
get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler
|
||||
set_choices_handler: Callable[[Any], Any] = _choices._set_handler
|
||||
set_distributed_autotune_state: Callable[[Any], Any] = (
|
||||
_distributed_autotune_state._set_handler
|
||||
)
|
||||
get_distributed_autotune_state: Callable[[], Any] = (
|
||||
_distributed_autotune_state._get_handler
|
||||
)
|
||||
|
||||
@property
|
||||
def ops(self) -> OpsHandler[Any]:
|
||||
@ -440,9 +429,5 @@ class _V:
|
||||
def choices(self) -> InductorChoices:
|
||||
return _choices._get_handler()
|
||||
|
||||
@property
|
||||
def distributed_autotune_state(self):
|
||||
return _distributed_autotune_state._get_handler()
|
||||
|
||||
|
||||
V = _V()
|
||||
|
||||
@ -385,13 +385,7 @@ def handle_noncontiguous_outputs(input_tlist, output):
|
||||
|
||||
|
||||
def _broadcast_shapes(*_shapes):
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
is_nested_int,
|
||||
size_hint,
|
||||
)
|
||||
|
||||
backed_so = torch.fx.experimental._config.backed_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false, is_nested_int
|
||||
|
||||
shapes = tuple(
|
||||
(x,) if isinstance(x, IntLike) else x
|
||||
@ -424,22 +418,6 @@ def _broadcast_shapes(*_shapes):
|
||||
):
|
||||
continue
|
||||
else:
|
||||
# When backed size oblivious is used, we specialize for broadcasting
|
||||
# if its the only way to compile the example input.
|
||||
# i.e: s0:1, s1:1 ==>
|
||||
# assert s0==s1, no specialization on ==1 or !=1.
|
||||
# The non-broadcast path is picked
|
||||
# s0:1, s1:4 ==>
|
||||
# specialize(s0) to be 1.
|
||||
# s0:4, s1:1 ==>
|
||||
# specialize(s1) to be 1.
|
||||
if backed_so:
|
||||
a = size_hint(shape[idx], allow_none=True)
|
||||
b = size_hint(common_shape[idx], allow_none=True)
|
||||
if a == 1 and b != 1:
|
||||
torch._check(shape[idx] == 1)
|
||||
if b == 1 and a != 1:
|
||||
torch._check(common_shape[idx] == 1)
|
||||
if guard_or_false(shape[idx] == common_shape[idx]):
|
||||
continue
|
||||
|
||||
|
||||
@ -2,8 +2,6 @@ import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from torch.utils._config_module import Config, install_config_module
|
||||
|
||||
|
||||
# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors.
|
||||
no_data_dependent_graph_break = (
|
||||
@ -102,11 +100,7 @@ backed_size_oblivious = False
|
||||
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
|
||||
skip_dtype_check_in_meta_registrations = False
|
||||
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
||||
@ -131,7 +131,6 @@ class PendingUnbackedSymbolNotFound(RuntimeError):
|
||||
aten = torch._ops.ops.aten # type: ignore[has-type]
|
||||
|
||||
__all__ = [
|
||||
"size_hint",
|
||||
"guard_or_false",
|
||||
"guard_or_true",
|
||||
"has_symbolic_sizes_strides",
|
||||
@ -256,17 +255,6 @@ def _nested_int_aware_sort(
|
||||
)
|
||||
|
||||
|
||||
def size_hint(x: int | torch.SymInt, *, allow_none: bool = False) -> int | None:
|
||||
"""Gets a size hint for a given expression from the underlying shapes we had.
|
||||
Does not introduce a guard, so only use this when you can guarantee that
|
||||
your code is still valid for arbitrary shapes (such as optimization decisions)
|
||||
"""
|
||||
if isinstance(x, int):
|
||||
return x
|
||||
assert isinstance(x, torch.SymInt)
|
||||
return x.node.shape_env.size_hint(x.node.expr, allow_none=allow_none)
|
||||
|
||||
|
||||
# Wrapper on lru_cache that reports statistics at process end
|
||||
def lru_cache(
|
||||
maxsize: Optional[int],
|
||||
|
||||
@ -20,7 +20,6 @@ from torch.nn.modules.module import _addindent
|
||||
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from .experimental import _config as fx_experimental_config
|
||||
from .graph import (
|
||||
_BoxedCodeGen,
|
||||
_custom_builtins,
|
||||
@ -859,15 +858,14 @@ class {module_name}(torch.nn.Module):
|
||||
called after editing the contained ``graph``, otherwise the generated
|
||||
code of this ``GraphModule`` will be out of date.
|
||||
"""
|
||||
# Do not import anything inside recompile, it might slow down the
|
||||
# function and cause perf regression. Import outside of the method instead.
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
python_code = self._graph.python_code(
|
||||
root_module="self",
|
||||
record_func=fx_experimental_config.enrich_profiler_metadata,
|
||||
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
|
||||
)
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
@ -876,7 +874,7 @@ class {module_name}(torch.nn.Module):
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
|
||||
if fx_experimental_config.enrich_profiler_metadata:
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
node_metadata: dict[int, dict[str, Any]] = {}
|
||||
for i, node in enumerate(self._graph.nodes):
|
||||
|
||||
Reference in New Issue
Block a user