Add Support for CausalBias to torch compile (#116071)

Fixes #115363

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116071
Approved by: https://github.com/mlazos
This commit is contained in:
drisspg
2024-01-30 02:22:45 +00:00
committed by PyTorch MergeBot
parent 67d8db9252
commit 126c1621ce
5 changed files with 83 additions and 32 deletions

View File

@ -30,8 +30,10 @@ from torch.testing._internal.common_utils import (
set_default_dtype,
gradcheck,
make_tensor,
NOTEST_CPU
NOTEST_CPU,
IS_WINDOWS
)
from torch._dynamo.testing import CompileCounterWithBackend
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
@ -3206,11 +3208,19 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
class TestAttnMasks(NNTestCase):
class TestAttnBias(NNTestCase):
def run_test(self, device, compile, make_q, make_kv, attn_bias=None,
forw_tolerances: Optional[Tolerances] = None, grad_tolerances: Optional[Tolerances] = None):
if compile:
def run_test(
self,
device,
make_q,
make_kv,
attn_bias=None,
forw_tolerances: Optional[Tolerances] = None,
grad_tolerances: Optional[Tolerances] = None,
backend=None,
):
if backend is not None:
torch._dynamo.reset()
query, key, value = make_q(), make_kv(), make_kv()
@ -3222,8 +3232,8 @@ class TestAttnMasks(NNTestCase):
)
sdpa_op = (
torch.compile(scaled_dot_product_attention, fullgraph=True)
if compile
torch.compile(scaled_dot_product_attention, backend=backend)
if backend is not None
else scaled_dot_product_attention
)
sdpa_output = sdpa_op(
@ -3278,15 +3288,16 @@ class TestAttnMasks(NNTestCase):
else:
attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
self.run_test(device, False, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol)
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None)
@unittest.skip("This test fails on some parameters and on some CI machines")
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
@parametrize(
"shape",
[(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
)
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
cnts = CompileCounterWithBackend("aot_eager")
make_tensor = partial(
torch.rand, device=device, dtype=torch.float16, requires_grad=True
)
@ -3306,7 +3317,8 @@ class TestAttnMasks(NNTestCase):
else:
attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
self.run_test(device, True, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol)
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts)
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
@ -3354,7 +3366,7 @@ instantiate_device_type_tests(TestTransformers, globals(), only_for=device_types
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types)
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types)
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
instantiate_device_type_tests(TestAttnMasks, globals(), only_for=device_types)
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
if __name__ == '__main__':
run_tests()

View File

@ -590,7 +590,7 @@ class VariableBuilder:
value.device,
source=self.source,
)
elif isinstance(value, torch._C._SDPAParams):
elif isinstance(value, (torch._C._SDPAParams)):
self.install_guards(GuardBuilder.TYPE_MATCH)
return SDPAParamsVariable.create(self.tx, value, self.source)
elif isinstance(value, _EventBase):
@ -1533,6 +1533,12 @@ def wrap_fx_proxy_cls(
proxy.node.meta["example_value"] = example_value
return SDPAParamsVariable(proxy, **options)
elif isinstance(example_value, bool) and proxy.node.target in [
torch.backends.cuda.can_use_flash_attention,
torch.backends.cuda.can_use_efficient_attention,
]:
proxy.node.meta["example_value"] = example_value
return ConstantVariable.create(example_value, **options)
else:
unimplemented(
"torch.* op returned non-Tensor "

View File

@ -8,6 +8,7 @@ from .parallel import DataParallel as DataParallel
from . import init
from . import functional
from . import utils
from . import attention
def factory_kwargs(kwargs):

View File

@ -21,6 +21,11 @@ from torch.nn.functional import scaled_dot_product_attention
__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]
torch._dynamo.allow_in_graph(can_use_flash_attention)
torch._dynamo.allow_in_graph(can_use_efficient_attention)
torch._dynamo.allow_in_graph(SDPAParams)
class CausalVariant(IntEnum):
r"""
Enum for causal variants used in attention mechanisms.
@ -74,7 +79,7 @@ class CausalVariant(IntEnum):
LOWER_RIGHT = auto()
class CausalBias(torch.Tensor):
class CausalBias:
"""
A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum.
@ -207,9 +212,7 @@ class CausalBias(torch.Tensor):
scale=scale,
)
elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
_validate_sdpa_input(
query, key, value, attn_mask, dropout_p, is_causal, scale
)
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal)
if can_use_flash_attention(sdpa_params):
needs_padding = query.size(-1) % 8 != 0

View File

@ -2454,6 +2454,11 @@ dynamo_expected_failures = {
"LoggingTests.test_dynamo_info", # dynamo/test_logging
"LoggingTests.test_graph_breaks", # dynamo/test_logging
"LoggingTests.test_aot", # dynamo/test_logging
"TestAttnBiasCPU.test_is_causal_equals_upper_left_shape2_cpu", # test_transformers.py
"TestAttnBiasCPU.test_is_causal_equals_upper_left_shape3_cpu", # test_transformers.py
"TestAttnBiasCPU.test_is_causal_and_mask_fails_cpu", # test_transformers.py
"TestAttnBiasCPU.test_is_causal_equals_upper_left_shape1_cpu", # test_transformers.py
"TestAttnBiasCPU.test_is_causal_equals_upper_left_shape0_cpu", # test_transformers.py
}
# see NOTE [dynamo_test_failures.py] for more details
@ -2522,23 +2527,40 @@ dynamo_skips = {
"TestPruningNN.test_global_pruning_importance_scores", # flaky
"TestOpenMP_ParallelFor.test_one_thread", # test_openmp
"TestTorchrun.test_multi_threads", # backends/xeon/test_launch
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu", # known py38 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu", # known py38 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu", # known py38 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu", # known py38 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu", # known py38 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu", # known py38 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu", # known py38 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu", # known py38 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_1_shape0_cpu", # known py311 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_2_shape2_cpu", # known py311 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_2_shape0_cpu", # known py311 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu", # known py38 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_1_shape0_cpu", # known py311 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_2_shape2_cpu", # known py311 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_2_shape0_cpu", # known py311 fail
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_2_shape3_cpu", # known py311 fail
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_CUDA", # known py38 fail
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_CUDA", # known py38 fail
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_CUDA", # known py38 fail
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_CUDA", # known py38 fail
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_CUDA", # known py38 fail
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_CUDA", # known py38 fail
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_CUDA", # known py38 fail
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_CUDA", # known py38 fail
"TestTransformersCPU.test_decoder_padding_and_src_mask_bool_cpu", # known py311 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_2_shape3_cpu", # known py311 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_1_shape3_cpu", # known py311 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_1_shape2_cpu", # known py311 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_1_shape1_cpu", # known py311 fail
"TestAttnMasksCPU.test_causal_variants_causal_variant_2_shape1_cpu", # known py311 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_2_shape3_cpu", # known py311 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_1_shape3_cpu", # known py311 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_1_shape2_cpu", # known py311 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_1_shape1_cpu", # known py311 fail
"TestAttnBiasCPU.test_causal_variants_causal_variant_2_shape1_cpu", # known py311 fail
"TestFunctionalAutogradBenchmark.test_fast_tasks", # flaky?
"TestFrameworkUtils.test_filtering_env_var", # known py38 fail
"TestAsArrayCPU.test_default_device_cpu", # known py38 fail
@ -7296,6 +7318,13 @@ dynamo_skips = {
"TestFrozenOptimizations.test_conv_bn_folding", # test_jit.py
"TestArgmax.test_combinations_data58",
"TestArgmax.test_combinations_data61",
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_1_shape0_cpu", # test_transformers.py
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_2_shape0_cpu", # test_transformers.py
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_1_shape3_cpu", # test_transformers.py
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_2_shape2_cpu", # test_transformers.py
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_1_shape2_cpu", # test_transformers.py
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_2_shape1_cpu", # test_transformers.py
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_1_shape1_cpu", # test_transformers.py
}