mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 19:15:00 +08:00
Compare commits
8 Commits
ciflow/tru
...
mlazos/use
| Author | SHA1 | Date | |
|---|---|---|---|
| f1fc0f62e0 | |||
| 75e545f81d | |||
| 6e6378784f | |||
| abe6cbecca | |||
| 20e0aff9e5 | |||
| 9113637f2d | |||
| 075a01ec48 | |||
| f640d0c072 |
@ -230,7 +230,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 20)
|
||||
self.assertExpectedInline(str(cnts.op_count), """9""")
|
||||
|
||||
@unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/118204
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
@ -335,7 +335,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 37)
|
||||
self.assertExpectedInline(str(cnts.op_count), """15""")
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
def test_cuda_stream_compared_with_constant(self):
|
||||
@ -517,7 +517,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
res = opt_fn(x, cur_stream, new_stream)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 27)
|
||||
self.assertExpectedInline(str(cnts.op_count), """16""")
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
def test_cuda_event_method(self):
|
||||
@ -537,7 +537,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
with torch.cuda.stream(new_stream):
|
||||
x = torch.add(x, 4)
|
||||
|
||||
new_event = torch.cuda.Event()
|
||||
new_event = torch.Event()
|
||||
new_event.record(new_stream)
|
||||
|
||||
new_event.wait(cur_stream)
|
||||
@ -557,7 +557,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 27)
|
||||
self.assertExpectedInline(str(cnts.op_count), """16""")
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
def test_cuda_device(self):
|
||||
|
||||
@ -1540,7 +1540,7 @@ cannot resume from torch._dynamo.step_unsupported()
|
||||
|
||||
Developer debug context:
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0283.html
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0284.html
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
|
||||
@ -1,12 +1,20 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import functools
|
||||
import unittest
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import requires_cuda
|
||||
|
||||
|
||||
requires_multigpu = functools.partial(
|
||||
unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices"
|
||||
)
|
||||
|
||||
|
||||
class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -26,6 +34,144 @@ class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
e = torch.Event()
|
||||
weakref.ref(e)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_enter_exit(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
z = torch.add(x, y)
|
||||
y = z + 2 + z1
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
actual = fn_opt(*inp)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_context_graph_break(self):
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
with s1:
|
||||
z1 = torch.add(x, y)
|
||||
with s2:
|
||||
z = torch.add(x, y)
|
||||
y = z + 2 + z1
|
||||
torch._dynamo.graph_break()
|
||||
y = y + 1
|
||||
|
||||
return y
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
expected = fn(*inp)
|
||||
fn_opt = torch.compile(fn)
|
||||
actual = fn_opt(*inp)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_input(self):
|
||||
def fn(x, y, s):
|
||||
z = torch.add(x, y)
|
||||
y = z + 2
|
||||
return y, s
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(device="cuda"))
|
||||
expected = fn(*inp)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
actual = fn_opt(*inp)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
@requires_cuda
|
||||
def test_local_stream_return(self):
|
||||
def fn(x, y):
|
||||
s = torch.Stream()
|
||||
z = torch.add(x, y)
|
||||
y = z + 2
|
||||
return y, s
|
||||
|
||||
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
_, s0 = fn_opt(*inp)
|
||||
_, s1 = fn_opt(*inp)
|
||||
# Streams will be different values for each invocation
|
||||
# so don't check for equality
|
||||
self.assertIsInstance(s0, torch.Stream)
|
||||
# Stream should be newly allocated on each call
|
||||
self.assertNotEqual(s0, s1)
|
||||
|
||||
@requires_cuda
|
||||
def test_get_current_stream_return(self):
|
||||
def fn(x, s):
|
||||
with s:
|
||||
s0 = torch.accelerator.current_stream()
|
||||
return x, s0
|
||||
|
||||
s_inp = torch.Stream(device="cuda")
|
||||
inp = (torch.ones(2, 2) + 1, s_inp)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
_, s0 = fn_opt(*inp)
|
||||
_, s1 = fn_opt(*inp)
|
||||
self.assertEqual(s_inp, s0)
|
||||
self.assertEqual(s0, s1)
|
||||
|
||||
@requires_cuda
|
||||
@requires_multigpu()
|
||||
def test_get_current_stream_return_different_device(self):
|
||||
def fn(x, s0, s1):
|
||||
with s1:
|
||||
with s0:
|
||||
s = torch.accelerator.current_stream(torch.device("cuda:1"))
|
||||
return s
|
||||
|
||||
s0 = torch.Stream(device="cuda:0")
|
||||
s1 = torch.Stream(device="cuda:1")
|
||||
inp = (torch.ones(2, 2) + 1, s0, s1)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
s_act = fn_opt(*inp)
|
||||
s_exp = fn(*inp)
|
||||
self.assertEqual(s_act, s_exp)
|
||||
|
||||
@requires_cuda
|
||||
@requires_multigpu()
|
||||
def test_get_current_stream_return_no_index(self):
|
||||
def fn(x, s0, s1):
|
||||
with s1:
|
||||
with s0:
|
||||
s = torch.accelerator.current_stream(torch.device("cuda"))
|
||||
return s
|
||||
|
||||
s0 = torch.Stream(device="cuda:0")
|
||||
s1 = torch.Stream(device="cuda:1")
|
||||
inp = (torch.ones(2, 2) + 1, s0, s1)
|
||||
fn_opt = torch.compile(fn, fullgraph=True)
|
||||
s_act = fn_opt(*inp)
|
||||
s_exp = fn(*inp)
|
||||
self.assertEqual(s_act, s_exp)
|
||||
|
||||
def test_nested_stream_enter_exit(self):
|
||||
pass
|
||||
|
||||
def test_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
def test_nested_stream_enter_exit_graph_break(self):
|
||||
pass
|
||||
|
||||
def test_local_stream_enter_exit(self):
|
||||
pass
|
||||
|
||||
def test_local_stream_nested_enter_exit(self):
|
||||
pass
|
||||
|
||||
def test_stream_with_mutation(self):
|
||||
pass
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck(self):
|
||||
from torch._dynamo.variables.streams import fork_stream, join_stream
|
||||
|
||||
@ -466,6 +466,7 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase):
|
||||
"handle_cudnn_is_acceptable", # No global state
|
||||
"handle_assert", # No global state (constant)
|
||||
"handle_nested_tensor", # No global state
|
||||
"handle_current_stream", # Safely implemented
|
||||
)
|
||||
for fn in handlers:
|
||||
if isinstance(fn, staticmethod) or inspect.ismethod(fn):
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
|
||||
|
||||
class DummyPrivateUse1Module:
|
||||
@ -60,6 +60,9 @@ class TestExtensionUtils(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
|
||||
torch._register_device_module("privateuseone", DummyPrivateUse1Module)
|
||||
|
||||
@skipIfTorchDynamo(
|
||||
"accelerator doesn't compose with privateuse1 : https://github.com/pytorch/pytorch/issues/166696"
|
||||
)
|
||||
def test_external_module_register_with_renamed_backend(self):
|
||||
torch.utils.rename_privateuse1_backend("foo")
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been set"):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: PrivateUse1"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
|
||||
|
||||
class DummyPrivateUse1Module:
|
||||
@ -31,6 +31,9 @@ class DummyPrivateUse1Module:
|
||||
|
||||
|
||||
class TestRenamePrivateuseoneToExistingBackend(TestCase):
|
||||
@skipIfTorchDynamo(
|
||||
"TorchDynamo exposes https://github.com/pytorch/pytorch/issues/166696"
|
||||
)
|
||||
def test_external_module_register_with_existing_backend(self):
|
||||
torch.utils.rename_privateuse1_backend("maia")
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been set"):
|
||||
|
||||
@ -155,7 +155,6 @@ def reset() -> None:
|
||||
GenerationTracker.clear()
|
||||
TensorifyState.clear()
|
||||
torch._dynamo.utils.warn_once_cache.clear()
|
||||
torch._dynamo.utils.user_obj_id_to_weakref.clear()
|
||||
torch._C._autograd._saved_tensors_hooks_set_tracing(False)
|
||||
|
||||
|
||||
|
||||
@ -2495,6 +2495,30 @@
|
||||
}
|
||||
],
|
||||
"GB0249": [
|
||||
{
|
||||
"Gb_type": "bad device argument to torch.accelerator.current_stream",
|
||||
"Context": "args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
|
||||
"Hints": [
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
},
|
||||
{
|
||||
"Gb_type": "bad device argument to torch.get_device_module",
|
||||
"Context": "args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
|
||||
"Hints": [
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
},
|
||||
{
|
||||
"Gb_type": "bad device argument to torch.accelerator.current_stream",
|
||||
"Context": "args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
|
||||
"Hints": [
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
},
|
||||
{
|
||||
"Gb_type": "bad device argument to torch.get_device_module",
|
||||
"Context": "args={args}, kwargs={kwargs}",
|
||||
@ -2853,6 +2877,14 @@
|
||||
}
|
||||
],
|
||||
"GB0283": [
|
||||
{
|
||||
"Gb_type": "Failed to make weakref to graph-created external object",
|
||||
"Context": "user_object: {example_value}",
|
||||
"Explanation": "Object does not allow us to make a weakref to it",
|
||||
"Hints": []
|
||||
}
|
||||
],
|
||||
"GB0284": [
|
||||
{
|
||||
"Gb_type": "cannot resume from torch._dynamo.step_unsupported()",
|
||||
"Context": "",
|
||||
@ -2863,5 +2895,25 @@
|
||||
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0285": [
|
||||
{
|
||||
"Gb_type": "unsupported arguments to torch.accelerator.current_stream",
|
||||
"Context": "args={args}, kwargs={kwargs}",
|
||||
"Explanation": "torch.accelerator.current_stream accepts one optional argument `device`",
|
||||
"Hints": [
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0286": [
|
||||
{
|
||||
"Gb_type": "bad device argument to torch.get_device_module",
|
||||
"Context": "args={args}, kwargs={kwargs}",
|
||||
"Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
|
||||
"Hints": [
|
||||
"Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import weakref
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from torch._dynamo.source import Source
|
||||
|
||||
|
||||
PyCodegen = Any
|
||||
|
||||
# This file is to handle types that we don't want to support
|
||||
# as explicit FX graph inputs. This uses a sidetable which
|
||||
# we populate in bytecode and is loaded during graph execution
|
||||
@ -11,44 +13,70 @@ from torch._dynamo.source import Source
|
||||
# We use a dynamo-generated index as a level of indirection
|
||||
# this allows us to register objects externally in pre-graph bytecode that we want
|
||||
# to pass to the graph, but not support their types as graph inputs
|
||||
index_to_source: dict[int, Source] = {}
|
||||
index_to_bytecode_constructor: dict[int, Callable[[PyCodegen], None]] = {}
|
||||
|
||||
index_to_user_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
|
||||
index_to_external_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
|
||||
|
||||
keep_alive: list[Any] = []
|
||||
|
||||
|
||||
def has_user_objects() -> bool:
|
||||
return bool(index_to_source)
|
||||
return bool(index_to_bytecode_constructor)
|
||||
|
||||
|
||||
def get_user_object_by_index(index: int) -> Any:
|
||||
assert index in index_to_user_object_weakref, (
|
||||
def get_external_object_by_index(index: int) -> Any:
|
||||
assert index in index_to_external_object_weakref, (
|
||||
"Index not registered in index_to_user_object_weakref"
|
||||
)
|
||||
obj = index_to_user_object_weakref[index]()
|
||||
obj = index_to_external_object_weakref[index]()
|
||||
assert obj is not None, "User object is no longer alive"
|
||||
return index_to_user_object_weakref[index]()
|
||||
return index_to_external_object_weakref[index]()
|
||||
|
||||
|
||||
def store_user_object_weakrefs(*args: Any) -> None:
|
||||
global index_to_user_object_weakref
|
||||
index_to_user_object_weakref.clear()
|
||||
index_to_user_object_weakref.update(
|
||||
global index_to_external_object_weakref
|
||||
index_to_external_object_weakref.clear()
|
||||
index_to_external_object_weakref.update(
|
||||
{i: weakref.ref(arg) for i, arg in enumerate(args)}
|
||||
)
|
||||
|
||||
|
||||
def reset_user_object_tracking() -> None:
|
||||
index_to_source.clear()
|
||||
index_to_user_object_weakref.clear()
|
||||
index_to_bytecode_constructor.clear()
|
||||
index_to_external_object_weakref.clear()
|
||||
keep_alive.clear()
|
||||
|
||||
|
||||
def register_graph_created_object(
|
||||
example_value: Any, construct_fn: Callable[[int, PyCodegen], None]
|
||||
) -> int:
|
||||
global index_to_bytecode_constructor
|
||||
global keep_alive
|
||||
keep_alive.append(example_value)
|
||||
index = len(index_to_bytecode_constructor)
|
||||
index_to_bytecode_constructor[index] = lambda cg: construct_fn(index, cg)
|
||||
try:
|
||||
index_to_external_object_weakref[index] = weakref.ref(example_value)
|
||||
except TypeError as e:
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to make weakref to graph-created external object",
|
||||
context=f"user_object: {example_value}",
|
||||
explanation="Object does not allow us to make a weakref to it",
|
||||
hints=[],
|
||||
from_exc=e,
|
||||
)
|
||||
return index
|
||||
|
||||
|
||||
# Register a user object to be used in the graph
|
||||
def register_user_object(value: Any, source: Source) -> int:
|
||||
global index_to_source
|
||||
index = len(index_to_source)
|
||||
index_to_source[index] = source
|
||||
global index_to_bytecode_constructor
|
||||
index = len(index_to_bytecode_constructor)
|
||||
index_to_bytecode_constructor[index] = lambda cg: cg(source)
|
||||
try:
|
||||
index_to_user_object_weakref[index] = weakref.ref(value)
|
||||
index_to_external_object_weakref[index] = weakref.ref(value)
|
||||
except TypeError as e:
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
|
||||
@ -132,6 +132,7 @@ from .source import (
|
||||
CodeSource,
|
||||
ConstantSource,
|
||||
ConstDictKeySource,
|
||||
CurrentStreamSource,
|
||||
DataclassFieldsSource,
|
||||
DefaultsSource,
|
||||
DictGetItemSource,
|
||||
@ -181,6 +182,7 @@ from .utils import (
|
||||
common_constant_types,
|
||||
dataclass_fields,
|
||||
dict_keys,
|
||||
get_current_stream,
|
||||
get_custom_getattr,
|
||||
get_torch_function_mode_stack,
|
||||
get_torch_function_mode_stack_at,
|
||||
@ -759,6 +761,7 @@ def _get_closure_vars() -> dict[str, object]:
|
||||
"___dataclass_fields": dataclass_fields,
|
||||
"___namedtuple_fields": lambda x: x._fields,
|
||||
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
|
||||
"___get_current_stream": get_current_stream,
|
||||
"__math_isnan": math.isnan,
|
||||
"__numpy_isnan": None if np is None else np.isnan,
|
||||
"inf": float("inf"),
|
||||
@ -1448,6 +1451,13 @@ class GuardBuilder(GuardBuilderBase):
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, CurrentStreamSource):
|
||||
out = root_guard_manager.lambda_manager(
|
||||
python_lambda=lambda _: get_current_stream(source.device),
|
||||
source=source_name,
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, GradSource):
|
||||
assert base_guard_manager # to make mypy happy
|
||||
out = base_guard_manager.grad_manager(
|
||||
|
||||
@ -101,7 +101,7 @@ from .exc import (
|
||||
unimplemented_v2,
|
||||
unimplemented_v2_with_warning,
|
||||
)
|
||||
from .graph_bytecode_inputs import has_user_objects, index_to_source
|
||||
from .graph_bytecode_inputs import has_user_objects, index_to_bytecode_constructor
|
||||
from .graph_deduplication import apply_graph_deduplication
|
||||
from .graph_region_tracker import GraphRegionTracker
|
||||
from .guards import GuardBuilder, install_guard
|
||||
@ -1541,9 +1541,19 @@ class OutputGraph(OutputGraphCommon):
|
||||
"store_user_object_weakrefs",
|
||||
)
|
||||
)
|
||||
for source in reversed(index_to_source.values()):
|
||||
codegen(source)
|
||||
codegen.call_function(len(index_to_source), False)
|
||||
tmp_vars = []
|
||||
for constructor in reversed(index_to_bytecode_constructor.values()):
|
||||
constructor(codegen)
|
||||
var_name = (
|
||||
self.new_var()
|
||||
) # keep alive any temp objects for the rest of the frame
|
||||
codegen.store(var_name)
|
||||
tmp_vars.append(var_name)
|
||||
|
||||
for var_name in tmp_vars:
|
||||
codegen.append_output(codegen.create_load(var_name))
|
||||
|
||||
codegen.call_function(len(index_to_bytecode_constructor), False)
|
||||
codegen.pop_top()
|
||||
self.add_output_instructions(codegen.get_instructions())
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch import device as device_type
|
||||
from torch._guards import ChainedSource, Guard, GuardSource, Source
|
||||
|
||||
from . import utils
|
||||
@ -1082,6 +1083,30 @@ class ShapeEnvSource(Source):
|
||||
return GuardSource.SHAPE_ENV
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CurrentStreamSource(Source):
|
||||
device: device_type
|
||||
|
||||
def name(self) -> str:
|
||||
return f"___get_current_stream(torch.device('{self.device.type}', {self.device.index}))"
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
num_args = 1
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(utils.__name__, "get_current_stream")
|
||||
)
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("torch", "device"))
|
||||
codegen.extend_output([codegen.create_load_const(self.device.type)])
|
||||
if self.device.index is not None:
|
||||
num_args += 1
|
||||
codegen.extend_output([codegen.create_load_const(self.device.index)])
|
||||
codegen.extend_output(create_call_function(num_args, False))
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
return GuardSource.GLOBAL
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BackwardStateSource(Source):
|
||||
def name(self) -> str:
|
||||
|
||||
@ -171,6 +171,7 @@ from .variables.misc import (
|
||||
UnknownVariable,
|
||||
)
|
||||
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
|
||||
from .variables.streams import SymbolicStreamState
|
||||
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
|
||||
from .variables.torch_function import (
|
||||
SymbolicTorchFunctionState,
|
||||
@ -1104,6 +1105,7 @@ class InstructionTranslatorBase(
|
||||
symbolic_locals: dict[str, VariableTracker]
|
||||
symbolic_globals: dict[str, VariableTracker]
|
||||
symbolic_torch_function_state: SymbolicTorchFunctionState
|
||||
symbolic_stream_state: SymbolicStreamState
|
||||
post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]]
|
||||
stack: list[VariableTracker]
|
||||
instruction_pointer: Optional[int]
|
||||
@ -4243,6 +4245,7 @@ class InstructionTranslatorBase(
|
||||
symbolic_locals: dict[str, VariableTracker],
|
||||
symbolic_globals: dict[str, VariableTracker],
|
||||
symbolic_torch_function_state: SymbolicTorchFunctionState,
|
||||
symbolic_stream_state: SymbolicStreamState,
|
||||
f_code: types.CodeType,
|
||||
export: bool,
|
||||
inline_depth: int,
|
||||
@ -4262,6 +4265,7 @@ class InstructionTranslatorBase(
|
||||
self.symbolic_locals = symbolic_locals
|
||||
self.symbolic_globals = symbolic_globals
|
||||
self.symbolic_torch_function_state = symbolic_torch_function_state
|
||||
self.symbolic_stream_state = symbolic_stream_state
|
||||
# used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals)
|
||||
# in order to generate any nested closures
|
||||
self.post_prune_cell_and_freevars = None
|
||||
@ -4416,6 +4420,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
# A global var is inserted only after a STORE_GLOBAL happens to it
|
||||
symbolic_globals={},
|
||||
symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
|
||||
symbolic_stream_state=None, # type: ignore[arg-type] # set below
|
||||
f_code=f_code,
|
||||
export=export,
|
||||
inline_depth=0,
|
||||
@ -4520,6 +4525,8 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
torch_function_mode_stack
|
||||
)
|
||||
|
||||
self.symbolic_stream_state = SymbolicStreamState()
|
||||
|
||||
if export:
|
||||
# export gets confused if we never realize unused inputs
|
||||
# in export mode just eagerly realize everything
|
||||
@ -4846,6 +4853,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
sub_locals,
|
||||
parent.symbolic_globals,
|
||||
parent.symbolic_torch_function_state,
|
||||
parent.symbolic_stream_state,
|
||||
func,
|
||||
)
|
||||
else:
|
||||
@ -4857,6 +4865,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
sub_locals,
|
||||
parent.symbolic_globals,
|
||||
parent.symbolic_torch_function_state,
|
||||
parent.symbolic_stream_state,
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
func,
|
||||
)
|
||||
@ -4940,6 +4949,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
symbolic_locals: dict[str, VariableTracker],
|
||||
symbolic_globals: dict[str, VariableTracker],
|
||||
symbolic_torch_function_state: SymbolicTorchFunctionState,
|
||||
symbolic_stream_state: SymbolicStreamState,
|
||||
funcvar: BaseUserFunctionVariable,
|
||||
) -> None:
|
||||
f_globals = funcvar.get_globals() # type: ignore[attr-defined]
|
||||
@ -4973,6 +4983,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
symbolic_locals=symbolic_locals,
|
||||
symbolic_globals=symbolic_globals,
|
||||
symbolic_torch_function_state=symbolic_torch_function_state,
|
||||
symbolic_stream_state=symbolic_stream_state,
|
||||
instructions=instructions,
|
||||
code_options={k: getattr(code, k) for k in get_code_keys()},
|
||||
f_code=code,
|
||||
|
||||
@ -4695,6 +4695,10 @@ def clear_torch_function_mode_stack() -> None:
|
||||
_pop_torch_function_stack()
|
||||
|
||||
|
||||
def get_current_stream(device: torch.device) -> torch.Stream:
|
||||
return torch.accelerator.current_stream(device)
|
||||
|
||||
|
||||
# call from C dynamo in order to inspect values in pdb
|
||||
def _breakpoint_for_c_dynamo(*args: Any) -> None:
|
||||
breakpoint()
|
||||
@ -4759,33 +4763,8 @@ def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]:
|
||||
return tensor_dict
|
||||
|
||||
|
||||
# This is useful for reconstructing within the Dynamo graph the non-graph-input objects
|
||||
# whose lifetime is governed by the user.
|
||||
# e.g. torch.cuda.Event is a prime example.
|
||||
user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {}
|
||||
|
||||
|
||||
# TODO: mlazos to remove after replacing w/ above API
|
||||
def get_user_object_from_id(obj_id: int) -> Any:
|
||||
obj = user_obj_id_to_weakref[obj_id]()
|
||||
assert obj is not None, "User object is no longer alive"
|
||||
return obj
|
||||
|
||||
|
||||
def store_user_object_weakref(obj: object) -> None:
|
||||
obj_id = id(obj)
|
||||
try:
|
||||
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)
|
||||
except TypeError as e:
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to make weakref to User Object when storing by ID",
|
||||
context=f"user_objected: {obj}",
|
||||
explanation="Object does not allow us to make a weakref to it",
|
||||
hints=[],
|
||||
from_exc=e,
|
||||
)
|
||||
def build_stream(args: tuple[Any], kwargs: dict[Any, Any]) -> torch.Stream:
|
||||
return torch._C.Stream(*args, **kwargs)
|
||||
|
||||
|
||||
class CompileTimeInstructionCounter:
|
||||
|
||||
@ -46,7 +46,7 @@ import torch
|
||||
from torch import SymInt
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.graph_bytecode_inputs import (
|
||||
get_user_object_by_index,
|
||||
get_external_object_by_index,
|
||||
register_user_object,
|
||||
)
|
||||
from torch._dynamo.utils import (
|
||||
@ -1057,13 +1057,12 @@ class VariableBuilder:
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
index = register_user_object(value, self.source)
|
||||
stream_proxy = self.tx.output.create_proxy(
|
||||
"call_function", get_user_object_by_index, (index,), {}
|
||||
"call_function", get_external_object_by_index, (index,), {}
|
||||
)
|
||||
set_example_value(stream_proxy.node, value)
|
||||
var = StreamVariable(
|
||||
stream_proxy,
|
||||
value,
|
||||
value.device,
|
||||
source=self.source,
|
||||
)
|
||||
return self.tx.output.side_effects.track_object_existing(value, var)
|
||||
@ -1078,7 +1077,7 @@ class VariableBuilder:
|
||||
index = register_user_object(value, self.source)
|
||||
event_proxy = self.tx.output.create_proxy(
|
||||
"call_function",
|
||||
get_user_object_by_index,
|
||||
get_external_object_by_index,
|
||||
(index,),
|
||||
{},
|
||||
)
|
||||
@ -3006,14 +3005,15 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
set_example_value(proxy.node, example_value)
|
||||
return SymNodeVariable(proxy, example_value, **options)
|
||||
elif (
|
||||
inspect.isclass(proxy.node.target)
|
||||
and issubclass(proxy.node.target, torch.Stream)
|
||||
isinstance(example_value, torch.Stream)
|
||||
and proxy.node.target
|
||||
in (get_external_object_by_index, torch.accelerator.current_stream)
|
||||
) or proxy.node.target in [
|
||||
device_interface.current_stream
|
||||
for _, device_interface in get_registered_device_interfaces()
|
||||
]:
|
||||
set_example_value(proxy.node, example_value)
|
||||
return StreamVariable(proxy, example_value, example_value.device, **options)
|
||||
return StreamVariable(proxy, example_value, **options)
|
||||
elif (
|
||||
inspect.isclass(proxy.node.target)
|
||||
and issubclass(proxy.node.target, torch.Event)
|
||||
|
||||
@ -1,15 +1,18 @@
|
||||
from typing import Any, Optional
|
||||
import collections
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._dynamo.variables.dicts import ConstDictVariable
|
||||
from torch._dynamo.variables.lists import TupleVariable
|
||||
from torch.fx import Proxy
|
||||
|
||||
from .. import graph_break_hints
|
||||
from ..device_interface import get_interface_for_device
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..exc import TYPE_CHECKING, unimplemented_v2
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import ContextWrappingVariable
|
||||
from .misc import GetAttrVariable
|
||||
from .ctx_manager import FxTracebackAnnotateVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -63,104 +66,89 @@ def _(
|
||||
pass
|
||||
|
||||
|
||||
class StreamContextVariable(ContextWrappingVariable):
|
||||
class SymbolicStreamState:
|
||||
"""Track the currently entered stream if any"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
from ..source import CurrentStreamSource
|
||||
|
||||
cur_stack: list[StreamVariable] = []
|
||||
if torch.accelerator.is_available():
|
||||
stream_var = LazyVariableTracker.create(
|
||||
torch.accelerator.current_stream(),
|
||||
source=CurrentStreamSource(torch.accelerator.current_stream().device),
|
||||
)
|
||||
cur_stack = [stream_var] # type: ignore[list-item]
|
||||
|
||||
self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque(
|
||||
cur_stack
|
||||
)
|
||||
|
||||
def enter_stream(self, stream: "StreamVariable") -> None:
|
||||
self.cur_stream_stack.append(stream)
|
||||
|
||||
def exit_stream(self) -> None:
|
||||
self.cur_stream_stack.pop()
|
||||
|
||||
def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable":
|
||||
if device is not None:
|
||||
for stream in reversed(self.cur_stream_stack):
|
||||
if stream.device == device:
|
||||
return stream
|
||||
|
||||
return self.cur_stream_stack[-1]
|
||||
|
||||
def in_stream_context(self) -> bool:
|
||||
return len(self.cur_stream_stack) > 0
|
||||
|
||||
|
||||
class StreamContextVariable(FxTracebackAnnotateVariable):
|
||||
"""This represents torch.cuda.StreamContext"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
tx: "InstructionTranslator",
|
||||
target_value: "StreamVariable",
|
||||
stream_to_enter: "StreamVariable",
|
||||
**kwargs: dict[str, Any],
|
||||
) -> "StreamContextVariable":
|
||||
return StreamContextVariable(
|
||||
target_values=[target_value],
|
||||
initial_values=[
|
||||
StreamContextVariable._get_current_stream(target_value.device, tx)
|
||||
],
|
||||
device=target_value.device,
|
||||
stream_to_enter,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_values: list["StreamVariable"],
|
||||
device: torch.device,
|
||||
initial_values: Optional[list["StreamVariable"]] = None,
|
||||
stream: Optional["StreamVariable"],
|
||||
**kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
self.stream = stream
|
||||
super().__init__(
|
||||
target_values=target_values, initial_values=initial_values, **kwargs
|
||||
target_values={"stream": self.get_stream().user_object_index},
|
||||
initial_values=None,
|
||||
**kwargs,
|
||||
)
|
||||
# pyrefly: ignore [read-only]
|
||||
self.device = device
|
||||
|
||||
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
|
||||
def enter(
|
||||
self, tx: "InstructionTranslator", *args: tuple[Any]
|
||||
) -> "VariableTracker":
|
||||
# to stream, from stream is the order of the arguments
|
||||
# we are entering the target, and leaving the initial stream
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
torch.ops.streams.fork.default,
|
||||
self._target_stream_proxies() + self._initial_stream_proxies(),
|
||||
{},
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
tx.symbolic_stream_state.enter_stream(self.get_stream())
|
||||
return super().enter(tx)
|
||||
|
||||
def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker":
|
||||
# to stream, from stream is the order of the arguments
|
||||
# we are leaving the target, and entering the initial stream
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
torch.ops.streams.join.default,
|
||||
self._initial_stream_proxies() + self._target_stream_proxies(),
|
||||
{},
|
||||
)
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
def _initial_stream_proxies(self) -> tuple[Proxy, Proxy]:
|
||||
assert self.initial_values, "No initial stream to move from"
|
||||
return StreamContextVariable._extract_stream_properties(
|
||||
self.initial_values[0].as_proxy()
|
||||
)
|
||||
|
||||
def _target_stream_proxies(self) -> tuple[Proxy, Proxy]:
|
||||
return StreamContextVariable._extract_stream_properties(
|
||||
self._get_target_values()[0].as_proxy()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_stream_properties(stream_proxy: Proxy) -> tuple[Proxy, Proxy]:
|
||||
stream_index = GetAttrVariable.create_getattr_proxy(stream_proxy, "stream_id")
|
||||
stream_device = GetAttrVariable.create_getattr_proxy(stream_proxy, "device")
|
||||
return stream_index, stream_device
|
||||
|
||||
@staticmethod
|
||||
def _get_current_stream(
|
||||
device: torch.device, tx: "InstructionTranslator"
|
||||
) -> "StreamVariable":
|
||||
from .builder import wrap_fx_proxy_cls
|
||||
|
||||
current_stream_method = get_interface_for_device(device).current_stream
|
||||
current_stream = wrap_fx_proxy_cls(
|
||||
StreamVariable,
|
||||
tx,
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
current_stream_method,
|
||||
(None,),
|
||||
{},
|
||||
),
|
||||
)
|
||||
return current_stream
|
||||
|
||||
def _get_target_values(self) -> list["StreamVariable"]:
|
||||
# We need this to be overridable, since StreamVariable does
|
||||
# not store target values (it does not require any arguments)
|
||||
# and captures the current stream at the time of entering the context
|
||||
return self.target_values
|
||||
tx.symbolic_stream_state.exit_stream()
|
||||
return super().exit(tx, *args)
|
||||
|
||||
def supports_graph_breaks(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_stream(self) -> "StreamVariable":
|
||||
assert self.stream, "Stream context should have a separate stream"
|
||||
return self.stream
|
||||
|
||||
|
||||
class StreamVariable(StreamContextVariable):
|
||||
"""Represents the device-agnostic torch.Stream class"""
|
||||
@ -169,19 +157,21 @@ class StreamVariable(StreamContextVariable):
|
||||
self,
|
||||
proxy: Proxy,
|
||||
value: torch.Stream,
|
||||
device: torch.device,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Index into the user object table
|
||||
# used to pass arbitrary objects to the graph
|
||||
user_object_index = kwargs.pop("user_obj_index", None)
|
||||
if proxy is not None and "example_value" in proxy.node.meta:
|
||||
assert proxy.node.meta["example_value"] == value
|
||||
assert value.device.type == device.type, (
|
||||
"stream value is not equal to the passed device"
|
||||
)
|
||||
super().__init__(target_values=[], initial_values=None, device=device, **kwargs)
|
||||
|
||||
self.proxy = proxy
|
||||
self.value = value
|
||||
# pyrefly: ignore [read-only]
|
||||
self.device = device
|
||||
self.device = value.device
|
||||
# pyrefly: ignore [read-only]
|
||||
self.user_object_index = user_object_index
|
||||
super().__init__(None, **kwargs)
|
||||
|
||||
def python_type(self) -> type:
|
||||
return torch.Stream
|
||||
@ -240,15 +230,6 @@ class StreamVariable(StreamContextVariable):
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
|
||||
# NB: Set initial values when we enter
|
||||
# Don't do this at object creation, as we need to record the current stream
|
||||
# at the time the context is entered.
|
||||
self.initial_values = [
|
||||
StreamContextVariable._get_current_stream(self.device, tx)
|
||||
]
|
||||
return super().enter(tx)
|
||||
|
||||
def as_proxy(self) -> Proxy:
|
||||
return self.proxy
|
||||
|
||||
@ -262,18 +243,39 @@ class StreamVariable(StreamContextVariable):
|
||||
# If we got here, this stream is fully subsumed by the graph - this means it is
|
||||
# not an input or global
|
||||
assert not self.source
|
||||
# Since we just proved that - for other such structures, like lists and dicts, reconstruction
|
||||
# is fine and sound according to dynamo principles of treating collectives. However,
|
||||
# streams are special in that we want to preserve the identity of the stream as the same as in the graph
|
||||
# Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
|
||||
# yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
|
||||
# design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
|
||||
prefix = f"_stream_{self.device}"
|
||||
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
if self.user_object_index is not None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
torch._dynamo.graph_bytecode_inputs.__name__,
|
||||
"get_external_object_by_index",
|
||||
)
|
||||
)
|
||||
codegen.append_output(codegen.create_load_const(self.user_object_index))
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
else:
|
||||
# TODO mlazos: evaluate if we still need this
|
||||
prefix = f"_stream_{self.device}"
|
||||
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
|
||||
def _get_target_values(self) -> list["StreamVariable"]:
|
||||
return [self]
|
||||
def get_stream(self) -> "StreamVariable":
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def make_construct_in_graph_stream_fn(
|
||||
args: TupleVariable, kwargs: ConstDictVariable
|
||||
) -> Callable[[int, "PyCodegen"], None]:
|
||||
def fn(index: int, codegen: "PyCodegen") -> None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
torch._dynamo.utils.__name__, "build_stream"
|
||||
)
|
||||
)
|
||||
codegen(args)
|
||||
codegen(kwargs)
|
||||
codegen.extend_output(create_call_function(2, False))
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
class EventVariable(VariableTracker):
|
||||
|
||||
@ -1268,6 +1268,35 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
# pyrefly: ignore [unbound-name]
|
||||
return VariableTracker.build(tx, module, new_source)
|
||||
|
||||
@register(torch.accelerator.current_stream)
|
||||
def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs):
|
||||
unimplemented_v2(
|
||||
gb_type="unsupported arguments to torch.accelerator.current_stream",
|
||||
context=f"args={args}, kwargs={kwargs}",
|
||||
explanation="torch.accelerator.current_stream accepts one optional argument `device`",
|
||||
hints=[
|
||||
*graph_break_hints.USER_ERROR,
|
||||
],
|
||||
)
|
||||
try:
|
||||
if kwargs:
|
||||
device = torch.device(kwargs["device"].as_python_constant())
|
||||
elif args:
|
||||
device = torch.device(args[0].as_python_constant())
|
||||
else:
|
||||
device = None
|
||||
|
||||
return tx.symbolic_stream_state.cur_stream(device)
|
||||
except Exception as e:
|
||||
unimplemented_v2(
|
||||
gb_type="bad device argument to torch.accelerator.current_stream",
|
||||
context=f"args={args}, kwargs={kwargs}",
|
||||
explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)",
|
||||
hints=[*graph_break_hints.USER_ERROR],
|
||||
from_exc=e,
|
||||
)
|
||||
|
||||
@register(torch.set_default_device)
|
||||
def handle_set_default_device(
|
||||
self, tx: "InstructionTranslator", *args, **kwargs
|
||||
|
||||
@ -58,6 +58,7 @@ from ..exc import (
|
||||
raise_observed_exception,
|
||||
unimplemented_v2,
|
||||
)
|
||||
from ..graph_bytecode_inputs import get_external_object_by_index
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import (
|
||||
AttrSource,
|
||||
@ -94,7 +95,7 @@ from ..utils import (
|
||||
unpatched_nn_module_getattr,
|
||||
)
|
||||
from .base import raise_type_error_exc, ValueMutationNew, VariableTracker
|
||||
from .dicts import DefaultDictVariable
|
||||
from .dicts import ConstDictVariable, DefaultDictVariable
|
||||
from .lists import SizeVariable
|
||||
|
||||
|
||||
@ -809,14 +810,44 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
)
|
||||
args = [stacked]
|
||||
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
*proxy_args_kwargs(args, kwargs),
|
||||
),
|
||||
)
|
||||
if issubclass(self.value, torch.Stream):
|
||||
from .constant import ConstantVariable
|
||||
from .lists import TupleVariable
|
||||
|
||||
# Register newly created stream for reconstruction
|
||||
var_kwargs = ConstDictVariable(
|
||||
{ConstantVariable(k): v for k, v in kwargs.items()}
|
||||
)
|
||||
var_args = TupleVariable(list(args))
|
||||
stream = self.value(
|
||||
*(var_args.as_python_constant()),
|
||||
**(var_kwargs.as_python_constant()),
|
||||
)
|
||||
from ..graph_bytecode_inputs import register_graph_created_object
|
||||
from .streams import StreamVariable
|
||||
|
||||
ind = register_graph_created_object(
|
||||
stream,
|
||||
StreamVariable.make_construct_in_graph_stream_fn(
|
||||
var_args, var_kwargs
|
||||
),
|
||||
)
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function", get_external_object_by_index, (ind,), {}
|
||||
),
|
||||
user_obj_index=ind,
|
||||
)
|
||||
else:
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
*proxy_args_kwargs(args, kwargs),
|
||||
),
|
||||
)
|
||||
|
||||
return tensor_variable
|
||||
elif self.value is random.Random:
|
||||
|
||||
Reference in New Issue
Block a user