mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-28 02:04:53 +08:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 207c2248a8 | |||
| a206dcc79e | |||
| f2d7f235a6 | |||
| 402b289f3b | |||
| a32157c67c | |||
| 24e7f29099 | |||
| 5b5d269d34 | |||
| fa88f390a0 | |||
| fe39c07826 | |||
| cba195c8ed | |||
| 16e67be7f1 | |||
| 7afffdf48b | |||
| ca45649eb5 |
@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,2
|
yolov3,pass,0
|
||||||
|
|||||||
|
@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,9
|
yolov3,fail_accuracy,8
|
||||||
|
|||||||
|
@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
pyhpc_isoneutral_mixing,fail_to_run,0
|
pyhpc_isoneutral_mixing,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,fail_to_run,0
|
yolov3,pass,0
|
||||||
|
|||||||
|
@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,2
|
yolov3,pass,0
|
||||||
|
|||||||
|
@ -338,4 +338,4 @@ vision_maskrcnn,pass,28
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,2
|
yolov3,pass,0
|
||||||
|
|||||||
|
@ -242,7 +242,7 @@ pyhpc_equation_of_state,pass,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
pyhpc_isoneutral_mixing,fail_to_run,0
|
pyhpc_isoneutral_mixing,pass,0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -350,4 +350,4 @@ vision_maskrcnn,fail_to_run,0
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,fail_to_run,0
|
yolov3,pass,0
|
||||||
|
|||||||
|
@ -374,4 +374,4 @@ vision_maskrcnn,pass,17
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,2
|
yolov3,pass,0
|
||||||
|
|||||||
|
@ -282,4 +282,4 @@ vision_maskrcnn,pass,34
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,9
|
yolov3,fail_accuracy,8
|
||||||
|
|||||||
|
@ -298,4 +298,4 @@ vision_maskrcnn,pass,28
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,2
|
yolov3,pass,0
|
||||||
|
|||||||
|
@ -374,4 +374,4 @@ vision_maskrcnn,pass,17
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,2
|
yolov3,pass,0
|
||||||
|
|||||||
|
@ -282,4 +282,4 @@ vision_maskrcnn,pass,34
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,9
|
yolov3,pass,8
|
||||||
|
|||||||
|
@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,2
|
yolov3,pass,0
|
||||||
|
|||||||
|
@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,9
|
yolov3,pass,8
|
||||||
|
|||||||
|
@ -378,4 +378,4 @@ vision_maskrcnn,pass,17
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,2
|
yolov3,pass,0
|
||||||
|
|||||||
|
@ -286,4 +286,4 @@ vision_maskrcnn,pass,34
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
yolov3,pass,9
|
yolov3,pass,8
|
||||||
|
|||||||
|
@ -4,12 +4,11 @@ phlippe_densenet,float32,static,default,1.3988316
|
|||||||
basic_gnn_gcn,float32,dynamic,default,1.074576405
|
basic_gnn_gcn,float32,dynamic,default,1.074576405
|
||||||
llama_v2_7b_16h,float32,dynamic,default,1.211740245
|
llama_v2_7b_16h,float32,dynamic,default,1.211740245
|
||||||
resnet50,float32,dynamic,default,1.65984261
|
resnet50,float32,dynamic,default,1.65984261
|
||||||
timm_efficientnet,float32,static,cpp,2.271561735
|
#timm_efficientnet,float32,static,cpp,2.1938112
|
||||||
mobilenet_v3_large,float32,static,cpp,2.63375628
|
mobilenet_v3_large,float32,static,cpp,2.63375628
|
||||||
timm_resnest,float32,dynamic,cpp,1.67998548
|
timm_resnest,float32,dynamic,cpp,1.67998548
|
||||||
pyhpc_turbulent_kinetic_energy,float32,dynamic,cpp,1.59968463
|
pyhpc_turbulent_kinetic_energy,float32,dynamic,cpp,1.59968463
|
||||||
#hf_GPT2,float32,dynamic,cpp,
|
#hf_GPT2,float32,dynamic,cpp,1.292704418
|
||||||
hf_GPT2,float32,dynamic,cpp,1.379885175
|
|
||||||
resnext50_32x4d,amp,static,default,1.461687045
|
resnext50_32x4d,amp,static,default,1.461687045
|
||||||
vgg16,amp,static,default,1.267194285
|
vgg16,amp,static,default,1.267194285
|
||||||
hf_Longformer,amp,dynamic,default,0.997006035
|
hf_Longformer,amp,dynamic,default,0.997006035
|
||||||
@ -17,6 +16,6 @@ hf_Bert_large,amp,dynamic,default,0.99391146
|
|||||||
llama,amp,static,default,1.32950568
|
llama,amp,static,default,1.32950568
|
||||||
timm_regnet,amp,static,cpp,1.157188305
|
timm_regnet,amp,static,cpp,1.157188305
|
||||||
lennard_jones,amp,static,cpp,2.240104485
|
lennard_jones,amp,static,cpp,2.240104485
|
||||||
hf_T5_generate,amp,dynamic,cpp,1.447656135
|
#hf_T5_generate,amp,dynamic,cpp,1.29339502
|
||||||
timm_vovnet,amp,dynamic,cpp,1.07856471
|
timm_vovnet,amp,dynamic,cpp,1.07856471
|
||||||
mobilenet_v2,amp,dynamic,cpp,2.27774577
|
mobilenet_v2,amp,dynamic,cpp,2.27774577
|
||||||
|
|||||||
|
@ -25,10 +25,6 @@ from torch._dynamo.utils import clone_inputs
|
|||||||
# We are primarily interested in tf32 datatype
|
# We are primarily interested in tf32 datatype
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
# Enable FX graph caching
|
|
||||||
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
|
|
||||||
torch._inductor.config.fx_graph_cache = True
|
|
||||||
|
|
||||||
|
|
||||||
def _reassign_parameters(model):
|
def _reassign_parameters(model):
|
||||||
# torch_geometric models register parameter as tensors due to
|
# torch_geometric models register parameter as tensors due to
|
||||||
|
|||||||
@ -827,6 +827,7 @@ libtorch_python_core_sources = [
|
|||||||
"torch/csrc/dynamo/guards.cpp",
|
"torch/csrc/dynamo/guards.cpp",
|
||||||
"torch/csrc/dynamo/init.cpp",
|
"torch/csrc/dynamo/init.cpp",
|
||||||
"torch/csrc/functorch/init.cpp",
|
"torch/csrc/functorch/init.cpp",
|
||||||
|
"torch/csrc/fx/node.cpp",
|
||||||
"torch/csrc/mps/Module.cpp",
|
"torch/csrc/mps/Module.cpp",
|
||||||
"torch/csrc/mtia/Module.cpp",
|
"torch/csrc/mtia/Module.cpp",
|
||||||
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
||||||
|
|||||||
@ -62,8 +62,8 @@ Overall, the ``pipelining`` package provides the following features:
|
|||||||
application on the Llama model.
|
application on the Llama model.
|
||||||
|
|
||||||
|
|
||||||
Step 1: build ``PipelineStage`` for execution
|
Step 1: build ``PipelineStage``
|
||||||
*********************************************
|
*******************************
|
||||||
|
|
||||||
Before we can use a ``PipelineSchedule``, we need to create ``PipelineStage``
|
Before we can use a ``PipelineSchedule``, we need to create ``PipelineStage``
|
||||||
objects that wrap the part of the model running in that stage. The
|
objects that wrap the part of the model running in that stage. The
|
||||||
|
|||||||
@ -304,13 +304,12 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
self.assertEqual(cnt.frame_count, 0)
|
self.assertEqual(cnt.frame_count, 0)
|
||||||
|
|
||||||
def test_torch_guards_stack_frame_register_inlining_disable(self):
|
def test_torch_guards_stack_frame_register_inlining_disable(self):
|
||||||
y = torch.nn.Parameter(torch.tensor([0.25, 0.25]))
|
|
||||||
x = torch.tensor([0.5, 0.5])
|
x = torch.tensor([0.5, 0.5])
|
||||||
|
|
||||||
class encoder(torch.nn.Module):
|
class encoder(torch.nn.Module):
|
||||||
def __init__(self, y):
|
def __init__(self, y):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_parameter("param", y)
|
self.a = y
|
||||||
|
|
||||||
@torch._dynamo.disable
|
@torch._dynamo.disable
|
||||||
def helper(self, x, y):
|
def helper(self, x, y):
|
||||||
@ -318,9 +317,9 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
def forward(self, a, *args):
|
def forward(self, a, *args):
|
||||||
x = a + a
|
x = a + a
|
||||||
return self.helper(x, self.param)
|
return self.helper(x, self.a)
|
||||||
|
|
||||||
e = encoder(y)
|
e = encoder(2.0)
|
||||||
|
|
||||||
seen_frames = []
|
seen_frames = []
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|||||||
@ -56,7 +56,7 @@ class BinaryFoldingTemplate(TestCase):
|
|||||||
self.use_scalar = scalar
|
self.use_scalar = scalar
|
||||||
tensor_size = [1 for _ in range(self.conv.weight.ndim)]
|
tensor_size = [1 for _ in range(self.conv.weight.ndim)]
|
||||||
tensor_size[1] = self.conv.weight.size(0)
|
tensor_size[1] = self.conv.weight.size(0)
|
||||||
self.tensor = (
|
self.tensor = torch.nn.Parameter(
|
||||||
add_tensor
|
add_tensor
|
||||||
if add_tensor is not None
|
if add_tensor is not None
|
||||||
else torch.rand(tensor_size).to(device)
|
else torch.rand(tensor_size).to(device)
|
||||||
@ -136,7 +136,11 @@ class BinaryFoldingTemplate(TestCase):
|
|||||||
nn.Conv2d,
|
nn.Conv2d,
|
||||||
pytorch_op,
|
pytorch_op,
|
||||||
False,
|
False,
|
||||||
add_tensor=torch.rand(32, 1, 32).to(self.device),
|
add_tensor=torch.rand(
|
||||||
|
32,
|
||||||
|
1,
|
||||||
|
32,
|
||||||
|
).to(self.device),
|
||||||
expect_success=False,
|
expect_success=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -156,7 +160,7 @@ class BinaryFoldingTemplate(TestCase):
|
|||||||
nn.Conv2d,
|
nn.Conv2d,
|
||||||
pytorch_op,
|
pytorch_op,
|
||||||
False,
|
False,
|
||||||
add_tensor=torch.tensor([2]).to(torch.int).to(self.device),
|
add_tensor=torch.tensor([2]).to(torch.float64).to(self.device),
|
||||||
expect_success=False,
|
expect_success=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -195,7 +195,7 @@ class TestFxGraphCache(TestCase):
|
|||||||
num_put += 1
|
num_put += 1
|
||||||
|
|
||||||
cache_module = (
|
cache_module = (
|
||||||
"triton.runtime.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend"
|
"triton.fb.fb_memcache.FbMemcacheRemoteFxGraphCacheBackend"
|
||||||
if config.is_fbcode()
|
if config.is_fbcode()
|
||||||
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
|
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -267,7 +267,7 @@ class TestMaxAutotune(TestCase):
|
|||||||
num_put += 1
|
num_put += 1
|
||||||
|
|
||||||
cache_module = (
|
cache_module = (
|
||||||
"triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend"
|
"triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend"
|
||||||
if config.is_fbcode()
|
if config.is_fbcode()
|
||||||
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
|
else "torch._inductor.remote_cache.RedisRemoteCacheBackend"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -233,6 +233,23 @@ def run_fw_bw_and_get_code(fn):
|
|||||||
return run_and_get_code(run_with_backward)
|
return run_and_get_code(run_with_backward)
|
||||||
|
|
||||||
|
|
||||||
|
def register_ops_with_aoti_compile(ns, op_set, dispatch_key, torch_compile_op_lib_impl):
|
||||||
|
for _op_name in op_set:
|
||||||
|
qualified_op_name = f"{ns}::{_op_name}"
|
||||||
|
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
|
||||||
|
for overload_name in overload_names:
|
||||||
|
try:
|
||||||
|
reg_op_name = qualified_op_name
|
||||||
|
schema = torch._C._get_schema(qualified_op_name, overload_name)
|
||||||
|
if schema.overload_name:
|
||||||
|
reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
|
||||||
|
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
|
||||||
|
reg_op_name, dispatch_key
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
class TestCase(InductorTestCase):
|
class TestCase(InductorTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@ -751,6 +768,58 @@ class CommonTemplate:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||||
|
def test_eager_aoti_support_out(self):
|
||||||
|
ns = "aten"
|
||||||
|
op_name = "clamp"
|
||||||
|
dispatch_key = "CPU"
|
||||||
|
device = "cpu"
|
||||||
|
if self.device.lower() == "cuda":
|
||||||
|
dispatch_key = "CUDA"
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
inp_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(1.0)
|
||||||
|
min_tensor = inp_tensor - 0.05
|
||||||
|
max_tensor = inp_tensor + 0.05
|
||||||
|
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||||
|
ref_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
ref_tensor = torch.clamp(
|
||||||
|
max=max_tensor, min=min_tensor, input=inp_tensor, out=ref_out_tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
ref_tensor1 = torch.clamp(
|
||||||
|
max=max_tensor, out=ref_out_tensor1, min=min_tensor, input=inp_tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
register_ops_with_aoti_compile(
|
||||||
|
ns, [op_name], dispatch_key, torch_compile_op_lib_impl
|
||||||
|
)
|
||||||
|
|
||||||
|
res_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
res_tensor = torch.clamp(
|
||||||
|
max=max_tensor, min=min_tensor, input=inp_tensor, out=res_out_tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(ref_tensor, res_tensor)
|
||||||
|
self.assertEqual(ref_out_tensor, res_out_tensor)
|
||||||
|
|
||||||
|
res_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
res_tensor1 = torch.clamp(
|
||||||
|
max=max_tensor, out=res_out_tensor1, min=min_tensor, input=inp_tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(ref_tensor1, res_tensor1)
|
||||||
|
self.assertEqual(ref_out_tensor1, res_out_tensor1)
|
||||||
|
|
||||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||||
def test_eager_aoti_cache_hit(self):
|
def test_eager_aoti_cache_hit(self):
|
||||||
ns = "aten"
|
ns = "aten"
|
||||||
@ -779,24 +848,13 @@ class CommonTemplate:
|
|||||||
with mock.patch(
|
with mock.patch(
|
||||||
"torch._inductor.utils.aoti_compile_with_persistent_cache", None
|
"torch._inductor.utils.aoti_compile_with_persistent_cache", None
|
||||||
):
|
):
|
||||||
qualified_op_name = f"{ns}::{op_name}"
|
|
||||||
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
|
|
||||||
|
|
||||||
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||||
# Get ref result from eager
|
# Get ref result from eager
|
||||||
ref_value = getattr(torch.ops.aten, op_name)(input_tensor)
|
ref_value = getattr(torch.ops.aten, op_name)(input_tensor)
|
||||||
|
|
||||||
for overload_name in overload_names:
|
register_ops_with_aoti_compile(
|
||||||
try:
|
ns, [op_name], dispatch_key, torch_compile_op_lib_impl
|
||||||
reg_op_name = qualified_op_name
|
)
|
||||||
schema = torch._C._get_schema(qualified_op_name, overload_name)
|
|
||||||
if schema.overload_name:
|
|
||||||
reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
|
|
||||||
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
|
|
||||||
reg_op_name, dispatch_key
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Invoke the pre-compiled kernel and get result.
|
# Invoke the pre-compiled kernel and get result.
|
||||||
res_value = getattr(torch.ops.aten, op_name)(input_tensor)
|
res_value = getattr(torch.ops.aten, op_name)(input_tensor)
|
||||||
@ -804,7 +862,7 @@ class CommonTemplate:
|
|||||||
self.assertEqual(ref_value, res_value)
|
self.assertEqual(ref_value, res_value)
|
||||||
|
|
||||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||||
def test_aoti_compile_with_persistent_cache(self):
|
def test_eager_aoti_with_persistent_cache(self):
|
||||||
def fn(a):
|
def fn(a):
|
||||||
return torch.abs(a)
|
return torch.abs(a)
|
||||||
|
|
||||||
@ -906,19 +964,9 @@ class CommonTemplate:
|
|||||||
for scalar_value in scalar_values:
|
for scalar_value in scalar_values:
|
||||||
ref_values.append(torch.add(a, b, alpha=scalar_value))
|
ref_values.append(torch.add(a, b, alpha=scalar_value))
|
||||||
|
|
||||||
qualified_op_name = f"{namespace_name}::{op_name}"
|
register_ops_with_aoti_compile(
|
||||||
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
|
namespace_name, [op_name], dispatch_key, torch_compile_op_lib_impl
|
||||||
for overload_name in overload_names:
|
)
|
||||||
try:
|
|
||||||
reg_op_name = qualified_op_name
|
|
||||||
schema = torch._C._get_schema(reg_op_name, overload_name)
|
|
||||||
if schema.overload_name:
|
|
||||||
reg_op_name = f"{reg_op_name}.{schema.overload_name}"
|
|
||||||
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
|
|
||||||
reg_op_name, dispatch_key
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
continue
|
|
||||||
|
|
||||||
res_values = []
|
res_values = []
|
||||||
for scalar_value in scalar_values:
|
for scalar_value in scalar_values:
|
||||||
@ -928,8 +976,7 @@ class CommonTemplate:
|
|||||||
self.assertEqual(ref_values, res_values)
|
self.assertEqual(ref_values, res_values)
|
||||||
|
|
||||||
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
@skipCUDAIf(not SM80OrLater, "Requires sm80")
|
||||||
def test_torch_compile_override_registration(self):
|
def test_eager_aoti_override_registration(self):
|
||||||
dynamic = False
|
|
||||||
namespace_name = "aten"
|
namespace_name = "aten"
|
||||||
dispatch_key = "CPU"
|
dispatch_key = "CPU"
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -951,24 +998,10 @@ class CommonTemplate:
|
|||||||
ref = opt_fn(x)
|
ref = opt_fn(x)
|
||||||
ref_array.append(ref)
|
ref_array.append(ref)
|
||||||
|
|
||||||
def register_ops(op_set, dispatch_key, torch_compile_op_lib_impl):
|
|
||||||
for _op_name in op_set:
|
|
||||||
qualified_op_name = f"{namespace_name}::{_op_name}"
|
|
||||||
_, overload_names = torch._C._jit_get_operation(qualified_op_name)
|
|
||||||
for overload_name in overload_names:
|
|
||||||
try:
|
|
||||||
reg_op_name = qualified_op_name
|
|
||||||
schema = torch._C._get_schema(qualified_op_name, overload_name)
|
|
||||||
if schema.overload_name:
|
|
||||||
reg_op_name = f"{qualified_op_name}.{schema.overload_name}"
|
|
||||||
torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821
|
|
||||||
reg_op_name, dispatch_key
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
continue
|
|
||||||
|
|
||||||
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||||
register_ops(unary_op_set, dispatch_key, torch_compile_op_lib_impl)
|
register_ops_with_aoti_compile(
|
||||||
|
namespace_name, unary_op_set, dispatch_key, torch_compile_op_lib_impl
|
||||||
|
)
|
||||||
|
|
||||||
res_array = []
|
res_array = []
|
||||||
for unary_op_name in unary_op_set:
|
for unary_op_name in unary_op_set:
|
||||||
@ -985,7 +1018,9 @@ class CommonTemplate:
|
|||||||
ref_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
|
ref_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
|
||||||
|
|
||||||
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl:
|
||||||
register_ops(["clamp"], dispatch_key, torch_compile_op_lib_impl)
|
register_ops_with_aoti_compile(
|
||||||
|
namespace_name, ["clamp"], dispatch_key, torch_compile_op_lib_impl
|
||||||
|
)
|
||||||
res_with_min = torch.ops.aten.clamp(a, min_tensor)
|
res_with_min = torch.ops.aten.clamp(a, min_tensor)
|
||||||
res_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
|
res_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor)
|
||||||
self.assertEqual(ref_with_min, res_with_min)
|
self.assertEqual(ref_with_min, res_with_min)
|
||||||
@ -5502,6 +5537,14 @@ class CommonTemplate:
|
|||||||
for dtype in all_types():
|
for dtype in all_types():
|
||||||
self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),))
|
self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),))
|
||||||
|
|
||||||
|
def test_full_boolean(self):
|
||||||
|
def fn(n):
|
||||||
|
x = torch.full((1,), n >= 1024, device=self.device)
|
||||||
|
return x, x + 1
|
||||||
|
|
||||||
|
self.common(fn, (1024,))
|
||||||
|
self.common(fn, (1023,))
|
||||||
|
|
||||||
def test_index1(self):
|
def test_index1(self):
|
||||||
def fn(a, b, c):
|
def fn(a, b, c):
|
||||||
return aten.index(a, [b, c])
|
return aten.index(a, [b, c])
|
||||||
@ -7816,6 +7859,95 @@ class CommonTemplate:
|
|||||||
)
|
)
|
||||||
assertGeneratedKernelCountEqual(self, 0)
|
assertGeneratedKernelCountEqual(self, 0)
|
||||||
|
|
||||||
|
def test_avg_pool3d_backward(self):
|
||||||
|
def fn(a, b):
|
||||||
|
return aten.avg_pool3d_backward(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
[2, 2, 2],
|
||||||
|
[2, 2, 2],
|
||||||
|
[0, 0, 0],
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.common(
|
||||||
|
fn,
|
||||||
|
[
|
||||||
|
torch.randn([2, 4, 7, 7, 7]),
|
||||||
|
torch.randn([2, 4, 14, 14, 14]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_avg_pool3d_backward2(self):
|
||||||
|
def fn(a, b):
|
||||||
|
return aten.avg_pool3d_backward(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
[3, 3, 3],
|
||||||
|
[1, 1, 1],
|
||||||
|
[1, 1, 1],
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.common(
|
||||||
|
fn,
|
||||||
|
[
|
||||||
|
torch.randn([1, 1, 20, 20, 15]),
|
||||||
|
torch.randn([1, 1, 20, 20, 15]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_avg_pool3d_backward3(self):
|
||||||
|
def fn(a, b):
|
||||||
|
return aten.avg_pool3d_backward(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
[1, 1, 1],
|
||||||
|
[2, 2, 2],
|
||||||
|
[0, 0, 0],
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch._inductor.metrics.generated_kernel_count = 0
|
||||||
|
self.common(
|
||||||
|
fn,
|
||||||
|
[
|
||||||
|
torch.randn([1, 2016, 11, 11, 11]),
|
||||||
|
torch.randn([1, 2016, 21, 21, 21]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assertGeneratedKernelCountEqual(self, 1)
|
||||||
|
|
||||||
|
def test_avg_pool3d_backward4(self):
|
||||||
|
def fn(a, b):
|
||||||
|
return aten.avg_pool3d_backward(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
[13, 13, 13],
|
||||||
|
[1, 1, 1],
|
||||||
|
[0, 0, 0],
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch._inductor.metrics.generated_kernel_count = 0
|
||||||
|
self.common(
|
||||||
|
fn,
|
||||||
|
[
|
||||||
|
torch.randn([1, 16, 12, 12, 12]),
|
||||||
|
torch.randn([1, 16, 24, 24, 24]),
|
||||||
|
],
|
||||||
|
check_lowp=False,
|
||||||
|
)
|
||||||
|
assertGeneratedKernelCountEqual(self, 0)
|
||||||
|
|
||||||
@config.patch(search_autotune_cache=False)
|
@config.patch(search_autotune_cache=False)
|
||||||
def test_mm_views(self):
|
def test_mm_views(self):
|
||||||
def fn(a, b):
|
def fn(a, b):
|
||||||
|
|||||||
@ -121,6 +121,7 @@ test_failures = {
|
|||||||
"test_conv2d_channels_last_dynamic_shapes": TestFailure(("cpu",)),
|
"test_conv2d_channels_last_dynamic_shapes": TestFailure(("cpu",)),
|
||||||
"test_conv3d_channels_last_dynamic_shapes": TestFailure(("cpu",)),
|
"test_conv3d_channels_last_dynamic_shapes": TestFailure(("cpu",)),
|
||||||
"test_expand_dynamic_shapes": TestFailure(("cpu",)),
|
"test_expand_dynamic_shapes": TestFailure(("cpu",)),
|
||||||
|
"test_full_boolean_dynamic_shapes": TestFailure(("cpu",)),
|
||||||
"test_glu_dynamic_shapes": TestFailure(("cpu",)),
|
"test_glu_dynamic_shapes": TestFailure(("cpu",)),
|
||||||
"test_isinf2_dynamic_shapes": TestFailure(("cpu",)),
|
"test_isinf2_dynamic_shapes": TestFailure(("cpu",)),
|
||||||
"test_linspace1_dynamic_shapes": TestFailure(("cpu",)),
|
"test_linspace1_dynamic_shapes": TestFailure(("cpu",)),
|
||||||
@ -146,6 +147,7 @@ test_failures = {
|
|||||||
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||||
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||||
"test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
"test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||||
|
"test_avg_pool3d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||||
"test_baddbmm_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
"test_baddbmm_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||||
"test_bmm2_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
"test_bmm2_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||||
"test_both_scalars_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
"test_both_scalars_dynamic_shapes": TestFailure(("cpu", "cuda")),
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.library
|
import torch.library
|
||||||
@ -369,6 +370,47 @@ class TestInductorDynamic(TestCase):
|
|||||||
arg = torch.tensor(5, device=device)
|
arg = torch.tensor(5, device=device)
|
||||||
self.assertEqual(f(arg), cf(arg))
|
self.assertEqual(f(arg), cf(arg))
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(
|
||||||
|
capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
|
||||||
|
)
|
||||||
|
@torch._inductor.config.patch(implicit_fallbacks=True)
|
||||||
|
def test_unbacked_save_for_backwards(self, device) -> None:
|
||||||
|
@torch.library.custom_op("_test::_cat", mutates_args=())
|
||||||
|
def _cat(t: torch.Tensor, ds: List[int]) -> torch.Tensor:
|
||||||
|
return t * t.new_ones([sum(ds)])
|
||||||
|
|
||||||
|
@torch.library.register_fake("_test::_cat")
|
||||||
|
def _cat_fake(t: torch.Tensor, ds: List[int]) -> torch.Tensor:
|
||||||
|
[torch._check_is_size(d) for d in ds]
|
||||||
|
return t.new_empty([sum(ds)])
|
||||||
|
|
||||||
|
def _cat_setup_context(ctx, inputs, output):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _cat_backward(ctx, grad):
|
||||||
|
return grad.sum(), None
|
||||||
|
|
||||||
|
torch.library.register_autograd(
|
||||||
|
"_test::_cat",
|
||||||
|
_cat_backward,
|
||||||
|
setup_context=_cat_setup_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn(t, sizes):
|
||||||
|
r = torch.ops._test._cat(t, sizes.tolist())
|
||||||
|
return r * t
|
||||||
|
|
||||||
|
t = torch.randn((), requires_grad=True, device=device)
|
||||||
|
sizes = torch.tensor([4, 8], dtype=torch.int64, device="cpu")
|
||||||
|
out = fn(t, sizes)
|
||||||
|
out.sum().backward()
|
||||||
|
expect = t.grad
|
||||||
|
t.grad = None
|
||||||
|
torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True)(
|
||||||
|
t, sizes
|
||||||
|
).sum().backward()
|
||||||
|
self.assertEqual(t.grad, expect)
|
||||||
|
|
||||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_unbacked_reduction(self, device):
|
def test_unbacked_reduction(self, device):
|
||||||
expect_fail = device == "cpu" and not IS_ARM64
|
expect_fail = device == "cpu" and not IS_ARM64
|
||||||
|
|||||||
@ -2333,3 +2333,14 @@ def _save_pickle(obj: Any) -> bytes: ...
|
|||||||
# Defined in torch/csrc/jit/runtime/static/init.cpp
|
# Defined in torch/csrc/jit/runtime/static/init.cpp
|
||||||
def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ...
|
def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ...
|
||||||
def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ...
|
def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ...
|
||||||
|
|
||||||
|
# Defined in torch/csrc/fx/node.cpp
|
||||||
|
class _NodeBase:
|
||||||
|
_erased: _bool
|
||||||
|
_prev: "_NodeBase"
|
||||||
|
_next: "_NodeBase"
|
||||||
|
|
||||||
|
class _NodeIter(Iterator):
|
||||||
|
def __init__(self, root: _NodeBase, reversed: _bool) -> None: ...
|
||||||
|
def __iter__(self) -> Iterator[_NodeBase]: ...
|
||||||
|
def __next__(self) -> _NodeBase: ...
|
||||||
|
|||||||
@ -752,7 +752,13 @@ class OutputGraph:
|
|||||||
**options,
|
**options,
|
||||||
):
|
):
|
||||||
if is_dynamic_nn_module(target, self.root_tx.export):
|
if is_dynamic_nn_module(target, self.root_tx.export):
|
||||||
return variables.UnspecializedNNModuleVariable(target, **options)
|
result = variables.UnspecializedNNModuleVariable(target, **options)
|
||||||
|
if not SideEffects.cls_supports_mutation_side_effects(type(target)):
|
||||||
|
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||||
|
return result
|
||||||
|
return self.root_tx.output.side_effects.track_object_existing(
|
||||||
|
target, result
|
||||||
|
)
|
||||||
|
|
||||||
options = dict(options)
|
options = dict(options)
|
||||||
assert "source" in options
|
assert "source" in options
|
||||||
|
|||||||
@ -1128,6 +1128,19 @@ class VariableBuilder:
|
|||||||
if mutation_guard.is_dynamic_nn_module(value, self.tx.export):
|
if mutation_guard.is_dynamic_nn_module(value, self.tx.export):
|
||||||
# created dynamically, don't specialize on it
|
# created dynamically, don't specialize on it
|
||||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
|
if (
|
||||||
|
torch._dynamo.config.inline_inbuilt_nn_modules
|
||||||
|
and torch._inductor.config.freezing
|
||||||
|
and not torch.is_grad_enabled()
|
||||||
|
):
|
||||||
|
from ..decorators import mark_static_address
|
||||||
|
|
||||||
|
for p in value.parameters():
|
||||||
|
mark_static_address(p)
|
||||||
|
|
||||||
|
for b in value.buffers():
|
||||||
|
mark_static_address(b)
|
||||||
|
|
||||||
result = UnspecializedNNModuleVariable(value, source=self.source)
|
result = UnspecializedNNModuleVariable(value, source=self.source)
|
||||||
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
||||||
# don't allow STORE_ATTR mutation with custom __setattr__
|
# don't allow STORE_ATTR mutation with custom __setattr__
|
||||||
|
|||||||
@ -1021,7 +1021,7 @@ class FxGraphCache:
|
|||||||
cache_id = "fx-graph-v1"
|
cache_id = "fx-graph-v1"
|
||||||
try:
|
try:
|
||||||
if config.is_fbcode():
|
if config.is_fbcode():
|
||||||
from triton.runtime.fb_memcache import (
|
from triton.fb.fb_memcache import (
|
||||||
FbMemcacheRemoteFxGraphCacheBackend,
|
FbMemcacheRemoteFxGraphCacheBackend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -201,11 +201,19 @@ def _unlift_graph(mod, gm, graph_signature):
|
|||||||
|
|
||||||
outputs = list(gm.graph.nodes)[-1].args[0]
|
outputs = list(gm.graph.nodes)[-1].args[0]
|
||||||
mutated_outputs = []
|
mutated_outputs = []
|
||||||
for out in outputs:
|
buffer_mutations = graph_signature.buffers_to_mutate
|
||||||
if out.name in graph_signature.buffers_to_mutate:
|
user_input_mutations = graph_signature.user_inputs_to_mutate
|
||||||
mutated_outputs.append(graph_signature.buffers_to_mutate[out.name])
|
output_tokens = graph_signature.output_tokens
|
||||||
else:
|
for idx, out in enumerate(outputs):
|
||||||
mutated_outputs.append(None)
|
value = None
|
||||||
|
|
||||||
|
if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
|
||||||
|
if out.name in buffer_mutations:
|
||||||
|
value = buffer_mutations[out.name]
|
||||||
|
elif out.name in user_input_mutations:
|
||||||
|
value = user_input_mutations[out.name]
|
||||||
|
|
||||||
|
mutated_outputs.append(value)
|
||||||
|
|
||||||
unlifted_gm = _unlift(
|
unlifted_gm = _unlift(
|
||||||
gm,
|
gm,
|
||||||
@ -392,7 +400,7 @@ def should_use_remote_fx_graph_cache():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from triton.runtime.fb_memcache import MEMCACHE_VERSION
|
from triton.fb.fb_memcache import MEMCACHE_VERSION
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -177,7 +177,7 @@ def is_boolean_type(x):
|
|||||||
|
|
||||||
def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
|
def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
|
||||||
def construct_input(inp):
|
def construct_input(inp):
|
||||||
if isinstance(inp, (Number, sympy.Expr)):
|
if isinstance(inp, (Number, sympy.Basic)):
|
||||||
return inp
|
return inp
|
||||||
else:
|
else:
|
||||||
assert hasattr(inp, "get_dtype")
|
assert hasattr(inp, "get_dtype")
|
||||||
@ -216,7 +216,7 @@ def transform_args(args, broadcast, type_promotion_kind, convert_input_to_bool):
|
|||||||
promoting_args = [
|
promoting_args = [
|
||||||
a
|
a
|
||||||
for a in args
|
for a in args
|
||||||
if isinstance(a, (Number, sympy.Expr))
|
if isinstance(a, (Number, sympy.Basic))
|
||||||
or getattr(a, "dtype", None) is not None
|
or getattr(a, "dtype", None) is not None
|
||||||
]
|
]
|
||||||
dtype = get_promoted_dtype(
|
dtype = get_promoted_dtype(
|
||||||
@ -368,15 +368,15 @@ def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=No
|
|||||||
if override_return_dtype is None and type_promotion_kind is None:
|
if override_return_dtype is None and type_promotion_kind is None:
|
||||||
type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||||
|
|
||||||
if not any(isinstance(x, (sympy.Expr, int, float)) for x in inputs):
|
if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs):
|
||||||
return inputs
|
return inputs
|
||||||
if all(isinstance(x, (int, float, sympy.Expr)) for x in inputs):
|
if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs):
|
||||||
dtype = override_return_dtype or get_promoted_dtype(
|
dtype = override_return_dtype or get_promoted_dtype(
|
||||||
*inputs, type_promotion_kind=type_promotion_kind
|
*inputs, type_promotion_kind=type_promotion_kind
|
||||||
)
|
)
|
||||||
|
|
||||||
def const_func(x):
|
def const_func(x):
|
||||||
if isinstance(x, sympy.Expr):
|
if isinstance(x, sympy.Basic):
|
||||||
return ir.IndexingConstant(x, dtype, decode_device(None))
|
return ir.IndexingConstant(x, dtype, decode_device(None))
|
||||||
else:
|
else:
|
||||||
return ir.Constant(x, dtype, decode_device(None))
|
return ir.Constant(x, dtype, decode_device(None))
|
||||||
@ -391,7 +391,7 @@ def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=No
|
|||||||
ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size())
|
ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size())
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(x, sympy.Expr):
|
elif isinstance(x, sympy.Basic):
|
||||||
out.append(
|
out.append(
|
||||||
ExpandView.create(
|
ExpandView.create(
|
||||||
IndexingConstant(x, ex.get_dtype(), ex.get_device()),
|
IndexingConstant(x, ex.get_dtype(), ex.get_device()),
|
||||||
@ -2155,7 +2155,6 @@ make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
|
|||||||
|
|
||||||
|
|
||||||
# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp
|
# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp
|
||||||
make_fallback(aten.avg_pool3d_backward)
|
|
||||||
make_fallback(aten.max_pool3d_with_indices_backward)
|
make_fallback(aten.max_pool3d_with_indices_backward)
|
||||||
make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
|
make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
|
||||||
make_fallback(aten._adaptive_avg_pool3d_backward)
|
make_fallback(aten._adaptive_avg_pool3d_backward)
|
||||||
@ -2471,7 +2470,7 @@ def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
|
|||||||
|
|
||||||
ranges: List[sympy.Expr] = []
|
ranges: List[sympy.Expr] = []
|
||||||
|
|
||||||
if isinstance(data, sympy.Expr):
|
if isinstance(data, sympy.Basic):
|
||||||
|
|
||||||
def inner_fn(index):
|
def inner_fn(index):
|
||||||
return ops.index_expr(data, dtype)
|
return ops.index_expr(data, dtype)
|
||||||
@ -2597,7 +2596,7 @@ def _full(fill_value, device, dtype, size):
|
|||||||
def inner_fn(index):
|
def inner_fn(index):
|
||||||
return ops.constant(value, dtype)
|
return ops.constant(value, dtype)
|
||||||
|
|
||||||
elif isinstance(value, sympy.Expr):
|
elif isinstance(value, sympy.Basic):
|
||||||
|
|
||||||
def inner_fn(index):
|
def inner_fn(index):
|
||||||
return ops.index_expr(value, dtype)
|
return ops.index_expr(value, dtype)
|
||||||
@ -4034,11 +4033,32 @@ def pad_adaptive_loader(x, pad_val=0.0):
|
|||||||
return load
|
return load
|
||||||
|
|
||||||
|
|
||||||
def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
|
def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out):
|
||||||
h_start_index_fn, w_start_index_fn = start_index_fns
|
h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
|
||||||
h_end_index_fn, w_end_index_fn = end_index_fns
|
h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
|
||||||
|
|
||||||
def fn_sum(idx, loader):
|
w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
|
||||||
|
w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
|
||||||
|
|
||||||
|
return h_start_index, h_end_index, w_start_index, w_end_index
|
||||||
|
|
||||||
|
|
||||||
|
def _adaptive_pooling_fn(
|
||||||
|
start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
|
||||||
|
):
|
||||||
|
h_in, w_in = in_sizes
|
||||||
|
h_out, w_out = out_sizes
|
||||||
|
|
||||||
|
(
|
||||||
|
h_start_index_fn,
|
||||||
|
h_end_index_fn,
|
||||||
|
w_start_index_fn,
|
||||||
|
w_end_index_fn,
|
||||||
|
) = compute_indices_adaptive_pooling(
|
||||||
|
start_index, end_index, h_in, w_in, h_out, w_out
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn(idx, loader):
|
||||||
*prefix, bh, bw = idx
|
*prefix, bh, bw = idx
|
||||||
|
|
||||||
h_start_index = h_start_index_fn(bh)
|
h_start_index = h_start_index_fn(bh)
|
||||||
@ -4047,7 +4067,7 @@ def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
|
|||||||
w_start_index = w_start_index_fn(bw)
|
w_start_index = w_start_index_fn(bw)
|
||||||
w_end_index = w_end_index_fn(bw)
|
w_end_index = w_end_index_fn(bw)
|
||||||
|
|
||||||
total = None
|
result = None
|
||||||
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
|
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
|
||||||
val = loader(
|
val = loader(
|
||||||
prefix,
|
prefix,
|
||||||
@ -4055,13 +4075,66 @@ def _adaptive_pooling_idx_sum(kernel_maxes, start_index_fns, end_index_fns):
|
|||||||
[h_start_index, w_start_index],
|
[h_start_index, w_start_index],
|
||||||
[h_end_index, w_end_index],
|
[h_end_index, w_end_index],
|
||||||
)
|
)
|
||||||
if total is None:
|
if result is None:
|
||||||
total = val
|
result = val
|
||||||
else:
|
else:
|
||||||
total = ops.add(val, total)
|
result = pooling_fn(val, result)
|
||||||
return total
|
return result
|
||||||
|
|
||||||
return fn_sum
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def _adaptive_pooling_fn_with_idx(
|
||||||
|
start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
|
||||||
|
):
|
||||||
|
h_in, w_in = in_sizes
|
||||||
|
h_out, w_out = out_sizes
|
||||||
|
|
||||||
|
(
|
||||||
|
h_start_index_fn,
|
||||||
|
h_end_index_fn,
|
||||||
|
w_start_index_fn,
|
||||||
|
w_end_index_fn,
|
||||||
|
) = compute_indices_adaptive_pooling(
|
||||||
|
start_index, end_index, h_in, w_in, h_out, w_out
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn(idx, loader):
|
||||||
|
*prefix, bh, bw = idx
|
||||||
|
|
||||||
|
h_start_index = h_start_index_fn(bh)
|
||||||
|
h_end_index = h_end_index_fn(bh)
|
||||||
|
|
||||||
|
w_start_index = w_start_index_fn(bw)
|
||||||
|
w_end_index = w_end_index_fn(bw)
|
||||||
|
|
||||||
|
maxval = None
|
||||||
|
maxindex = None
|
||||||
|
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
|
||||||
|
val = loader(
|
||||||
|
prefix,
|
||||||
|
[ih, iw],
|
||||||
|
[h_start_index, w_start_index],
|
||||||
|
[h_end_index, w_end_index],
|
||||||
|
)
|
||||||
|
|
||||||
|
index = ops.index_expr(
|
||||||
|
(h_start_index + ih) * w_in + w_start_index + iw, torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
if maxindex is None:
|
||||||
|
maxindex = index
|
||||||
|
else:
|
||||||
|
maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
|
||||||
|
|
||||||
|
if maxval is None:
|
||||||
|
maxval = val
|
||||||
|
else:
|
||||||
|
maxval = pooling_fn(val, maxval)
|
||||||
|
|
||||||
|
return maxindex
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
fallback_adaptive_avg_pool2d = fallback_handler(
|
fallback_adaptive_avg_pool2d = fallback_handler(
|
||||||
@ -4099,27 +4172,24 @@ def _adaptive_avg_pool2d(x, output_size):
|
|||||||
new_size = list(batch) + [h_out, w_out]
|
new_size = list(batch) + [h_out, w_out]
|
||||||
dtype = x.get_dtype()
|
dtype = x.get_dtype()
|
||||||
|
|
||||||
|
window_size = h_kernel_max * w_kernel_max
|
||||||
|
if window_size > 25:
|
||||||
|
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
|
||||||
|
return fallback_adaptive_avg_pool2d(x, output_size)
|
||||||
|
|
||||||
def start_index(index, out_dim, inp_dim):
|
def start_index(index, out_dim, inp_dim):
|
||||||
return FloorDiv((index * inp_dim), out_dim)
|
return FloorDiv((index * inp_dim), out_dim)
|
||||||
|
|
||||||
def end_index(index, out_dim, inp_dim):
|
def end_index(index, out_dim, inp_dim):
|
||||||
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
|
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
|
||||||
|
|
||||||
h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
|
fn_sum = _adaptive_pooling_fn(
|
||||||
h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
|
start_index=start_index,
|
||||||
|
end_index=end_index,
|
||||||
w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
|
kernel_maxes=[h_kernel_max, w_kernel_max],
|
||||||
w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
|
in_sizes=[h_in, w_in],
|
||||||
|
out_sizes=[h_out, w_out],
|
||||||
window_size = h_kernel_max * w_kernel_max
|
pooling_fn=ops.add,
|
||||||
if window_size > 25:
|
|
||||||
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
|
|
||||||
return fallback_adaptive_avg_pool2d(x, output_size)
|
|
||||||
|
|
||||||
fn_sum = _adaptive_pooling_idx_sum(
|
|
||||||
[h_kernel_max, w_kernel_max],
|
|
||||||
[h_start_index, w_start_index],
|
|
||||||
[h_end_index, w_end_index],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ones_loader = pad_adaptive_loader(ones_like(x))
|
ones_loader = pad_adaptive_loader(ones_like(x))
|
||||||
@ -4139,60 +4209,6 @@ def _adaptive_avg_pool2d(x, output_size):
|
|||||||
return rv
|
return rv
|
||||||
|
|
||||||
|
|
||||||
def _adaptive_pooling_idx_max(kernel_maxes, in_sizes, out_sizes, return_index, loader):
|
|
||||||
# NOTE: There is some duplication between this and addaptive_avg_pool2d and max_pool2d
|
|
||||||
# Look into refactoring/deduplication after #116418 is merged.
|
|
||||||
h_in, w_in = in_sizes
|
|
||||||
h_out, w_out = out_sizes
|
|
||||||
|
|
||||||
def start_index(index, out_dim, inp_dim):
|
|
||||||
return FloorDiv((index * inp_dim), out_dim)
|
|
||||||
|
|
||||||
def end_index(index, out_dim, inp_dim):
|
|
||||||
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
|
|
||||||
|
|
||||||
h_start_index_fn = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
|
|
||||||
h_end_index_fn = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
|
|
||||||
w_start_index_fn = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
|
|
||||||
w_end_index_fn = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
|
|
||||||
|
|
||||||
def fn_max(idx):
|
|
||||||
*prefix, bh, bw = idx
|
|
||||||
|
|
||||||
h_start_index = h_start_index_fn(bh)
|
|
||||||
h_end_index = h_end_index_fn(bh)
|
|
||||||
|
|
||||||
w_start_index = w_start_index_fn(bw)
|
|
||||||
w_end_index = w_end_index_fn(bw)
|
|
||||||
maxval = None
|
|
||||||
maxindex = None
|
|
||||||
for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
|
|
||||||
val = loader(
|
|
||||||
prefix,
|
|
||||||
[ih, iw],
|
|
||||||
[h_start_index, w_start_index],
|
|
||||||
[h_end_index, w_end_index],
|
|
||||||
)
|
|
||||||
index = ops.index_expr(
|
|
||||||
(h_start_index + ih) * w_in + w_start_index + iw, torch.int64
|
|
||||||
)
|
|
||||||
if return_index:
|
|
||||||
if maxindex is None:
|
|
||||||
maxindex = index
|
|
||||||
else:
|
|
||||||
maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
|
|
||||||
if maxval is None:
|
|
||||||
maxval = val
|
|
||||||
else:
|
|
||||||
maxval = ops.maximum(val, maxval)
|
|
||||||
if return_index:
|
|
||||||
return maxindex
|
|
||||||
else:
|
|
||||||
return maxval
|
|
||||||
|
|
||||||
return fn_max
|
|
||||||
|
|
||||||
|
|
||||||
fallback_adaptive_max_pool2d = fallback_handler(
|
fallback_adaptive_max_pool2d = fallback_handler(
|
||||||
aten.adaptive_max_pool2d.default, add_to_fallback_set=False
|
aten.adaptive_max_pool2d.default, add_to_fallback_set=False
|
||||||
)
|
)
|
||||||
@ -4245,32 +4261,46 @@ def adaptive_max_pool2d(x, output_size):
|
|||||||
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
|
# Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
|
||||||
return fallback_adaptive_max_pool2d(x, output_size)
|
return fallback_adaptive_max_pool2d(x, output_size)
|
||||||
|
|
||||||
inner_func_max_val = _adaptive_pooling_idx_max(
|
def start_index(index, out_dim, inp_dim):
|
||||||
|
return FloorDiv((index * inp_dim), out_dim)
|
||||||
|
|
||||||
|
def end_index(index, out_dim, inp_dim):
|
||||||
|
return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
|
||||||
|
|
||||||
|
inner_func_max_val = _adaptive_pooling_fn(
|
||||||
|
start_index=start_index,
|
||||||
|
end_index=end_index,
|
||||||
kernel_maxes=[h_kernel_max, w_kernel_max],
|
kernel_maxes=[h_kernel_max, w_kernel_max],
|
||||||
in_sizes=[h_in, w_in],
|
in_sizes=[h_in, w_in],
|
||||||
out_sizes=[h_out, w_out],
|
out_sizes=[h_out, w_out],
|
||||||
return_index=False,
|
pooling_fn=ops.maximum,
|
||||||
loader=pad_adaptive_loader(x, float("-inf")),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
inner_func_max_idx = _adaptive_pooling_idx_max(
|
inner_func_max_idx = _adaptive_pooling_fn_with_idx(
|
||||||
|
start_index=start_index,
|
||||||
|
end_index=end_index,
|
||||||
kernel_maxes=[h_kernel_max, w_kernel_max],
|
kernel_maxes=[h_kernel_max, w_kernel_max],
|
||||||
in_sizes=[h_in, w_in],
|
in_sizes=[h_in, w_in],
|
||||||
out_sizes=[h_out, w_out],
|
out_sizes=[h_out, w_out],
|
||||||
return_index=True,
|
pooling_fn=ops.maximum,
|
||||||
loader=pad_adaptive_loader(x, float("-inf")),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def inner_fn_max_val(idx):
|
||||||
|
return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf")))
|
||||||
|
|
||||||
|
def inner_fn_max_idx(idx):
|
||||||
|
return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf")))
|
||||||
|
|
||||||
rv = Pointwise.create(
|
rv = Pointwise.create(
|
||||||
device=x.get_device(),
|
device=x.get_device(),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
inner_fn=inner_func_max_val,
|
inner_fn=inner_fn_max_val,
|
||||||
ranges=new_size,
|
ranges=new_size,
|
||||||
)
|
)
|
||||||
ri = Pointwise.create(
|
ri = Pointwise.create(
|
||||||
device=x.get_device(),
|
device=x.get_device(),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
inner_fn=inner_func_max_idx,
|
inner_fn=inner_fn_max_idx,
|
||||||
ranges=new_size,
|
ranges=new_size,
|
||||||
)
|
)
|
||||||
return rv, ri
|
return rv, ri
|
||||||
@ -4400,16 +4430,13 @@ def upsample_nearest2d_backward(
|
|||||||
def end_index(index, out_dim, inp_dim):
|
def end_index(index, out_dim, inp_dim):
|
||||||
return start_index((index + 1), out_dim, inp_dim)
|
return start_index((index + 1), out_dim, inp_dim)
|
||||||
|
|
||||||
h_start_index = functools.partial(start_index, out_dim=out_h, inp_dim=inp_h)
|
fn_sum = _adaptive_pooling_fn(
|
||||||
h_end_index = functools.partial(end_index, out_dim=out_h, inp_dim=inp_h)
|
start_index=start_index,
|
||||||
|
end_index=end_index,
|
||||||
w_start_index = functools.partial(start_index, out_dim=out_w, inp_dim=inp_w)
|
kernel_maxes=[h_kernel_max, w_kernel_max],
|
||||||
w_end_index = functools.partial(end_index, out_dim=out_w, inp_dim=inp_w)
|
in_sizes=[inp_h, inp_w],
|
||||||
|
out_sizes=[out_h, out_w],
|
||||||
fn_sum = _adaptive_pooling_idx_sum(
|
pooling_fn=ops.add,
|
||||||
[h_kernel_max, w_kernel_max],
|
|
||||||
[h_start_index, w_start_index],
|
|
||||||
[h_end_index, w_end_index],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def fn(idx):
|
def fn(idx):
|
||||||
@ -4761,6 +4788,207 @@ def avg_pool2d_backward(
|
|||||||
return rv
|
return rv
|
||||||
|
|
||||||
|
|
||||||
|
fallback_avg_pool3d_backward = fallback_handler(
|
||||||
|
aten.avg_pool3d_backward.default, add_to_fallback_set=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None)
|
||||||
|
def avg_pool3d_backward(
|
||||||
|
grad_output,
|
||||||
|
x,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
ceil_mode,
|
||||||
|
count_include_pad,
|
||||||
|
divisor_override=None,
|
||||||
|
):
|
||||||
|
assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
|
||||||
|
if not stride:
|
||||||
|
stride = kernel_size
|
||||||
|
if not padding:
|
||||||
|
padding = [0, 0, 0]
|
||||||
|
|
||||||
|
assert isinstance(grad_output, TensorBox)
|
||||||
|
assert isinstance(x, TensorBox)
|
||||||
|
assert len(kernel_size) == 3
|
||||||
|
assert len(stride) == 3
|
||||||
|
assert len(padding) == 3
|
||||||
|
assert len(x.get_size()) in (4, 5)
|
||||||
|
|
||||||
|
grad_output.realize_hint()
|
||||||
|
|
||||||
|
*batch, depth, height, width = x.get_size()
|
||||||
|
|
||||||
|
d_out, ceil_mode_d = pooling_size(depth, 0, kernel_size, stride, padding, ceil_mode)
|
||||||
|
h_out, ceil_mode_h = pooling_size(
|
||||||
|
height, 1, kernel_size, stride, padding, ceil_mode
|
||||||
|
)
|
||||||
|
w_out, ceil_mode_w = pooling_size(width, 2, kernel_size, stride, padding, ceil_mode)
|
||||||
|
|
||||||
|
grad_loader = grad_output.make_loader()
|
||||||
|
had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w
|
||||||
|
|
||||||
|
*_, pooled_depth, pooled_height, pooled_width = grad_output.get_size()
|
||||||
|
new_size = list(x.get_size())
|
||||||
|
dtype = x.get_dtype()
|
||||||
|
|
||||||
|
d_window_size, h_window_size, w_window_size = (
|
||||||
|
max(
|
||||||
|
max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1)
|
||||||
|
for d in range(kernel_size[i] * 2)
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
)
|
||||||
|
|
||||||
|
window_size = d_window_size * h_window_size * w_window_size
|
||||||
|
if window_size > 125:
|
||||||
|
# Kernel size too big. Results in hard-to-optimize Triton code.
|
||||||
|
return fallback_avg_pool3d_backward(
|
||||||
|
grad_output,
|
||||||
|
x,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
ceil_mode,
|
||||||
|
count_include_pad,
|
||||||
|
divisor_override,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_pool_size_without_padding(pd, ph, pw):
|
||||||
|
stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride)
|
||||||
|
pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding)
|
||||||
|
kernel_d, kernel_h, kernel_w = (
|
||||||
|
ops.constant(k, torch.int32) for k in kernel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
dstart, hstart, wstart = (
|
||||||
|
ops.sub(ops.mul(p, s), pad)
|
||||||
|
for p, s, pad in zip(
|
||||||
|
[pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
dend, hend, wend = (
|
||||||
|
ops.minimum(
|
||||||
|
ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad)
|
||||||
|
)
|
||||||
|
for start, k, dim, pad in zip(
|
||||||
|
[dstart, hstart, wstart],
|
||||||
|
[kernel_d, kernel_h, kernel_w],
|
||||||
|
[depth, height, width],
|
||||||
|
[pad_d, pad_h, pad_w],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
dstart, hstart, wstart = (
|
||||||
|
ops.maximum(start, ops.constant(0, torch.int32))
|
||||||
|
for start in [dstart, hstart, wstart]
|
||||||
|
)
|
||||||
|
dend, hend, wend = (
|
||||||
|
ops.minimum(end, ops.index_expr(dim, torch.int32))
|
||||||
|
for end, dim in zip([dend, hend, wend], [depth, height, width])
|
||||||
|
)
|
||||||
|
divide_factor = ops.mul(
|
||||||
|
ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart)
|
||||||
|
)
|
||||||
|
return divide_factor
|
||||||
|
|
||||||
|
def fn(idx):
|
||||||
|
*prefix, d, h, w = idx
|
||||||
|
d, h, w = (v + pad for v, pad in zip([d, h, w], padding))
|
||||||
|
|
||||||
|
pdstart, phstart, pwstart = (
|
||||||
|
ops.index_expr(FloorDiv(v - k + s, s), torch.int32)
|
||||||
|
for v, k, s in zip([d, h, w], kernel_size, stride)
|
||||||
|
)
|
||||||
|
|
||||||
|
pdend, phend, pwend = (
|
||||||
|
ops.index_expr(FloorDiv(v, s) + 1, torch.int32)
|
||||||
|
for v, s in zip([d, h, w], stride)
|
||||||
|
)
|
||||||
|
|
||||||
|
pdstart, phstart, pwstart = (
|
||||||
|
ops.maximum(pstart, ops.constant(0, torch.int32))
|
||||||
|
for pstart in [pdstart, phstart, pwstart]
|
||||||
|
)
|
||||||
|
pdend, phend, pwend = (
|
||||||
|
ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32))
|
||||||
|
for pend, pooled_dim in zip(
|
||||||
|
[pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
gradient = None
|
||||||
|
# Iterate over the 3D region to accumulate gradients
|
||||||
|
for pd_ in range(d_window_size):
|
||||||
|
for ph_ in range(h_window_size):
|
||||||
|
for pw_ in range(w_window_size):
|
||||||
|
pd, ph, pw = (
|
||||||
|
ops.add(pstart, ops.constant(p_, torch.int32))
|
||||||
|
for pstart, p_ in zip(
|
||||||
|
[pdstart, phstart, pwstart], [pd_, ph_, pw_]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if divisor_override is not None:
|
||||||
|
scale = divisor_override
|
||||||
|
elif count_include_pad or not had_padding:
|
||||||
|
scale = kernel_size[0] * kernel_size[1] * kernel_size[2]
|
||||||
|
else:
|
||||||
|
scale = compute_pool_size_without_padding(pd, ph, pw)
|
||||||
|
|
||||||
|
part = ops.truediv(
|
||||||
|
grad_loader(
|
||||||
|
[
|
||||||
|
*prefix,
|
||||||
|
ops.indirect_indexing(
|
||||||
|
ops.minimum(
|
||||||
|
pd, ops.sub(pdend, ops.constant(1, torch.int32))
|
||||||
|
),
|
||||||
|
pooled_depth,
|
||||||
|
check=False,
|
||||||
|
),
|
||||||
|
ops.indirect_indexing(
|
||||||
|
ops.minimum(
|
||||||
|
ph, ops.sub(phend, ops.constant(1, torch.int32))
|
||||||
|
),
|
||||||
|
pooled_height,
|
||||||
|
check=False,
|
||||||
|
),
|
||||||
|
ops.indirect_indexing(
|
||||||
|
ops.minimum(
|
||||||
|
pw, ops.sub(pwend, ops.constant(1, torch.int32))
|
||||||
|
),
|
||||||
|
pooled_width,
|
||||||
|
check=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = ops.and_(
|
||||||
|
ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)),
|
||||||
|
ops.lt(pw, pwend),
|
||||||
|
)
|
||||||
|
if gradient is None:
|
||||||
|
gradient = ops.where(
|
||||||
|
mask, part, ops.constant(0.0, torch.float32)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gradient = ops.where(mask, ops.add(gradient, part), gradient)
|
||||||
|
assert gradient is not None
|
||||||
|
return gradient
|
||||||
|
|
||||||
|
rv = Pointwise.create(
|
||||||
|
device=grad_output.get_device(),
|
||||||
|
dtype=dtype,
|
||||||
|
inner_fn=fn,
|
||||||
|
ranges=new_size,
|
||||||
|
)
|
||||||
|
return rv
|
||||||
|
|
||||||
|
|
||||||
def _validate_reduction_axis(x, axis):
|
def _validate_reduction_axis(x, axis):
|
||||||
size = x.get_size()
|
size = x.get_size()
|
||||||
if isinstance(axis, int):
|
if isinstance(axis, int):
|
||||||
|
|||||||
@ -1031,7 +1031,7 @@ def should_use_remote_autotune_cache(inductor_meta):
|
|||||||
if inductor_meta.get("is_hip"):
|
if inductor_meta.get("is_hip"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
from triton.runtime.fb_memcache import MEMCACHE_VERSION
|
from triton.fb.fb_memcache import MEMCACHE_VERSION
|
||||||
|
|
||||||
return MEMCACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
|
return MEMCACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
|
||||||
"pytorch/remote_cache:autotune_memcache_version"
|
"pytorch/remote_cache:autotune_memcache_version"
|
||||||
@ -1075,8 +1075,12 @@ def cached_autotune(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if inductor_meta.get("is_fbcode"):
|
if inductor_meta.get("is_fbcode"):
|
||||||
remote_cache = triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend(
|
import triton.fb.fb_memcache
|
||||||
key
|
|
||||||
|
remote_cache = (
|
||||||
|
triton.fb.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend(
|
||||||
|
key
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from torch._inductor.remote_cache import RedisRemoteCacheBackend
|
from torch._inductor.remote_cache import RedisRemoteCacheBackend
|
||||||
|
|||||||
@ -193,7 +193,16 @@ class SizeVarAllocator:
|
|||||||
"""
|
"""
|
||||||
sizes = list(map(self.simplify, sizes))
|
sizes = list(map(self.simplify, sizes))
|
||||||
|
|
||||||
strides = [self.stride_vars(x, index_vars) for x in index_formulas]
|
strides = [
|
||||||
|
# index_formulas may contain boolean expressions (e.g. s0 < 10),
|
||||||
|
# for which "strides" don't make sense so we ignore them here.
|
||||||
|
# NOTE: These expressions may still block merging dims in the sound
|
||||||
|
# substitution test performed in can_merge_dims.
|
||||||
|
self.stride_vars(x, index_vars)
|
||||||
|
if isinstance(x, sympy.Expr)
|
||||||
|
else [0] * len(index_vars)
|
||||||
|
for x in index_formulas
|
||||||
|
]
|
||||||
assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))
|
assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))
|
||||||
|
|
||||||
for i in range(len(sizes)):
|
for i in range(len(sizes)):
|
||||||
|
|||||||
@ -1046,17 +1046,17 @@ def type_to_dtype(typ: type) -> torch.dtype:
|
|||||||
|
|
||||||
assert isinstance(typ, type)
|
assert isinstance(typ, type)
|
||||||
|
|
||||||
if typ is bool:
|
if typ in (bool, torch.SymBool):
|
||||||
return torch.bool
|
return torch.bool
|
||||||
if typ in [int, torch.SymInt]:
|
if typ in (int, torch.SymInt):
|
||||||
return torch.long
|
return torch.long
|
||||||
if typ in [float, torch.SymFloat]:
|
if typ in (float, torch.SymFloat):
|
||||||
return torch.get_default_dtype()
|
return torch.get_default_dtype()
|
||||||
# TODO: sym_complex_float?
|
# TODO: sym_complex_float?
|
||||||
if typ is complex:
|
if typ is complex:
|
||||||
return corresponding_complex_dtype(torch.get_default_dtype())
|
return corresponding_complex_dtype(torch.get_default_dtype())
|
||||||
|
|
||||||
raise ValueError("Invalid type!")
|
raise ValueError(f"Invalid type {typ}!")
|
||||||
|
|
||||||
|
|
||||||
def get_dtype(x: Union[torch.Tensor, NumberType]):
|
def get_dtype(x: Union[torch.Tensor, NumberType]):
|
||||||
@ -1363,8 +1363,12 @@ def number_type(
|
|||||||
return type(x)
|
return type(x)
|
||||||
|
|
||||||
|
|
||||||
def expr_type(x: sympy.Expr) -> Type:
|
def expr_type(x: sympy.Basic) -> Type:
|
||||||
if x.is_integer: # type: ignore[attr-defined]
|
import sympy
|
||||||
|
|
||||||
|
if x.kind is sympy.core.kind.BooleanKind:
|
||||||
|
return bool
|
||||||
|
elif x.is_integer: # type: ignore[attr-defined]
|
||||||
return int
|
return int
|
||||||
else:
|
else:
|
||||||
# NB: Not strictly correct, but we don't support SymPy complex or bool.
|
# NB: Not strictly correct, but we don't support SymPy complex or bool.
|
||||||
@ -1471,13 +1475,13 @@ def elementwise_dtypes(
|
|||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
for x in args:
|
for x in args:
|
||||||
if not isinstance(x, (Number, TensorLike, sympy.Expr)):
|
if not isinstance(x, (Number, TensorLike, sympy.Basic)):
|
||||||
msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!"
|
msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
if isinstance(x, Number):
|
if isinstance(x, Number):
|
||||||
highest_type = get_higher_type(highest_type, number_type(x))
|
highest_type = get_higher_type(highest_type, number_type(x))
|
||||||
elif isinstance(x, sympy.Expr):
|
elif isinstance(x, sympy.Basic):
|
||||||
highest_type = get_higher_type(highest_type, expr_type(x))
|
highest_type = get_higher_type(highest_type, expr_type(x))
|
||||||
else:
|
else:
|
||||||
# x is a TensorLike
|
# x is a TensorLike
|
||||||
|
|||||||
@ -67,6 +67,7 @@
|
|||||||
#include <torch/csrc/cpu/Module.h>
|
#include <torch/csrc/cpu/Module.h>
|
||||||
#include <torch/csrc/dynamo/init.h>
|
#include <torch/csrc/dynamo/init.h>
|
||||||
#include <torch/csrc/functorch/init.h>
|
#include <torch/csrc/functorch/init.h>
|
||||||
|
#include <torch/csrc/fx/node.h>
|
||||||
#include <torch/csrc/inductor/aoti_runner/pybind.h>
|
#include <torch/csrc/inductor/aoti_runner/pybind.h>
|
||||||
#include <torch/csrc/jit/python/init.h>
|
#include <torch/csrc/jit/python/init.h>
|
||||||
#include <torch/csrc/jit/python/python_ir.h>
|
#include <torch/csrc/jit/python/python_ir.h>
|
||||||
@ -1602,6 +1603,8 @@ PyObject* initModule() {
|
|||||||
THPDevice_init(module);
|
THPDevice_init(module);
|
||||||
THPStream_init(module);
|
THPStream_init(module);
|
||||||
THPEvent_init(module);
|
THPEvent_init(module);
|
||||||
|
NodeBase_init(module);
|
||||||
|
NodeIter_init(module);
|
||||||
ASSERT_TRUE(THPVariable_initModule(module));
|
ASSERT_TRUE(THPVariable_initModule(module));
|
||||||
ASSERT_TRUE(THPFunction_initModule(module));
|
ASSERT_TRUE(THPFunction_initModule(module));
|
||||||
ASSERT_TRUE(THPEngine_initModule(module));
|
ASSERT_TRUE(THPEngine_initModule(module));
|
||||||
|
|||||||
257
torch/csrc/fx/node.cpp
Normal file
257
torch/csrc/fx/node.cpp
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
#include <torch/csrc/fx/node.h>
|
||||||
|
|
||||||
|
#include <structmember.h>
|
||||||
|
#include <torch/csrc/utils/pythoncapi_compat.h>
|
||||||
|
|
||||||
|
////////////////////////////////
|
||||||
|
// NodeBase
|
||||||
|
///////////////////////////////
|
||||||
|
|
||||||
|
struct NodeBase {
|
||||||
|
PyObject_HEAD bool _erased;
|
||||||
|
NodeBase* _prev;
|
||||||
|
NodeBase* _next;
|
||||||
|
};
|
||||||
|
|
||||||
|
static PyObject* NodeBase_new(
|
||||||
|
PyTypeObject* type,
|
||||||
|
PyObject* args,
|
||||||
|
PyObject* kwds) {
|
||||||
|
PyObject* self = type->tp_alloc(type, 0);
|
||||||
|
if (!self)
|
||||||
|
return nullptr;
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) {
|
||||||
|
self->_erased = false;
|
||||||
|
Py_INCREF(self);
|
||||||
|
self->_prev = self;
|
||||||
|
Py_INCREF(self);
|
||||||
|
self->_next = self;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
|
static struct PyMemberDef NodeBase_members[] = {
|
||||||
|
{"_erased", T_BOOL, offsetof(NodeBase, _erased), 0, nullptr},
|
||||||
|
{"_prev", T_OBJECT_EX, offsetof(NodeBase, _prev), 0, nullptr},
|
||||||
|
{"_next", T_OBJECT_EX, offsetof(NodeBase, _next), 0, nullptr},
|
||||||
|
{nullptr} /* Sentinel */
|
||||||
|
};
|
||||||
|
|
||||||
|
static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) {
|
||||||
|
Py_VISIT(self->_prev);
|
||||||
|
Py_VISIT(self->_next);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int NodeBase_clear(NodeBase* self) {
|
||||||
|
Py_CLEAR(self->_prev);
|
||||||
|
Py_CLEAR(self->_next);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void NodeBase_dealloc(PyObject* self) {
|
||||||
|
PyObject_GC_UnTrack(self);
|
||||||
|
(void)NodeBase_clear((NodeBase*)self);
|
||||||
|
Py_TYPE(self)->tp_free(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyTypeObject NodeBaseType = {
|
||||||
|
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */
|
||||||
|
sizeof(NodeBase), /* tp_basicsize */
|
||||||
|
0, /* tp_itemsize */
|
||||||
|
(destructor)NodeBase_dealloc, /* tp_dealloc */
|
||||||
|
0, /* tp_vectorcall_offset */
|
||||||
|
nullptr, /* tp_getattr */
|
||||||
|
nullptr, /* tp_setattr */
|
||||||
|
nullptr, /* tp_reserved */
|
||||||
|
nullptr, /* tp_repr */
|
||||||
|
nullptr, /* tp_as_number */
|
||||||
|
nullptr, /* tp_as_sequence */
|
||||||
|
nullptr, /* tp_as_mapping */
|
||||||
|
nullptr, /* tp_hash */
|
||||||
|
nullptr, /* tp_call */
|
||||||
|
nullptr, /* tp_str */
|
||||||
|
nullptr, /* tp_getattro */
|
||||||
|
nullptr, /* tp_setattro */
|
||||||
|
nullptr, /* tp_as_buffer */
|
||||||
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
|
||||||
|
Py_TPFLAGS_HAVE_GC, /* tp_flags */
|
||||||
|
nullptr, /* tp_doc */
|
||||||
|
(traverseproc)NodeBase_traverse, /* tp_traverse */
|
||||||
|
(inquiry)NodeBase_clear, /* tp_clear */
|
||||||
|
nullptr, /* tp_richcompare */
|
||||||
|
0, /* tp_weaklistoffset */
|
||||||
|
nullptr, /* tp_iter */
|
||||||
|
nullptr, /* tp_iternext */
|
||||||
|
nullptr, /* tp_methods */
|
||||||
|
NodeBase_members, /* tp_members */
|
||||||
|
nullptr, /* tp_getset */
|
||||||
|
nullptr, /* tp_base */
|
||||||
|
nullptr, /* tp_dict */
|
||||||
|
nullptr, /* tp_descr_get */
|
||||||
|
nullptr, /* tp_descr_set */
|
||||||
|
0, /* tp_dictoffset */
|
||||||
|
(initproc)NodeBase_init_fn, /* tp_init */
|
||||||
|
nullptr, /* tp_alloc */
|
||||||
|
NodeBase_new, /* tp_new */
|
||||||
|
};
|
||||||
|
|
||||||
|
bool NodeBase_init(PyObject* module) {
|
||||||
|
if (PyModule_AddType(module, &NodeBaseType) < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////
|
||||||
|
// NodeIter
|
||||||
|
////////////////////////////////
|
||||||
|
|
||||||
|
struct NodeIter {
|
||||||
|
PyObject_HEAD bool _reversed;
|
||||||
|
NodeBase* _root;
|
||||||
|
NodeBase* _cur;
|
||||||
|
};
|
||||||
|
|
||||||
|
static PyObject* NodeIter_new(
|
||||||
|
PyTypeObject* type,
|
||||||
|
PyObject* args,
|
||||||
|
PyObject* kwds) {
|
||||||
|
PyObject* self = type->tp_alloc(type, 0);
|
||||||
|
if (!self)
|
||||||
|
return nullptr;
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int NodeIter_init_fn(NodeIter* self, PyObject* args, PyObject* kwargs) {
|
||||||
|
NodeBase* root = nullptr;
|
||||||
|
bool reversed = false;
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||||
|
constexpr const char* keywords[] = {"root", "reversed", nullptr};
|
||||||
|
if (!PyArg_ParseTupleAndKeywords(
|
||||||
|
args,
|
||||||
|
kwargs,
|
||||||
|
"Ob|",
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||||
|
const_cast<char**>(keywords),
|
||||||
|
&root,
|
||||||
|
&reversed)) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
self->_reversed = reversed;
|
||||||
|
Py_INCREF(root);
|
||||||
|
self->_root = root;
|
||||||
|
Py_INCREF(root);
|
||||||
|
self->_cur = root;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool reversed>
|
||||||
|
PyObject* NodeIter_iternext_helper(NodeIter* self) {
|
||||||
|
// It should be possible to relax the ref counting here
|
||||||
|
// but in practice, we do not have that many _erased Nodes,
|
||||||
|
// so probably not worth it.
|
||||||
|
if constexpr (reversed) {
|
||||||
|
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
|
||||||
|
Py_CLEAR(self->_cur);
|
||||||
|
self->_cur = prev;
|
||||||
|
} else {
|
||||||
|
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
|
||||||
|
Py_CLEAR(self->_cur);
|
||||||
|
self->_cur = next;
|
||||||
|
}
|
||||||
|
while (self->_cur != self->_root) {
|
||||||
|
if (!self->_cur->_erased) {
|
||||||
|
Py_INCREF(self->_cur);
|
||||||
|
return (PyObject*)self->_cur;
|
||||||
|
}
|
||||||
|
if constexpr (reversed) {
|
||||||
|
NodeBase* prev = (NodeBase*)Py_NewRef(self->_cur->_prev);
|
||||||
|
Py_CLEAR(self->_cur);
|
||||||
|
self->_cur = prev;
|
||||||
|
} else {
|
||||||
|
NodeBase* next = (NodeBase*)Py_NewRef(self->_cur->_next);
|
||||||
|
Py_CLEAR(self->_cur);
|
||||||
|
self->_cur = next;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PyErr_SetNone(PyExc_StopIteration);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* NodeIter_iternext(PyObject* _self) {
|
||||||
|
NodeIter* self = (NodeIter*)_self;
|
||||||
|
if (self->_reversed) {
|
||||||
|
return NodeIter_iternext_helper<true>(self);
|
||||||
|
} else {
|
||||||
|
return NodeIter_iternext_helper<false>(self);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static int NodeIter_traverse(NodeIter* self, visitproc visit, void* arg) {
|
||||||
|
Py_VISIT(self->_root);
|
||||||
|
Py_VISIT(self->_cur);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int NodeIter_clear(NodeIter* self) {
|
||||||
|
Py_CLEAR(self->_root);
|
||||||
|
Py_CLEAR(self->_cur);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void NodeIter_dealloc(PyObject* self) {
|
||||||
|
PyObject_GC_UnTrack(self);
|
||||||
|
(void)NodeIter_clear((NodeIter*)self);
|
||||||
|
Py_TYPE(self)->tp_free(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PyTypeObject NodeIterType = {
|
||||||
|
PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeIter", /* tp_name */
|
||||||
|
sizeof(NodeIter), /* tp_basicsize */
|
||||||
|
0, /* tp_itemsize */
|
||||||
|
(destructor)NodeIter_dealloc, /* tp_dealloc */
|
||||||
|
0, /* tp_vectorcall_offset */
|
||||||
|
nullptr, /* tp_getattr */
|
||||||
|
nullptr, /* tp_setattr */
|
||||||
|
nullptr, /* tp_reserved */
|
||||||
|
nullptr, /* tp_repr */
|
||||||
|
nullptr, /* tp_as_number */
|
||||||
|
nullptr, /* tp_as_sequence */
|
||||||
|
nullptr, /* tp_as_mapping */
|
||||||
|
nullptr, /* tp_hash */
|
||||||
|
nullptr, /* tp_call */
|
||||||
|
nullptr, /* tp_str */
|
||||||
|
nullptr, /* tp_getattro */
|
||||||
|
nullptr, /* tp_setattro */
|
||||||
|
nullptr, /* tp_as_buffer */
|
||||||
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
|
||||||
|
nullptr, /* tp_doc */
|
||||||
|
(traverseproc)NodeIter_traverse, /* tp_traverse */
|
||||||
|
(inquiry)NodeIter_clear, /* tp_clear */
|
||||||
|
nullptr, /* tp_richcompare */
|
||||||
|
0, /* tp_weaklistoffset */
|
||||||
|
PyObject_SelfIter, /* tp_iter */
|
||||||
|
NodeIter_iternext, /* tp_iternext */
|
||||||
|
nullptr, /* tp_methods */
|
||||||
|
nullptr, /* tp_members */
|
||||||
|
nullptr, /* tp_getset */
|
||||||
|
nullptr, /* tp_base */
|
||||||
|
nullptr, /* tp_dict */
|
||||||
|
nullptr, /* tp_descr_get */
|
||||||
|
nullptr, /* tp_descr_set */
|
||||||
|
0, /* tp_dictoffset */
|
||||||
|
(initproc)NodeIter_init_fn, /* tp_init */
|
||||||
|
nullptr, /* tp_alloc */
|
||||||
|
NodeIter_new, /* tp_new */
|
||||||
|
};
|
||||||
|
|
||||||
|
bool NodeIter_init(PyObject* module) {
|
||||||
|
if (PyModule_AddType(module, &NodeIterType) < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
6
torch/csrc/fx/node.h
Normal file
6
torch/csrc/fx/node.h
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/python_headers.h>
|
||||||
|
|
||||||
|
bool NodeBase_init(PyObject* module);
|
||||||
|
bool NodeIter_init(PyObject* module);
|
||||||
@ -86,6 +86,7 @@ def _insert_copy_for_mutations(
|
|||||||
assert len(outputs) == len(mutated_outputs)
|
assert len(outputs) == len(mutated_outputs)
|
||||||
|
|
||||||
user_output_nodes = []
|
user_output_nodes = []
|
||||||
|
return_nodes_to_copy = {}
|
||||||
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
|
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
|
||||||
if mutated_node_name is None:
|
if mutated_node_name is None:
|
||||||
user_output_nodes.append(return_node)
|
user_output_nodes.append(return_node)
|
||||||
@ -101,13 +102,18 @@ def _insert_copy_for_mutations(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with gm.graph.inserting_before(output_node):
|
with gm.graph.inserting_before(output_node):
|
||||||
_ = gm.graph.call_function(
|
copy_node = gm.graph.call_function(
|
||||||
torch.ops.aten.copy_.default, (mutated_node, return_node)
|
torch.ops.aten.copy_.default, (mutated_node, return_node)
|
||||||
)
|
)
|
||||||
|
return_nodes_to_copy[return_node] = copy_node
|
||||||
|
|
||||||
|
output_args = [
|
||||||
|
return_nodes_to_copy[node] if node in return_nodes_to_copy else node
|
||||||
|
for node in user_output_nodes
|
||||||
|
]
|
||||||
with gm.graph.inserting_before(output_node):
|
with gm.graph.inserting_before(output_node):
|
||||||
# Only return user outputs
|
# Only return user outputs
|
||||||
new_output = gm.graph.output(tuple(user_output_nodes))
|
new_output = gm.graph.output(tuple(output_args))
|
||||||
output_node.replace_all_uses_with(new_output)
|
output_node.replace_all_uses_with(new_output)
|
||||||
gm.graph.erase_node(output_node)
|
gm.graph.erase_node(output_node)
|
||||||
|
|
||||||
|
|||||||
@ -453,20 +453,21 @@ def free_unbacked_symbols(x):
|
|||||||
# setup!
|
# setup!
|
||||||
def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]:
|
def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]:
|
||||||
if (
|
if (
|
||||||
node.op == "placeholder" and
|
|
||||||
"val" in node.meta and
|
"val" in node.meta and
|
||||||
isinstance(node.meta["val"], torch.SymInt) and
|
isinstance(node.meta["val"], torch.SymInt) and
|
||||||
isinstance(node.meta["val"].node.expr, sympy.Symbol)
|
isinstance(node.meta["val"].node.expr, sympy.Symbol) and
|
||||||
|
(node.op == "placeholder" or free_unbacked_symbols(node.meta["val"].node.expr))
|
||||||
):
|
):
|
||||||
return node.meta["val"].node.expr
|
return node.meta["val"].node.expr
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def find_symbol_binding_fx_nodes(graph):
|
def find_symbol_binding_fx_nodes(graph):
|
||||||
return {
|
r = {}
|
||||||
node.meta["val"].node.expr: node
|
# NB: Prefer first occurrence of symbol
|
||||||
for node in graph.nodes
|
for node in graph.nodes:
|
||||||
if is_symbol_binding_fx_node(node)
|
if is_symbol_binding_fx_node(node) and node.meta["val"].node.expr not in r:
|
||||||
}
|
r[node.meta["val"].node.expr] = node
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
# Analogous to ConvertIntSource
|
# Analogous to ConvertIntSource
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_na
|
|||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from . import _pytree as fx_pytree
|
from . import _pytree as fx_pytree
|
||||||
from ._compatibility import compatibility
|
from ._compatibility import compatibility
|
||||||
|
from torch._C import _NodeIter
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import contextlib
|
import contextlib
|
||||||
@ -271,20 +272,8 @@ class _node_list:
|
|||||||
return self.graph._len
|
return self.graph._len
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
root = self.graph._root
|
assert self.direction == "_prev" or self.direction == "_next"
|
||||||
if self.direction == "_next":
|
yield from _NodeIter(self.graph._root, self.direction == "_prev")
|
||||||
cur = root._next
|
|
||||||
while cur is not root:
|
|
||||||
if not cur._erased:
|
|
||||||
yield cur
|
|
||||||
cur = cur._next
|
|
||||||
else:
|
|
||||||
assert self.direction == "_prev"
|
|
||||||
cur = root._prev
|
|
||||||
while cur is not root:
|
|
||||||
if not cur._erased:
|
|
||||||
yield cur
|
|
||||||
cur = cur._prev
|
|
||||||
|
|
||||||
def __reversed__(self):
|
def __reversed__(self):
|
||||||
return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
|
return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import inspect
|
|||||||
import warnings
|
import warnings
|
||||||
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
|
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
|
||||||
from .._ops import ops as _ops
|
from .._ops import ops as _ops
|
||||||
|
from torch._C import _NodeBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .graph import Graph
|
from .graph import Graph
|
||||||
@ -139,7 +140,7 @@ def _format_arg(arg, max_list_len=float('inf')) -> str:
|
|||||||
return str(arg)
|
return str(arg)
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
class Node:
|
class Node(_NodeBase):
|
||||||
"""
|
"""
|
||||||
``Node`` is the data structure that represents individual operations within
|
``Node`` is the data structure that represents individual operations within
|
||||||
a ``Graph``. For the most part, Nodes represent callsites to various entities,
|
a ``Graph``. For the most part, Nodes represent callsites to various entities,
|
||||||
@ -197,6 +198,7 @@ class Node:
|
|||||||
annotation of values in the generated code or for other types
|
annotation of values in the generated code or for other types
|
||||||
of analyses.
|
of analyses.
|
||||||
"""
|
"""
|
||||||
|
super().__init__()
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.name = name # unique name of value being created
|
self.name = name # unique name of value being created
|
||||||
assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']
|
assert op in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']
|
||||||
@ -235,9 +237,6 @@ class Node:
|
|||||||
# does not produce a value, it's more of a notation. Thus, this value
|
# does not produce a value, it's more of a notation. Thus, this value
|
||||||
# describes the type of args[0] in the ``return`` node.
|
# describes the type of args[0] in the ``return`` node.
|
||||||
self.type : Optional[Any] = return_type
|
self.type : Optional[Any] = return_type
|
||||||
self._prev = self
|
|
||||||
self._next = self
|
|
||||||
self._erased = False
|
|
||||||
self._sort_key: Any = ()
|
self._sort_key: Any = ()
|
||||||
|
|
||||||
# If set, use this fn to print this node
|
# If set, use this fn to print this node
|
||||||
@ -247,6 +246,22 @@ class Node:
|
|||||||
# transformations. This metadata is preserved across node copies
|
# transformations. This metadata is preserved across node copies
|
||||||
self.meta : Dict[str, Any] = {}
|
self.meta : Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state["_erased"] = self._erased
|
||||||
|
state["_prev"] = self._prev
|
||||||
|
state["_next"] = self._next
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
_erased = state.pop("_erased")
|
||||||
|
_prev = state.pop("_prev")
|
||||||
|
_next = state.pop("_next")
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self._erased = _erased
|
||||||
|
self._prev = _prev
|
||||||
|
self._next = _next
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next(self) -> 'Node':
|
def next(self) -> 'Node':
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -89,7 +89,10 @@ def sympy_generic_le(lower, upper):
|
|||||||
return lower <= upper
|
return lower <= upper
|
||||||
else:
|
else:
|
||||||
# only negative condition is True > False
|
# only negative condition is True > False
|
||||||
assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean)
|
assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean), (
|
||||||
|
lower,
|
||||||
|
upper,
|
||||||
|
)
|
||||||
return not (lower and not upper)
|
return not (lower and not upper)
|
||||||
|
|
||||||
|
|
||||||
@ -945,6 +948,8 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis):
|
|||||||
if dtype == torch.bool:
|
if dtype == torch.bool:
|
||||||
if x.is_singleton():
|
if x.is_singleton():
|
||||||
return ValueRanges.wrap(x.lower != 0)
|
return ValueRanges.wrap(x.lower != 0)
|
||||||
|
elif x.is_bool:
|
||||||
|
return x
|
||||||
elif 0 not in x:
|
elif 0 not in x:
|
||||||
return ValueRanges.wrap(sympy.true)
|
return ValueRanges.wrap(sympy.true)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user