mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Support for CausalBias to torch compile (#116071)
Fixes #115363 Pull Request resolved: https://github.com/pytorch/pytorch/pull/116071 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
67d8db9252
commit
126c1621ce
@ -30,8 +30,10 @@ from torch.testing._internal.common_utils import (
|
||||
set_default_dtype,
|
||||
gradcheck,
|
||||
make_tensor,
|
||||
NOTEST_CPU
|
||||
NOTEST_CPU,
|
||||
IS_WINDOWS
|
||||
)
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
|
||||
|
||||
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
|
||||
@ -3206,11 +3208,19 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
|
||||
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
|
||||
|
||||
class TestAttnMasks(NNTestCase):
|
||||
class TestAttnBias(NNTestCase):
|
||||
|
||||
def run_test(self, device, compile, make_q, make_kv, attn_bias=None,
|
||||
forw_tolerances: Optional[Tolerances] = None, grad_tolerances: Optional[Tolerances] = None):
|
||||
if compile:
|
||||
def run_test(
|
||||
self,
|
||||
device,
|
||||
make_q,
|
||||
make_kv,
|
||||
attn_bias=None,
|
||||
forw_tolerances: Optional[Tolerances] = None,
|
||||
grad_tolerances: Optional[Tolerances] = None,
|
||||
backend=None,
|
||||
):
|
||||
if backend is not None:
|
||||
torch._dynamo.reset()
|
||||
|
||||
query, key, value = make_q(), make_kv(), make_kv()
|
||||
@ -3222,8 +3232,8 @@ class TestAttnMasks(NNTestCase):
|
||||
)
|
||||
|
||||
sdpa_op = (
|
||||
torch.compile(scaled_dot_product_attention, fullgraph=True)
|
||||
if compile
|
||||
torch.compile(scaled_dot_product_attention, backend=backend)
|
||||
if backend is not None
|
||||
else scaled_dot_product_attention
|
||||
)
|
||||
sdpa_output = sdpa_op(
|
||||
@ -3278,15 +3288,16 @@ class TestAttnMasks(NNTestCase):
|
||||
else:
|
||||
attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
|
||||
|
||||
self.run_test(device, False, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol)
|
||||
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None)
|
||||
|
||||
@unittest.skip("This test fails on some parameters and on some CI machines")
|
||||
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
|
||||
@parametrize(
|
||||
"shape",
|
||||
[(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)],
|
||||
)
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
|
||||
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
|
||||
cnts = CompileCounterWithBackend("aot_eager")
|
||||
make_tensor = partial(
|
||||
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
||||
)
|
||||
@ -3306,7 +3317,8 @@ class TestAttnMasks(NNTestCase):
|
||||
else:
|
||||
attn_bias = causal_lower_right(seq_len_q, seq_len_kv)
|
||||
|
||||
self.run_test(device, True, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol)
|
||||
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts)
|
||||
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
|
||||
|
||||
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
|
||||
def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
|
||||
@ -3354,7 +3366,7 @@ instantiate_device_type_tests(TestTransformers, globals(), only_for=device_types
|
||||
instantiate_device_type_tests(TestSDPAFailureModes, globals(), only_for=device_types)
|
||||
instantiate_device_type_tests(TestSDPA, globals(), only_for=device_types)
|
||||
instantiate_device_type_tests(TestSDPACudaOnly, globals(), only_for=("cuda"))
|
||||
instantiate_device_type_tests(TestAttnMasks, globals(), only_for=device_types)
|
||||
instantiate_device_type_tests(TestAttnBias, globals(), only_for=device_types)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
@ -590,7 +590,7 @@ class VariableBuilder:
|
||||
value.device,
|
||||
source=self.source,
|
||||
)
|
||||
elif isinstance(value, torch._C._SDPAParams):
|
||||
elif isinstance(value, (torch._C._SDPAParams)):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return SDPAParamsVariable.create(self.tx, value, self.source)
|
||||
elif isinstance(value, _EventBase):
|
||||
@ -1533,6 +1533,12 @@ def wrap_fx_proxy_cls(
|
||||
|
||||
proxy.node.meta["example_value"] = example_value
|
||||
return SDPAParamsVariable(proxy, **options)
|
||||
elif isinstance(example_value, bool) and proxy.node.target in [
|
||||
torch.backends.cuda.can_use_flash_attention,
|
||||
torch.backends.cuda.can_use_efficient_attention,
|
||||
]:
|
||||
proxy.node.meta["example_value"] = example_value
|
||||
return ConstantVariable.create(example_value, **options)
|
||||
else:
|
||||
unimplemented(
|
||||
"torch.* op returned non-Tensor "
|
||||
|
||||
@ -8,6 +8,7 @@ from .parallel import DataParallel as DataParallel
|
||||
from . import init
|
||||
from . import functional
|
||||
from . import utils
|
||||
from . import attention
|
||||
|
||||
|
||||
def factory_kwargs(kwargs):
|
||||
|
||||
@ -21,6 +21,11 @@ from torch.nn.functional import scaled_dot_product_attention
|
||||
__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]
|
||||
|
||||
|
||||
torch._dynamo.allow_in_graph(can_use_flash_attention)
|
||||
torch._dynamo.allow_in_graph(can_use_efficient_attention)
|
||||
torch._dynamo.allow_in_graph(SDPAParams)
|
||||
|
||||
|
||||
class CausalVariant(IntEnum):
|
||||
r"""
|
||||
Enum for causal variants used in attention mechanisms.
|
||||
@ -74,7 +79,7 @@ class CausalVariant(IntEnum):
|
||||
LOWER_RIGHT = auto()
|
||||
|
||||
|
||||
class CausalBias(torch.Tensor):
|
||||
class CausalBias:
|
||||
"""
|
||||
A bias representing causal attention patterns. For an overview of the bias structure, see the :class:`CausalVariant` enum.
|
||||
|
||||
@ -207,9 +212,7 @@ class CausalBias(torch.Tensor):
|
||||
scale=scale,
|
||||
)
|
||||
elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
|
||||
_validate_sdpa_input(
|
||||
query, key, value, attn_mask, dropout_p, is_causal, scale
|
||||
)
|
||||
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
|
||||
sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal)
|
||||
if can_use_flash_attention(sdpa_params):
|
||||
needs_padding = query.size(-1) % 8 != 0
|
||||
|
||||
@ -2454,6 +2454,11 @@ dynamo_expected_failures = {
|
||||
"LoggingTests.test_dynamo_info", # dynamo/test_logging
|
||||
"LoggingTests.test_graph_breaks", # dynamo/test_logging
|
||||
"LoggingTests.test_aot", # dynamo/test_logging
|
||||
"TestAttnBiasCPU.test_is_causal_equals_upper_left_shape2_cpu", # test_transformers.py
|
||||
"TestAttnBiasCPU.test_is_causal_equals_upper_left_shape3_cpu", # test_transformers.py
|
||||
"TestAttnBiasCPU.test_is_causal_and_mask_fails_cpu", # test_transformers.py
|
||||
"TestAttnBiasCPU.test_is_causal_equals_upper_left_shape1_cpu", # test_transformers.py
|
||||
"TestAttnBiasCPU.test_is_causal_equals_upper_left_shape0_cpu", # test_transformers.py
|
||||
}
|
||||
|
||||
# see NOTE [dynamo_test_failures.py] for more details
|
||||
@ -2522,23 +2527,40 @@ dynamo_skips = {
|
||||
"TestPruningNN.test_global_pruning_importance_scores", # flaky
|
||||
"TestOpenMP_ParallelFor.test_one_thread", # test_openmp
|
||||
"TestTorchrun.test_multi_threads", # backends/xeon/test_launch
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu", # known py38 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu", # known py38 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu", # known py38 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu", # known py38 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu", # known py38 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu", # known py38 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu", # known py38 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu", # known py38 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_1_shape0_cpu", # known py311 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_2_shape2_cpu", # known py311 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_2_shape0_cpu", # known py311 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_cpu", # known py38 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_1_shape0_cpu", # known py311 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_2_shape2_cpu", # known py311 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_2_shape0_cpu", # known py311 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_2_shape3_cpu", # known py311 fail
|
||||
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape3_CUDA", # known py38 fail
|
||||
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape0_CUDA", # known py38 fail
|
||||
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape1_CUDA", # known py38 fail
|
||||
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape2_CUDA", # known py38 fail
|
||||
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape3_CUDA", # known py38 fail
|
||||
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_CUDA", # known py38 fail
|
||||
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape2_CUDA", # known py38 fail
|
||||
"TestAttnBiasCUDA.test_causal_variants_compile_causal_variant_CausalVariant_UPPER_LEFT_shape1_CUDA", # known py38 fail
|
||||
"TestTransformersCPU.test_decoder_padding_and_src_mask_bool_cpu", # known py311 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_2_shape3_cpu", # known py311 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_1_shape3_cpu", # known py311 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_1_shape2_cpu", # known py311 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_1_shape1_cpu", # known py311 fail
|
||||
"TestAttnMasksCPU.test_causal_variants_causal_variant_2_shape1_cpu", # known py311 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_2_shape3_cpu", # known py311 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_1_shape3_cpu", # known py311 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_1_shape2_cpu", # known py311 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_1_shape1_cpu", # known py311 fail
|
||||
"TestAttnBiasCPU.test_causal_variants_causal_variant_2_shape1_cpu", # known py311 fail
|
||||
"TestFunctionalAutogradBenchmark.test_fast_tasks", # flaky?
|
||||
"TestFrameworkUtils.test_filtering_env_var", # known py38 fail
|
||||
"TestAsArrayCPU.test_default_device_cpu", # known py38 fail
|
||||
@ -7296,6 +7318,13 @@ dynamo_skips = {
|
||||
"TestFrozenOptimizations.test_conv_bn_folding", # test_jit.py
|
||||
"TestArgmax.test_combinations_data58",
|
||||
"TestArgmax.test_combinations_data61",
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_1_shape0_cpu", # test_transformers.py
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_2_shape0_cpu", # test_transformers.py
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_1_shape3_cpu", # test_transformers.py
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_2_shape2_cpu", # test_transformers.py
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_1_shape2_cpu", # test_transformers.py
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_2_shape1_cpu", # test_transformers.py
|
||||
"TestAttnBiasCPU.test_causal_variants_compile_causal_variant_1_shape1_cpu", # test_transformers.py
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user