mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/157636 Approved by: https://github.com/yewentao256, https://github.com/mlazos ghstack dependencies: #156311, #156609
		
			
				
	
	
		
			2839 lines
		
	
	
		
			104 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			2839 lines
		
	
	
		
			104 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: vmap"]
 | |
| # ruff: noqa: F841
 | |
| 
 | |
| import functools
 | |
| import itertools
 | |
| import types
 | |
| import warnings
 | |
| 
 | |
| import torch
 | |
| import torch.nn.functional as F
 | |
| from torch import Tensor
 | |
| from torch._vmap_internals import vmap
 | |
| from torch.testing._internal.common_device_type import instantiate_device_type_tests
 | |
| from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
 | |
| 
 | |
| 
 | |
| FALLBACK_REGEX = r"There is a performance drop"
 | |
| 
 | |
| 
 | |
| class EnableVmapFallbackWarnings:
 | |
|     def __enter__(self):
 | |
|         self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled()
 | |
|         torch._C._debug_only_display_vmap_fallback_warnings(True)
 | |
| 
 | |
|     def __exit__(self, *ignored):
 | |
|         torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state)
 | |
| 
 | |
| 
 | |
| class TestVmapAPILegacy(TestCase):
 | |
|     def test_non_tensor_output_raises(self):
 | |
|         with self.assertRaisesRegex(
 | |
|             ValueError, "got type <class 'float'> as the return"
 | |
|         ):
 | |
|             output = vmap(lambda x: 3.14)(torch.ones(3))
 | |
| 
 | |
|         def multiple_outputs(x):
 | |
|             return x, 3
 | |
| 
 | |
|         with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"):
 | |
|             vmap(multiple_outputs)(torch.ones(3))
 | |
| 
 | |
|     def test_different_map_dim_size_raises(self):
 | |
|         x = torch.randn(2)
 | |
|         y = torch.randn(3)
 | |
|         expected_msg = (
 | |
|             "Expected all tensors to have the same size in the mapped dimension"
 | |
|         )
 | |
|         with self.assertRaisesRegex(ValueError, expected_msg):
 | |
|             vmap(torch.mul)(x, y)
 | |
|         with self.assertRaisesRegex(ValueError, expected_msg):
 | |
|             vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
 | |
|         with self.assertRaisesRegex(ValueError, expected_msg):
 | |
|             vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))(
 | |
|                 {"x": x, "y": y}
 | |
|             )
 | |
| 
 | |
|     def test_func_with_no_inputs(self):
 | |
|         expected_msg = "got no inputs"
 | |
| 
 | |
|         def foo():
 | |
|             return torch.randn(3)
 | |
| 
 | |
|         def bar(x):
 | |
|             return torch.randn(3)
 | |
| 
 | |
|         with self.assertRaisesRegex(ValueError, expected_msg):
 | |
|             vmap(foo)()
 | |
| 
 | |
|         with self.assertRaisesRegex(ValueError, expected_msg):
 | |
|             vmap(bar)()
 | |
| 
 | |
|     def test_constant_function(self):
 | |
|         output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3))
 | |
|         self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14]))
 | |
| 
 | |
|     def test_single_input(self):
 | |
|         x = torch.randn(2, 3)
 | |
| 
 | |
|         def square(x):
 | |
|             return x * x
 | |
| 
 | |
|         output = vmap(square)(x)
 | |
|         self.assertEqual(output, x * x)
 | |
| 
 | |
|     def test_multiple_inputs(self):
 | |
|         x = torch.randn(2, 3)
 | |
|         y = torch.randn(2, 3)
 | |
|         output = vmap(torch.mul)(x, y)
 | |
|         self.assertEqual(output, x * y)
 | |
| 
 | |
|     def test_multiple_outputs(self):
 | |
|         def foo(x):
 | |
|             return x * x, x * x * x
 | |
| 
 | |
|         x = torch.randn(3)
 | |
|         outputs = vmap(foo)(x)
 | |
|         self.assertEqual(outputs[0], x * x)
 | |
|         self.assertEqual(outputs[1], x * x * x)
 | |
| 
 | |
|     def test_multiple_outputs_error_cases(self):
 | |
|         # This is the same thing as
 | |
|         # def returns_tuple_of_tensors(x):
 | |
|         #     return x, x
 | |
|         def returns_tuple_of_tensors(x):
 | |
|             return (x, x)
 | |
| 
 | |
|         def returns_list_of_two_tensors(x):
 | |
|             return [x, x]
 | |
| 
 | |
|         def returns_list_of_one_tensor(x):
 | |
|             return [x]
 | |
| 
 | |
|         x = torch.randn(3)
 | |
| 
 | |
|         # should not throw
 | |
|         vmap(returns_tuple_of_tensors)(x)
 | |
| 
 | |
|         # jax supports these, but we don't yet
 | |
|         msg = "must only return Tensors, got type <class 'list'>"
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(returns_list_of_two_tensors)(x)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(returns_list_of_one_tensor)(x)
 | |
| 
 | |
|     def test_nested_with_same_map_dim(self):
 | |
|         x = torch.randn(2, 3, 5)
 | |
|         y = torch.randn(2, 3, 5)
 | |
|         output = vmap(vmap(torch.mul))(x, y)
 | |
|         self.assertEqual(output, x * y)
 | |
| 
 | |
|         output = vmap(vmap(vmap(torch.mul)))(x, y)
 | |
|         self.assertEqual(output, x * y)
 | |
| 
 | |
|     def test_nested_with_different_map_dim(self):
 | |
|         x = torch.randn(2, 3)
 | |
|         y = torch.randn(5, 3)
 | |
|         output = vmap(lambda x: vmap(lambda y: x * y)(y))(x)
 | |
|         self.assertEqual(output.shape, (2, 5, 3))
 | |
|         self.assertEqual(output, x.view(2, 1, 3) * y)
 | |
| 
 | |
|         z = torch.randn(7, 3)
 | |
|         output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x)
 | |
|         self.assertEqual(output.shape, (2, 5, 7, 3))
 | |
|         self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)
 | |
| 
 | |
|     def test_noop_in_inner_vmap(self):
 | |
|         x = torch.randn(3)
 | |
|         y = torch.randn(5)
 | |
|         output = vmap(lambda x: vmap(lambda y: x)(y))(x)
 | |
|         self.assertEqual(output, x.view(3, 1).expand(3, 5))
 | |
| 
 | |
|     def test_unsupported_op_err_msg(self):
 | |
|         # Unsupported view op
 | |
|         tensor = torch.randn(2, 3)
 | |
|         msg = (
 | |
|             r"Batching rule not implemented for aten::.+; the "
 | |
|             r"fallback path doesn't work on out= or view ops"
 | |
|         )
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(torch.ravel)(tensor)
 | |
| 
 | |
|         def out_op(x, y):
 | |
|             return torch.abs(x, out=y)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(out_op)(tensor, tensor)
 | |
| 
 | |
|         tensor = torch.randn(2)
 | |
|         # The fallback doesn't support TensorList
 | |
|         with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"):
 | |
|             vmap(lambda t: torch.atleast_1d([t]))(tensor)
 | |
| 
 | |
|         # Don't support non-tensor returns. This is a limitation of vmap;
 | |
|         # functions that don't return tensors must be special cased
 | |
|         with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"):
 | |
|             vmap(torch.Tensor.item)(tensor)
 | |
| 
 | |
|     def test_nonzero_out_dims(self):
 | |
|         # Basic test
 | |
|         tensor = torch.randn(2, 3)
 | |
|         result = vmap(lambda x: x, out_dims=1)(tensor)
 | |
|         self.assertEqual(result, tensor.permute(1, 0))
 | |
|         self.assertEqual(result.data_ptr(), tensor.data_ptr())
 | |
| 
 | |
|         # Test that the batch dimension gets permuted to dim 2
 | |
|         tensor = torch.randn(2, 3, 5, 7)
 | |
|         result = vmap(lambda x: x, out_dims=2)(tensor)
 | |
|         self.assertEqual(result, tensor.permute(1, 2, 0, 3))
 | |
|         self.assertEqual(result.data_ptr(), tensor.data_ptr())
 | |
| 
 | |
|         # negative out_dim
 | |
|         tensor = torch.randn(2, 3, 5, 7)
 | |
|         result = vmap(lambda x: x, out_dims=-1)(tensor)
 | |
|         self.assertEqual(result, tensor.permute(1, 2, 3, 0))
 | |
|         self.assertEqual(result.data_ptr(), tensor.data_ptr())
 | |
| 
 | |
|         # check that out_dims works on ALL outputs
 | |
|         tensor = torch.randn(2, 3, 5, 7)
 | |
|         other = torch.randn(2, 3, 5, 7)
 | |
|         result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
 | |
|         self.assertEqual(
 | |
|             result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3))
 | |
|         )
 | |
| 
 | |
|         # use out_dims with the maximum vmap-able tensor dims (64 dims)
 | |
|         ndims = 64
 | |
|         shape = [2] + [1] * (ndims - 1)
 | |
|         expected_shape = [1, 1, 2] + [1] * (ndims - 3)
 | |
|         tensor = torch.randn(shape)
 | |
|         result = vmap(lambda x: x, out_dims=2)(tensor)
 | |
|         self.assertEqual(result.shape, expected_shape)
 | |
| 
 | |
|         # test something that is not the identity function
 | |
|         def foo(x, y):
 | |
|             return x, x * y, x * y * y
 | |
| 
 | |
|         x = torch.randn(2, 3, 5)
 | |
|         y = torch.randn(2, 3, 5)
 | |
|         result = vmap(foo, out_dims=1)(x, y)
 | |
|         self.assertEqual(
 | |
|             result,
 | |
|             (
 | |
|                 x.permute(1, 0, 2),
 | |
|                 (x * y).permute(1, 0, 2),
 | |
|                 (x * y * y).permute(1, 0, 2),
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|     def test_multiple_out_dims(self):
 | |
|         def foo(x):
 | |
|             return x, x
 | |
| 
 | |
|         def bar(x, y):
 | |
|             return x, x, x, x * y
 | |
| 
 | |
|         x = torch.randn(2, 3, 5)
 | |
|         y = torch.randn(2, 3, 5)
 | |
|         result = vmap(foo, out_dims=(0, 1))(x)
 | |
|         self.assertEqual(result, (x, x.permute(1, 0, 2)))
 | |
| 
 | |
|         result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
 | |
|         expected = (
 | |
|             x.permute(1, 2, 0),
 | |
|             x,
 | |
|             x.permute(1, 0, 2),
 | |
|             (x * y).permute(1, 2, 0),
 | |
|         )
 | |
|         self.assertEqual(result, expected)
 | |
| 
 | |
|     def test_nested_out_dims(self):
 | |
|         y = torch.randn(2, 3, 5, 7)
 | |
| 
 | |
|         # Inner vmap has non-zero out_dim
 | |
|         result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
 | |
|         self.assertEqual(result.shape, (2, 5, 3, 7))
 | |
|         self.assertEqual(result, y.permute(0, 2, 1, 3))
 | |
| 
 | |
|         # all vmaps have non-zero out_dim
 | |
|         result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
 | |
|         self.assertEqual(result.shape, (5, 2, 3, 7))
 | |
|         self.assertEqual(result, y.permute(2, 0, 1, 3))
 | |
| 
 | |
|         # throwing in some negative out_dims
 | |
|         result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
 | |
|         self.assertEqual(result.shape, (5, 7, 3, 2))
 | |
|         self.assertEqual(result, y.permute(2, 3, 1, 0))
 | |
| 
 | |
|         # testing fn that isn't the identity
 | |
|         x = torch.randn(2, 3)
 | |
|         y = torch.randn(5, 3)
 | |
|         result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
 | |
|         self.assertEqual(result.shape, (3, 2, 5))
 | |
|         self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
 | |
| 
 | |
|     def test_out_dims_edge_case(self):
 | |
|         def foo(x):
 | |
|             return x
 | |
| 
 | |
|         # Test that we accept out_dims=(1,) for a function with one output.
 | |
|         tensor = torch.randn(2, 3)
 | |
|         expected = vmap(foo, out_dims=1)(tensor)
 | |
|         result = vmap(foo, out_dims=(1,))(tensor)
 | |
|         self.assertEqual(result, expected)
 | |
| 
 | |
|     def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self):
 | |
|         msg = "`out_dims` must be an int or a tuple of int"
 | |
|         tensor = torch.randn(2, 3)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda x: x, out_dims="lol")(tensor)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda x: x, out_dims=("lol",))(tensor)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda x: x, out_dims=None)(tensor)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda x: x, out_dims=(None,))(tensor)
 | |
| 
 | |
|     def test_out_dims_and_num_outputs_mismatch_err_msg(self):
 | |
|         msg = "`out_dims` must have one dim per output"
 | |
|         x = torch.randn(2, 3, 5)
 | |
| 
 | |
|         # Too many out_dims
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda x: x, out_dims=(0, 0))(x)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)
 | |
| 
 | |
|         # Too few out_dims
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda x: (x, x), out_dims=(0,))(x)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)
 | |
| 
 | |
|     def test_out_dim_out_of_bounds_err_msg(self):
 | |
|         # TODO(rzou): This error message isn't that great. It comes straight
 | |
|         # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
 | |
|         # the error message in the future in C++
 | |
|         msg = "Dimension out of range"
 | |
|         x = torch.randn(2, 3, 5)
 | |
|         with self.assertRaisesRegex(IndexError, msg):
 | |
|             vmap(lambda x: x, out_dims=3)(x)
 | |
|         with self.assertRaisesRegex(IndexError, msg):
 | |
|             vmap(lambda x: x, out_dims=-4)(x)
 | |
| 
 | |
|     def test_non_zero_in_dims(self):
 | |
|         tensor = torch.randn(2, 3, 5)
 | |
| 
 | |
|         # Implicit out_dims = 0; vmap will move the batch dim to the front.
 | |
|         output = vmap(lambda x: x, (1,))(tensor)
 | |
|         self.assertEqual(output, tensor.permute(1, 0, 2))
 | |
|         self.assertEqual(output.data_ptr(), tensor.data_ptr())
 | |
| 
 | |
|         x = torch.randn(2, 3)
 | |
|         y = torch.randn(3, 2)
 | |
|         output = vmap(torch.mul, (0, 1))(x, y)
 | |
|         self.assertEqual(output, x * y.t())
 | |
|         output = vmap(torch.mul, (1, 0))(x, y)
 | |
|         self.assertEqual(output, x.t() * y)
 | |
| 
 | |
|     def test_none_in_dims(self):
 | |
|         x = torch.randn(2, 3)
 | |
|         y = torch.randn(2, 3)
 | |
| 
 | |
|         # None in_dim for a Tensor means we don't map over it
 | |
|         output = vmap(torch.mul, (0, None))(x, y)
 | |
|         self.assertEqual(output.shape, (2, 2, 3))
 | |
|         self.assertEqual(output, x.view(2, 1, 3) * y)
 | |
| 
 | |
|         # None in_dim for non-tensor arguments
 | |
|         output = vmap(torch.mul, (0, None))(x, 2)
 | |
|         self.assertEqual(output, x * 2)
 | |
| 
 | |
|     def test_nested_non_default_in_dims(self):
 | |
|         x = torch.rand(5, 2, 3)
 | |
|         y = torch.rand(3, 5, 2)
 | |
|         result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y)
 | |
|         self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1))
 | |
| 
 | |
|     def test_non_default_in_dims_out_dims(self):
 | |
|         x = torch.randn(2, 3, 5)
 | |
| 
 | |
|         # Same in_dim as out_dim, vmap over identity
 | |
|         result = vmap(lambda x: x, in_dims=1, out_dims=1)(x)
 | |
|         self.assertEqual(result, x)
 | |
|         self.assertEqual(result.data_ptr(), x.data_ptr())
 | |
| 
 | |
|         # Different in_dim from out_dim, vmap over identity
 | |
|         result = vmap(lambda x: x, in_dims=2, out_dims=1)(x)
 | |
|         self.assertEqual(result.shape, (2, 5, 3))
 | |
|         self.assertEqual(result, x.transpose(1, 2))
 | |
|         self.assertEqual(result.data_ptr(), x.data_ptr())
 | |
| 
 | |
|         def foo(x):
 | |
|             return x * 2
 | |
| 
 | |
|         # Same in_dim as out_dim, vmap over operation
 | |
|         result = vmap(foo, in_dims=1, out_dims=1)(x)
 | |
|         self.assertEqual(result, x * 2)
 | |
| 
 | |
|         # Different in_dim as out_dim, vmap over operation
 | |
|         result = vmap(foo, in_dims=2, out_dims=1)(x)
 | |
|         self.assertEqual(result.shape, (2, 5, 3))
 | |
|         self.assertEqual(result, (x * 2).transpose(1, 2))
 | |
| 
 | |
|         # Basic nested test.
 | |
|         result = vmap(vmap(foo, 1, 1), 1, 1)(x)
 | |
|         self.assertEqual(result, x * 2)
 | |
| 
 | |
|     def test_accepts_nested_inputs(self):
 | |
|         B0 = 2
 | |
|         x = torch.randn(2, 3)
 | |
|         y = torch.randn(2, 3)
 | |
| 
 | |
|         # Single layer of nesting
 | |
|         out = vmap(lambda z: z[0] + z[1])((x, y))
 | |
|         self.assertEqual(out, x + y)
 | |
|         out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y))
 | |
|         self.assertEqual(out, x + y)
 | |
|         out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
 | |
|         self.assertEqual(out, x + y)
 | |
| 
 | |
|         out = vmap(lambda z: z[0] + z[1])([x, y])
 | |
|         self.assertEqual(out, x + y)
 | |
|         out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y])
 | |
|         self.assertEqual(out, x + y)
 | |
|         out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y])
 | |
|         self.assertEqual(out, x + y)
 | |
| 
 | |
|         out = vmap(lambda z: z["x"] + z["y"])({"x": x, "y": y})
 | |
|         self.assertEqual(out, x + y)
 | |
|         out = vmap(lambda z: z["x"] + z["y"], in_dims=(0,))({"x": x, "y": y})
 | |
|         self.assertEqual(out, x + y)
 | |
|         out = vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))(
 | |
|             {"x": x, "y": y}
 | |
|         )
 | |
|         self.assertEqual(out, x + y)
 | |
| 
 | |
|         # Multiple layers of nesting
 | |
|         out_fn = vmap(lambda z: z["x"][0] + z["x"][1][0] + z["y"][0] + z["y"][1])
 | |
|         out = out_fn({"x": [x, (x,)], "y": [y, y]})
 | |
|         self.assertEqual(out, x + x + y + y)
 | |
| 
 | |
|     def test_in_dims_wrong_type_err_msg(self):
 | |
|         x = torch.randn(3)
 | |
|         y = torch.randn(3)
 | |
|         msg = r"expected `in_dims` to be int or a \(potentially nested\) tuple"
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(torch.mul, [0, 0])(x, y)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(torch.mul, set({0}))(x, y)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(torch.mul, "lol")(x, y)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y])
 | |
|         # The following should not throw
 | |
|         vmap(torch.mul, (0, 0))(x, y)
 | |
| 
 | |
|     def test_not_enough_in_dims_err_msg(self):
 | |
|         x = torch.randn(3)
 | |
|         y = torch.randn(3)
 | |
|         msg = r"in_dims is not compatible with the structure of `inputs`"
 | |
| 
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(torch.mul, (0,))(x, y)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(torch.mul, (0, 0, 0))(x, y)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y])
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y])
 | |
|         # The following should not throw
 | |
|         vmap(torch.mul, (0, 0))(x, y)
 | |
| 
 | |
|     def test_integer_in_dim_but_not_tensor_input_err_msg(self):
 | |
|         def foo(xy):
 | |
|             return xy[0] * xy[1]
 | |
| 
 | |
|         def bar(x, yz):
 | |
|             return x * yz[0] * yz[1]
 | |
| 
 | |
|         x = torch.randn(2, 3)
 | |
|         y = torch.randn(2, 3)
 | |
| 
 | |
|         # the following are errors in jax (and will always be errors)
 | |
|         msg = "Got in_dim=0 for an input but the input is of type"
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(torch.sum)(x, 0)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(torch.sum, (0, 0))(x, 0)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1])
 | |
|         # The following should not throw
 | |
|         vmap(torch.sum, (0, None))(x, 0)
 | |
| 
 | |
|     def test_in_dim_not_in_tensor_err_msg(self):
 | |
|         def foo(x):
 | |
|             return x * x
 | |
| 
 | |
|         x = torch.randn(2, 3)
 | |
|         y = torch.randn(2, 3)
 | |
| 
 | |
|         msg = r"Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w"
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(foo)(torch.randn([]))
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(foo, in_dims=(0,))(torch.randn([]))
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(foo, in_dims=(-1,))(x)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(foo, in_dims=(2,))(y)
 | |
|         with self.assertRaisesRegex(ValueError, msg):
 | |
|             vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y])
 | |
|         # the following should not throw
 | |
|         vmap(foo, in_dims=(0,))(torch.randn(2, 3))
 | |
|         vmap(foo, in_dims=(1,))(torch.randn(2, 3))
 | |
| 
 | |
|     def test_fallback_does_not_warn_by_default(self):
 | |
|         # NB: One day we will implement a batching rule for torch.atan2.
 | |
|         # If/when we do, this test should be replaced to test the fallback
 | |
|         # path on another operator to avoid bitrot.
 | |
|         op = torch.atan2
 | |
|         x = torch.randn(11)
 | |
|         y = torch.randn(11)
 | |
|         with warnings.catch_warnings(record=True) as wa:
 | |
|             result = vmap(op)(x, y)
 | |
|             # The single warning here is the "vmap is experimental"
 | |
|             # warning, not a warning from the vmap fallback path.
 | |
|             self.assertEqual(len(wa), 1)
 | |
| 
 | |
|     def test_fallback_warns_when_warnings_are_enabled(self):
 | |
|         # NB: One day we will implement a batching rule for torch.atan2.
 | |
|         # If/when we do, this test should be replaced to test the fallback
 | |
|         # path on another operator to avoid bitrot.
 | |
|         op = torch.atan2
 | |
|         x = torch.randn(11)
 | |
|         y = torch.randn(11)
 | |
|         with warnings.catch_warnings(record=True) as wa:
 | |
|             with EnableVmapFallbackWarnings():
 | |
|                 result = vmap(op)(x, y)
 | |
|             self.assertEqual(len(wa), 2)
 | |
|             self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
 | |
| 
 | |
|     def _assert_uses_vmap_fallback(self, vmap_args, inputs):
 | |
|         with warnings.catch_warnings(record=True) as wa:
 | |
|             with EnableVmapFallbackWarnings():
 | |
|                 result = vmap(*vmap_args)(*inputs)
 | |
|             self.assertEqual(len(wa), 2)
 | |
|             self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
 | |
| 
 | |
|     def test_fallback_zero_dim(self):
 | |
|         # NB: One day we will implement a batching rule for torch.atan2.
 | |
|         # If/when we do, this test should be replaced to test the fallback
 | |
|         # path on another operator to avoid bitrot.
 | |
|         op = torch.atan2
 | |
|         x = torch.randn(11)
 | |
|         y = torch.randn(11)
 | |
|         self._assert_uses_vmap_fallback((op,), (x, y))
 | |
| 
 | |
|         B0, B1 = 0, 3
 | |
|         x = torch.randn(B0, 11)
 | |
|         y = torch.randn(11)
 | |
| 
 | |
|         msg = "The fallback path does not support vmap over dims of size 0"
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, (0, None))(x, y)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, (None, 0))(y, x)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op)(x, x)
 | |
| 
 | |
|         x = torch.randn(B0, B1, 11)
 | |
|         y = torch.randn(B1, 11)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, (0, None))(x, y)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, (None, 0))(y, x)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op)(x, x)
 | |
| 
 | |
|     def test_fallback_atan2(self):
 | |
|         # NB: One day we will implement a batching rule for torch.atan2.
 | |
|         # If/when we do, this test should be replaced to test the fallback
 | |
|         # path on another operator to avoid bitrot.
 | |
|         op = torch.atan2
 | |
| 
 | |
|         x = torch.randn(5, 7, 11)
 | |
|         y = torch.randn(5, 7, 11)
 | |
| 
 | |
|         self._assert_uses_vmap_fallback((op,), (x, y))
 | |
| 
 | |
|         # fallback on torch.atan2
 | |
|         x = torch.randn(7, 11, 5)
 | |
|         y = torch.randn(5, 7, 11)
 | |
|         result = vmap(op, (2, 0))(x, y)
 | |
|         self.assertEqual(result, op(x.permute(2, 0, 1), y))
 | |
| 
 | |
|         # fallback on torch.atan2, nested vmap
 | |
|         x = torch.randn(7, 11, 5)
 | |
|         y = torch.randn(5, 7, 11)
 | |
|         result = vmap(vmap(op), (2, 0))(x, y)
 | |
|         self.assertEqual(result, op(x.permute(2, 0, 1), y))
 | |
| 
 | |
|         # big batch size (total 10000)
 | |
|         x = torch.randn(100, 10, 10, 5)
 | |
|         y = torch.randn(100, 10, 10)
 | |
|         result = vmap(vmap(vmap(op)))(x, y)
 | |
|         self.assertEqual(result, op(x, y.view(100, 10, 10, 1)))
 | |
| 
 | |
|     def test_fallback_masked_fill(self):
 | |
|         # NB: One day we will implement a batching rule for masked_fill
 | |
|         # If/when we do, this test should be replaced to test the fallback
 | |
|         # path on another operator to avoid bitrot.
 | |
|         def run_test(batch_size):
 | |
|             B0 = batch_size
 | |
|             x = torch.randn(B0, 7, 11, 13)
 | |
|             dim = 0
 | |
|             index = torch.tensor([0, 4, 2])
 | |
|             values = torch.randn(B0, 3, 11, 13)
 | |
| 
 | |
|             self._assert_uses_vmap_fallback(
 | |
|                 (torch.index_add, (0, None, None, 0)), (x, dim, index, values)
 | |
|             )
 | |
| 
 | |
|             result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
 | |
|             expected = torch.index_add(x, dim + 1, index, values.view(B0, 3, 11, 13))
 | |
|             self.assertEqual(result, expected)
 | |
| 
 | |
|         run_test(batch_size=5)
 | |
|         run_test(batch_size=1237)
 | |
| 
 | |
|     def test_fallback_multiple_returns(self):
 | |
|         # NB: One day we will implement a batching rule for torch.var_mean
 | |
|         # If/when we do, this test should be replaced to test the fallback
 | |
|         # path on another operator to avoid bitrot.
 | |
|         B0, B1, B2 = 2, 3, 1237
 | |
|         tensor = torch.randn(B0, 10)
 | |
| 
 | |
|         self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,))
 | |
| 
 | |
|         # fallback correctness on torch.var_mean
 | |
|         result = vmap(torch.var_mean)(tensor)
 | |
|         expected = torch.var_mean(tensor, dim=1)
 | |
|         self.assertEqual(result, expected)
 | |
| 
 | |
|         # nested vmap
 | |
|         tensor = torch.randn(B0, B1, 10)
 | |
|         result = vmap(vmap(torch.var_mean))(tensor)
 | |
|         expected = torch.var_mean(tensor, dim=2)
 | |
|         self.assertEqual(result, expected)
 | |
| 
 | |
|         # big batch size, nested vmap
 | |
|         tensor = torch.randn(B0, B1, B2, 10)
 | |
|         result = vmap(vmap(vmap(torch.var_mean)))(tensor)
 | |
|         expected = torch.var_mean(tensor, dim=3)
 | |
|         self.assertEqual(result, expected)
 | |
| 
 | |
|     def test_inplace_fallback_unary(self):
 | |
|         # Test the in-place fallback on an in-place method that takes no
 | |
|         # additional Tensor arguments. This is the simplest case of the fallback.
 | |
|         # NB: One day we will implement a batching rule for acos_.
 | |
|         # If/when we do, this test should be replaced to test the fallback
 | |
|         # path on another operator to avoid bitrot.
 | |
|         op = Tensor.acos_
 | |
|         B0, B1, B2 = 2, 3, 10000
 | |
| 
 | |
|         x = torch.randn(B0, 5)
 | |
|         self._assert_uses_vmap_fallback((op,), (x,))
 | |
| 
 | |
|         # Single vmap
 | |
|         x_orig = torch.rand(B0, 5)
 | |
|         x = x_orig.clone()
 | |
|         result = vmap(op)(x)
 | |
|         self.assertTrue(result is x)
 | |
|         self.assertEqual(result, x_orig.acos())
 | |
| 
 | |
|         # Single vmap + different out_dim produces a view(!)
 | |
|         x_orig = torch.rand(B0, 5)
 | |
|         x = x_orig.clone()
 | |
|         result = vmap(op, out_dims=(1,))(x)
 | |
|         self.assertTrue(result._base is x)
 | |
|         self.assertEqual(result, x_orig.t().acos())
 | |
| 
 | |
|         # Nested vmap
 | |
|         x_orig = torch.randn(B0, B1, 5)
 | |
|         x = x_orig.clone()
 | |
|         result = vmap(vmap(op))(x)
 | |
|         self.assertTrue(result is x)
 | |
|         self.assertEqual(result, x_orig.acos())
 | |
| 
 | |
|         # Nested vmap, large batch size
 | |
|         x_orig = torch.randn(B0, B1, B2, 5)
 | |
|         x = x_orig.clone()
 | |
|         result = vmap(vmap(vmap(op)))(x)
 | |
|         self.assertTrue(result is x)
 | |
|         self.assertEqual(result, x_orig.acos())
 | |
| 
 | |
|     def test_inplace_fallback_nary_same_levels(self):
 | |
|         # NB: One day we will implement a batching rule for atan2_
 | |
|         # If/when we do, this test should be replaced to test the fallback
 | |
|         # path on another operator to avoid bitrot.
 | |
|         op = Tensor.atan2_
 | |
|         outplace_op = torch.atan2
 | |
| 
 | |
|         x = torch.randn(5, 7, 11)
 | |
|         y = torch.randn(5, 7, 11)
 | |
|         self._assert_uses_vmap_fallback((op,), (x, y))
 | |
| 
 | |
|         # Single vmap
 | |
|         B0 = 5
 | |
|         x_orig = torch.randn(7, 11, B0)
 | |
|         x = x_orig.clone()
 | |
|         y = torch.randn(B0, 7, 11)
 | |
|         vmap(op, (2, 0))(x, y)
 | |
|         self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2)))
 | |
| 
 | |
|         # Nested vmap
 | |
|         B0, B1 = 5, 7
 | |
|         x_orig = torch.randn(B1, 11, B0)
 | |
|         x = x_orig.clone()
 | |
|         y = torch.randn(B0, B1, 11)
 | |
|         vmap(vmap(op), (2, 0))(x, y)
 | |
|         self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0])))
 | |
| 
 | |
|         # big batch size (total 10000)
 | |
|         B0, B1, B2 = 100, 10, 10
 | |
|         x_orig = torch.randn(B0, B1, B2, 5)
 | |
|         x = x_orig.clone()
 | |
|         y = torch.randn(B0, B1, B2)
 | |
|         result = vmap(vmap(vmap(op)))(x, y)
 | |
|         self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1)))
 | |
| 
 | |
|     def test_inplace_fallback_nary_different_levels(self):
 | |
|         # NB: One day we will implement a batching rule for atan2_
 | |
|         # If/when we do, this test should be replaced to test the fallback
 | |
|         # path on another operator to avoid bitrot.
 | |
|         op = Tensor.atan2_
 | |
|         outplace_op = torch.atan2
 | |
|         B0, B1, B2 = 2, 3, 5
 | |
| 
 | |
|         x = torch.rand(B0, 7)
 | |
|         y = torch.rand(7)
 | |
|         self._assert_uses_vmap_fallback((op, (0, None)), (x, y))
 | |
| 
 | |
|         # op(left, right): All of the levels in right are found in left
 | |
|         x_orig = torch.rand(B0, 7)
 | |
|         x = x_orig.clone()
 | |
|         y = torch.rand(7)
 | |
|         vmap(op, in_dims=(0, None))(x, y)
 | |
|         self.assertEqual(x, outplace_op(x_orig, y))
 | |
| 
 | |
|         x_orig = torch.rand(B0, B1, 7)
 | |
|         x = x_orig.clone()
 | |
|         y = torch.rand(B0, 7)
 | |
|         vmap(vmap(op, in_dims=(0, None)))(x, y)
 | |
|         self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7)))
 | |
| 
 | |
|         # op(left, right): Some of the levels in right are not found in left
 | |
|         msg = r"vmap: aten::atan2_\(self, \*extra_args\) is not possible"
 | |
|         x = torch.rand(7)
 | |
|         y = torch.rand(B0, 7)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, in_dims=(None, 0))(x, y)
 | |
| 
 | |
|         x = torch.rand(B1, 7)
 | |
|         y = torch.rand(B0, 7)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y)
 | |
| 
 | |
|         x = torch.rand(B1, 7)
 | |
|         y = torch.rand(7, B0)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y)
 | |
| 
 | |
|         x = torch.rand(B0, 7)
 | |
|         y = torch.rand(B0, B1, 7)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(vmap(op, in_dims=(None, 0)))(x, y)
 | |
| 
 | |
|     def test_backward_unsupported_interaction(self):
 | |
|         x = torch.randn(3, requires_grad=True)
 | |
|         y = torch.randn(5)
 | |
|         grad = torch.randn_like(x)
 | |
|         err_msg = r"backward\(\) called inside torch.vmap"
 | |
| 
 | |
|         def backward_on_vmapped_tensor(x):
 | |
|             x.sum().backward()
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|             vmap(backward_on_vmapped_tensor)(x)
 | |
| 
 | |
|         def backward_with_vmapped_grad(x, grad):
 | |
|             x.backward(grad)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|             vmap(backward_with_vmapped_grad)(x, grad)
 | |
| 
 | |
|         def completely_unrelated_backward(y):
 | |
|             x.sum().backward()
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|             vmap(completely_unrelated_backward)(y)
 | |
| 
 | |
|     def test_grad_unsupported_interaction(self):
 | |
|         input_tensor = torch.randn(3, requires_grad=True)
 | |
|         err_msg = "autograd.grad.* called inside torch.vmap"
 | |
| 
 | |
|         captured = torch.randn(3, requires_grad=True)
 | |
| 
 | |
|         def output_to_grad_is_vmapped(input_tensor):
 | |
|             output = (captured * input_tensor).sum()
 | |
|             return torch.autograd.grad([output], [captured])[0]
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|             vmap(output_to_grad_is_vmapped)(input_tensor)
 | |
| 
 | |
|         output = (input_tensor**2).sum()
 | |
| 
 | |
|         def input_to_grad_is_vmapped(input_tensor):
 | |
|             return torch.autograd.grad([output], [input_tensor])[0]
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, err_msg):
 | |
|             vmap(input_to_grad_is_vmapped)(input_tensor)
 | |
| 
 | |
|     def test_batched_gradient_basic(self):
 | |
|         N = 3
 | |
|         x = torch.randn(N, requires_grad=True)
 | |
|         y = torch.randn(N)
 | |
| 
 | |
|         def vjp_mul(v):
 | |
|             return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0]
 | |
| 
 | |
|         batched_v = torch.eye(N)
 | |
|         jacobian = vmap(vjp_mul)(batched_v)
 | |
|         self.assertEqual(jacobian, torch.diagflat(y))
 | |
| 
 | |
|     def test_functools_partial(self):
 | |
|         x = torch.randn(3)
 | |
|         y = torch.randn(2, 3)
 | |
|         result = vmap(functools.partial(torch.mul, x))(y)
 | |
|         self.assertEqual(result, x * y)
 | |
| 
 | |
|     def test_nn_module(self):
 | |
|         tensor = torch.randn(2, 3)
 | |
|         model = torch.nn.Linear(3, 3, bias=False)
 | |
|         result = vmap(model)(tensor)
 | |
|         self.assertEqual(result, model(tensor))
 | |
| 
 | |
|     def test_fallback_with_undefined_grad(self):
 | |
|         B0 = 7
 | |
|         x = torch.randn(2, 3, 4, 5, requires_grad=True)
 | |
|         weight = torch.randn(3, 3, 1, 1)
 | |
|         v = torch.randn(B0, 2, 3, 4, 5)
 | |
| 
 | |
|         def get_vjp(v):
 | |
|             result = torch.nn.functional.conv2d(x, weight)
 | |
|             (grad_x,) = torch.autograd.grad(result, x, v)
 | |
|             return grad_x
 | |
| 
 | |
|         # Runs vmap(get_vjp)(v), which should not error out.
 | |
|         # The backward formula for convolution returns an undefined
 | |
|         # Tensor for grad_bias because the original bias does not exist.
 | |
|         #
 | |
|         # In the future we'll probably add a batching rule for convolution
 | |
|         # backward. When this happens, we should modify this test to use a
 | |
|         # different op (and/or create and use a dummy operator) to avoid bitrot.
 | |
|         self._assert_uses_vmap_fallback([get_vjp], [v])
 | |
| 
 | |
| 
 | |
| def slice_inputs(inputs, bdims, i):
 | |
|     result = []
 | |
|     for inp, bdim in zip(inputs, bdims):
 | |
|         if bdim is None:
 | |
|             result.append(inp)
 | |
|         else:
 | |
|             result.append(inp.select(bdim, i))
 | |
|     return tuple(result)
 | |
| 
 | |
| 
 | |
| def reference_vmap(op, inputs, in_dims=0, out_dims=0):
 | |
|     if isinstance(in_dims, int):
 | |
|         in_dims = (in_dims,) * len(inputs)
 | |
|     bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None]
 | |
|     assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes)
 | |
|     bdim_size = bdim_sizes[0]
 | |
|     results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size))
 | |
| 
 | |
|     assert len(results) > 0
 | |
|     op_has_single_return = not isinstance(results[0], tuple)
 | |
|     if op_has_single_return:
 | |
|         assert all(isinstance(result, torch.Tensor) for result in results)
 | |
|         if isinstance(out_dims, int):
 | |
|             out_dims = (out_dims,) * 1
 | |
|         return torch.stack(results, dim=out_dims[0])
 | |
| 
 | |
|     assert all(isinstance(result, tuple) for result in results)
 | |
|     num_returns = len(results[0])
 | |
|     assert all(len(result) == num_returns for result in results)
 | |
|     if isinstance(out_dims, int):
 | |
|         out_dims = (out_dims,) * num_returns
 | |
|     return tuple(
 | |
|         torch.stack(result_shards, out_dim)
 | |
|         for result_shards, out_dim in zip(zip(*results), out_dims)
 | |
|     )
 | |
| 
 | |
| 
 | |
| class TensorFactory:
 | |
|     @staticmethod
 | |
|     def rand(size, device="cpu", dtype=torch.float):
 | |
|         return torch.rand(size, device=device, dtype=dtype)
 | |
| 
 | |
|     @staticmethod
 | |
|     def randn(size, device="cpu", dtype=torch.float):
 | |
|         return torch.randn(size, device=device, dtype=dtype)
 | |
| 
 | |
|     @staticmethod
 | |
|     def randp1(size, device="cpu", dtype=torch.float):
 | |
|         return torch.rand(size, device=device, dtype=dtype) + 1
 | |
| 
 | |
| 
 | |
| # Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a
 | |
| # (slow) sequential map+stack fallback.
 | |
| #
 | |
| # check_view: Test if the first returned output is a view of the first input
 | |
| # check_propagates_grad: Test if the operation propagates gradients.
 | |
| def _vmap_test(
 | |
|     self,
 | |
|     op,
 | |
|     inputs,
 | |
|     in_dims=0,
 | |
|     out_dims=0,
 | |
|     check_view=False,
 | |
|     check_propagates_grad=True,
 | |
| ):
 | |
|     result = vmap(op, in_dims, out_dims)(*inputs)
 | |
|     reference_result = reference_vmap(op, inputs, in_dims, out_dims)
 | |
|     self.assertEqual(result, reference_result)
 | |
|     op_has_single_return = not isinstance(result, tuple)
 | |
| 
 | |
|     if check_view:
 | |
|         result_as_tuple = (result,) if op_has_single_return else result
 | |
|         for output in result_as_tuple:
 | |
|             input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base
 | |
|             self.assertTrue(
 | |
|                 output._base is input0_base,
 | |
|                 msg="result was not a view of the first input!",
 | |
|             )
 | |
| 
 | |
|     if not check_propagates_grad:
 | |
|         return
 | |
|     # Assuming input[0] is a floating-point tensor. Check if the vmap
 | |
|     # operation propagates the requires_grad flag to the zeroth output.
 | |
|     # Some vmap operators are implemented in a way that assumes that
 | |
|     # they are composite with respect to autograd. If the operator ever is
 | |
|     # changed to not be composite with respect to autograd, then the
 | |
|     # following check should fail.
 | |
|     inputs_clone = list(inputs)
 | |
|     inputs_clone[0] = inputs[0].clone().requires_grad_()
 | |
|     result = vmap(op, in_dims, out_dims)(*inputs_clone)
 | |
|     result_as_tuple = (result,) if op_has_single_return else result
 | |
|     self.assertTrue(result[0].requires_grad)
 | |
| 
 | |
| 
 | |
| def should_allow_vmap_fallback_usage(fn):
 | |
|     return getattr(fn, "_allow_vmap_fallback_usage", False)
 | |
| 
 | |
| 
 | |
| def allowVmapFallbackUsage(fn):
 | |
|     fn._allow_vmap_fallback_usage = True
 | |
|     return fn
 | |
| 
 | |
| 
 | |
| # All tests of TestVmapBaseLegacy check that the slow vmap fallback is never invoked.
 | |
| # This is so that we can incrementally add batching rules for operators to
 | |
| # replace the slow vmap fallback path for said operators. To skip this check,
 | |
| # please use the allowVmapFallbackUsage decorator.
 | |
| #
 | |
| # NB: Don't add tests to TestVmapBaseLegacy directly, unless you want them to run
 | |
| # on every subclass of TestVmapBaseLegacy. Add them to e.g. TestVmapOperators.
 | |
| #
 | |
| # NB: TestVmapBaseLegacy is a nested class. This prevents test runners from picking
 | |
| # it up and running it.
 | |
| class Namespace:
 | |
|     class TestVmapBaseLegacy(TestCase):
 | |
|         def __init__(self, method_name="runTest"):
 | |
|             super().__init__(method_name)
 | |
| 
 | |
|             test_method = getattr(self, method_name, None)
 | |
|             if test_method is None:
 | |
|                 return
 | |
| 
 | |
|             if not should_allow_vmap_fallback_usage(test_method):
 | |
|                 setattr(
 | |
|                     self,
 | |
|                     method_name,
 | |
|                     self._wrap_method_with_vmap_fallback_check(test_method),
 | |
|                 )
 | |
| 
 | |
|         def _wrap_method_with_vmap_fallback_check(self, method):
 | |
|             msg = (
 | |
|                 "Expected the test to not invoke the vmap fallback path, i.e., "
 | |
|                 "all of the operators being tested in this test should have batching "
 | |
|                 "rules implemented. If you are intentionally testing something to "
 | |
|                 "do with the fallback path, use allowVmapFallbackUsage. Otherwise, "
 | |
|                 "please make sure that batching rules are implemented for the "
 | |
|                 "operator(s) being tested."
 | |
|             )
 | |
| 
 | |
|             @functools.wraps(method)
 | |
|             def wrapper(self, *args, **kwargs):
 | |
|                 with warnings.catch_warnings(record=True) as wa:
 | |
|                     warnings.simplefilter("always")
 | |
|                     with EnableVmapFallbackWarnings():
 | |
|                         method(*args, **kwargs)
 | |
|                     for captured_warning in wa:
 | |
|                         self.assertNotRegex(
 | |
|                             str(captured_warning.message), FALLBACK_REGEX, msg
 | |
|                         )
 | |
| 
 | |
|             return types.MethodType(wrapper, self)
 | |
| 
 | |
|         @allowVmapFallbackUsage
 | |
|         def test_vmap_fallback_check_ok(self):
 | |
|             # One day we'll implement a batching rule for torch.var_mean.
 | |
|             # When that happens, please change the example to use an
 | |
|             # operator that doesn't have a batching rule implemented.
 | |
|             op_using_fallback = torch.var_mean
 | |
|             vmap(op_using_fallback)(torch.rand(3))
 | |
| 
 | |
|         def test_vmap_fallback_check(self):
 | |
|             @self._wrap_method_with_vmap_fallback_check
 | |
|             def no_fallback(self):
 | |
|                 pass
 | |
| 
 | |
|             # One day we'll implement a batching rule for torch.var_mean.
 | |
|             # When that happens, please change the example to use an
 | |
|             # operator that doesn't have a batching rule implemented.
 | |
|             op_using_fallback = torch.var_mean
 | |
| 
 | |
|             @self._wrap_method_with_vmap_fallback_check
 | |
|             def uses_fallback(self):
 | |
|                 vmap(op_using_fallback)(torch.rand(3))
 | |
| 
 | |
|             no_fallback(self)
 | |
| 
 | |
|             with self.assertRaises(AssertionError):
 | |
|                 uses_fallback(self)
 | |
| 
 | |
| 
 | |
| class TestVmapOperatorsLegacy(Namespace.TestVmapBaseLegacy):
 | |
|     def _vmap_test(self, *args, **kwargs):
 | |
|         return _vmap_test(self, *args, **kwargs)
 | |
| 
 | |
|     def _vmap_view_test(self, *args, **kwargs):
 | |
|         self._vmap_test(*args, **kwargs, check_view=True)
 | |
| 
 | |
|     def _test_unary(self, op, getter, device, *args, **kwargs):
 | |
|         test = functools.partial(self._vmap_test, *args, **kwargs)
 | |
|         B0, B1 = 7, 11
 | |
| 
 | |
|         # Single vmap, various in_dims / out_dims
 | |
|         test(op, [getter([B0, 3], device)])
 | |
|         test(op, [getter([2, 5, B0, 3], device)], in_dims=2)
 | |
|         test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2)
 | |
| 
 | |
|         # Doubly nested vmap
 | |
|         test(vmap(op), [getter([B0, B1], device)])
 | |
|         test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2)
 | |
|         test(
 | |
|             vmap(op, in_dims=2),
 | |
|             [getter([2, 5, B0, B1, 3], device)],
 | |
|             in_dims=2,
 | |
|             out_dims=2,
 | |
|         )
 | |
| 
 | |
|     def test_unary_pointwise_ops(self):
 | |
|         cases = [
 | |
|             (torch.abs, TensorFactory.randn),
 | |
|             (torch.acos, TensorFactory.rand),
 | |
|             (torch.asin, TensorFactory.rand),
 | |
|             (torch.atan, TensorFactory.rand),
 | |
|             (torch.ceil, TensorFactory.randn),
 | |
|             (torch.cos, TensorFactory.rand),
 | |
|             (torch.cosh, TensorFactory.rand),
 | |
|             (torch.digamma, TensorFactory.rand),
 | |
|             (torch.exp, TensorFactory.randn),
 | |
|             (torch.expm1, TensorFactory.randn),
 | |
|             (torch.floor, TensorFactory.randn),
 | |
|             (torch.frac, TensorFactory.randn),
 | |
|             (torch.lgamma, TensorFactory.rand),
 | |
|             (torch.log, TensorFactory.randp1),
 | |
|             (torch.log10, TensorFactory.randp1),
 | |
|             (torch.log1p, TensorFactory.randp1),
 | |
|             (torch.log2, TensorFactory.randp1),
 | |
|             (torch.neg, TensorFactory.randn),
 | |
|             (torch.reciprocal, TensorFactory.randp1),
 | |
|             (torch.relu, TensorFactory.randn),
 | |
|             (torch.round, TensorFactory.randn),
 | |
|             (torch.rsqrt, TensorFactory.randp1),
 | |
|             (torch.sigmoid, TensorFactory.randn),
 | |
|             (torch.sign, TensorFactory.randn),
 | |
|             (torch.sin, TensorFactory.rand),
 | |
|             (torch.sinh, TensorFactory.rand),
 | |
|             (torch.sqrt, TensorFactory.rand),
 | |
|             (torch.tan, TensorFactory.rand),
 | |
|             (torch.tanh, TensorFactory.rand),
 | |
|             (torch.trunc, TensorFactory.randn),
 | |
|         ]
 | |
|         for op, getter in cases:
 | |
|             self._test_unary(op, getter, "cpu")
 | |
| 
 | |
|     def test_clone(self):
 | |
|         # Some basic tests
 | |
|         self._test_unary(lambda x: x.clone(), TensorFactory.randn, "cpu")
 | |
|         self._test_unary(
 | |
|             lambda x: x.clone(memory_format=torch.preserve_format),
 | |
|             TensorFactory.randn,
 | |
|             "cpu",
 | |
|         )
 | |
|         self._test_unary(
 | |
|             lambda x: x.clone(memory_format=torch.contiguous_format),
 | |
|             TensorFactory.randn,
 | |
|             "cpu",
 | |
|         )
 | |
| 
 | |
|         # Test that the per-examples are contiguous when using torch.contiguous_format
 | |
|         def clone_contiguous(x):
 | |
|             return x.clone(memory_format=torch.contiguous_format)
 | |
| 
 | |
|         B0, B1 = 3, 5
 | |
|         x = torch.randn(2, B0, 7)
 | |
|         y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x)
 | |
|         self.assertTrue(y.movedim(1, 0).is_contiguous())
 | |
|         self.assertTrue(y[:, 0, :].is_contiguous())
 | |
| 
 | |
|         x = torch.randn(2, B0, 7, B1)
 | |
|         y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x)
 | |
|         self.assertTrue(y.is_contiguous())
 | |
|         self.assertTrue(y[0][0].is_contiguous())
 | |
| 
 | |
|         msg = r"only supported with memory_format torch.preserve_format or torch.contiguous_format"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0))
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(
 | |
|                 torch.randn(B0)
 | |
|             )
 | |
| 
 | |
|     def test_binary_pointwise_ops(self):
 | |
|         def get_number(getter):
 | |
|             return getter([]).item()
 | |
| 
 | |
|         def make_case(op, input_getter=TensorFactory.randn):
 | |
|             return (op, input_getter)
 | |
| 
 | |
|         cases = [
 | |
|             # Basic arithmetic
 | |
|             make_case(torch.add),
 | |
|             make_case(lambda x, y: x + y),
 | |
|             make_case(torch.sub),
 | |
|             make_case(lambda x, y: x - y),
 | |
|             make_case(torch.mul),
 | |
|             make_case(lambda x, y: x * y),
 | |
|             make_case(torch.div, input_getter=TensorFactory.randp1),
 | |
|             make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
 | |
|             make_case(torch.pow, input_getter=TensorFactory.randp1),
 | |
|             make_case(lambda x, y: x**y, input_getter=TensorFactory.randp1),
 | |
|         ]
 | |
|         test = self._vmap_test
 | |
| 
 | |
|         for op, getter in cases:
 | |
|             device = "cpu"
 | |
|             B0, B1 = 7, 11
 | |
| 
 | |
|             # Single vmap: op(Tensor, Tensor)
 | |
|             test(op, (getter([B0, 3], device), getter([B0, 3], device)))
 | |
|             test(op, (getter([B0], device), getter([B0, 2, 3], device)))
 | |
|             test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
 | |
|             test(
 | |
|                 op,
 | |
|                 (getter([B0], device), getter([2, B0, 3], device)),
 | |
|                 in_dims=(0, 1),
 | |
|                 out_dims=1,
 | |
|             )
 | |
|             test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
 | |
|             test(
 | |
|                 op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None)
 | |
|             )
 | |
| 
 | |
|             # Nested vmap: op(Tensor, Tensor)
 | |
|             test(
 | |
|                 vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device))
 | |
|             )
 | |
|             test(
 | |
|                 vmap(op, in_dims=(None, 0)),
 | |
|                 (getter([B0, 2, 3], device), getter([B1, 3], device)),
 | |
|                 in_dims=(0, None),
 | |
|             )
 | |
| 
 | |
|             # Python number overload: op(Tensor, Number) (and vice-versa)
 | |
|             number = get_number(getter)
 | |
|             self._test_unary(lambda t: op(t, number), getter, device)
 | |
|             number = get_number(getter)
 | |
|             self._test_unary(lambda t: op(number, t), getter, device)
 | |
| 
 | |
|             # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
 | |
|             test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
 | |
|             test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
 | |
|             test(op, (getter([B0], device), getter([B0], device)))
 | |
| 
 | |
|             # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
 | |
|             test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
 | |
|             test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
 | |
| 
 | |
|             if not torch.cuda.is_available():
 | |
|                 continue
 | |
| 
 | |
|             # TODO(rzou): fix the following
 | |
|             # # Test cross-device scalars
 | |
|             # number = get_number(getter)
 | |
|             # self._test_unary(lambda t: op(t, number), getter, device='cuda')
 | |
|             # self._test_unary(lambda t: op(number, t), getter, device='cuda')
 | |
|             # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
 | |
| 
 | |
|     def test_as_strided(self):
 | |
|         def _test(sizes, strides, offset, tensor, lambd):
 | |
|             result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor)
 | |
|             expected = vmap(lambd)(tensor)
 | |
|             self.assertTrue(result._base is expected._base)
 | |
|             self.assertEqual(result, expected)
 | |
| 
 | |
|         # single vmap test
 | |
|         B0 = 5
 | |
|         tensors = [
 | |
|             # contiguous
 | |
|             torch.randn(B0, 2, 3),
 | |
|             # non-contiguous
 | |
|             torch.randn(B0, 3, 2).transpose(1, 2),
 | |
|             # non-zero storage offset
 | |
|             torch.randn(2, B0, 2, 3)[1],
 | |
|             # non-contiguous strides, zero storage offset
 | |
|             torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0],
 | |
|             # non-contiguous strides, non-zero storage offset
 | |
|             torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1],
 | |
|         ]
 | |
| 
 | |
|         for x in tensors:
 | |
|             S0, S1 = x.stride()[1:]
 | |
|             offset = x.storage_offset()
 | |
| 
 | |
|             # Broadcast
 | |
|             _test(
 | |
|                 [5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3)
 | |
|             )
 | |
|             # transpose
 | |
|             _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1))
 | |
|             # select
 | |
|             _test([2], [S0], offset + S1, x, lambda x: x[:, 1])
 | |
| 
 | |
|         # Nested vmap test
 | |
|         B1 = 7
 | |
|         x = torch.randn(B1, B0, 2, 3)
 | |
|         S0, S1 = x.stride()[2:]
 | |
|         result = vmap(
 | |
|             vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1
 | |
|         )(x)
 | |
|         expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x)
 | |
|         self.assertTrue(result._base is expected._base)
 | |
|         self.assertEqual(result, expected)
 | |
| 
 | |
|         # Check that mal-formatted size/strides doesn't crash
 | |
|         with self.assertRaisesRegex(
 | |
|             RuntimeError, "size and stride must have the same length"
 | |
|         ):
 | |
|             x = torch.randn(B0, 2, 3).transpose(0, 1)
 | |
|             vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x)
 | |
| 
 | |
|         # Sanity check #1: we require the batch dims to be at the front of the
 | |
|         # tensor (in memory layout).
 | |
|         msg = "batch dims being vmapped over are at the front of the tensor"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             x = torch.randn(2, B0, 3).transpose(0, 1)
 | |
|             vmap(lambda x: x.as_strided([2, 3], [B0 * 3, 1]))(x)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             x = torch.randn(B0, 2, 3, B1).movedim(3, 1)
 | |
|             vmap(vmap(lambda x: x.as_strided([2, 3], [B1 * 3, B1])))(x)
 | |
| 
 | |
|         # All the Sanity check #2{a,b,c} cases check that
 | |
|         # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
 | |
|         # doesn't index memory that is out of bounds of xs[i]. This condition
 | |
|         # is important to the correctness of the as_strided batching rule
 | |
|         # (see NOTE: [When will the as_strided_batching_rule fail?])
 | |
| 
 | |
|         # Sanity check #2a: The maximum indexable location of
 | |
|         # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
 | |
|         # is less than or equal to the maximum indexable location of xs[i].
 | |
|         msg = "This is not supported inside of vmap"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             x = torch.randn(B0, 3)
 | |
|             vmap(lambda x: x.as_strided([3], [1], 1))(x)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             x = torch.randn(B0, 3, 5)
 | |
|             vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             x = torch.randn(B0, B1, 3, 5)
 | |
|             vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x)
 | |
| 
 | |
|         # Sanity check #2b: The min indexable location of
 | |
|         # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
 | |
|         # is greater than or equal to the min indexable location of xs[i].
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             x = torch.randn(2, B0, 3)[1]
 | |
|             vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x)
 | |
| 
 | |
|         # Sanity check #2c:
 | |
|         # xs[i] is a zero-dim tensor, but
 | |
|         # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
 | |
|         # is not
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             x = torch.randn(B0, 0, 3)
 | |
|             vmap(lambda x: x.as_strided([3], [1]))(x)
 | |
| 
 | |
|     def test_bmm(self):
 | |
|         op = torch.bmm
 | |
|         test = self._vmap_test
 | |
|         B0, B1 = 7, 11
 | |
| 
 | |
|         # shape mismatch
 | |
|         msg = "Shape mismatch"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2))
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
 | |
| 
 | |
|         # left arg is vmapped
 | |
|         test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None)),
 | |
|             (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)),
 | |
|             in_dims=(1, None),
 | |
|         )
 | |
| 
 | |
|         # right arg is vmapped
 | |
|         test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
 | |
|         test(
 | |
|             vmap(op, in_dims=(None, 0)),
 | |
|             (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)),
 | |
|             in_dims=(None, 1),
 | |
|         )
 | |
| 
 | |
|         # both args are vmapped
 | |
|         test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3)))
 | |
|         test(
 | |
|             vmap(op),
 | |
|             (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)),
 | |
|             in_dims=(1, 0),
 | |
|         )
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None)),
 | |
|             (torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)),
 | |
|             in_dims=(None, 0),
 | |
|         )
 | |
| 
 | |
|     def test_cat(self):
 | |
|         test = self._vmap_test
 | |
|         B0, B1 = 5, 7
 | |
| 
 | |
|         # Quick hack b/c vmap can't accept a list of tensors as an argument
 | |
|         def get_op(dim):
 | |
|             def op(*tensors):
 | |
|                 return torch.cat(tensors, dim=dim)
 | |
| 
 | |
|             return op
 | |
| 
 | |
|         test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3)))
 | |
|         test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0))
 | |
|         test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2))
 | |
|         test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2))
 | |
|         test(
 | |
|             vmap(get_op(0), in_dims=(0, None)),
 | |
|             (torch.rand(B1, 2), torch.rand(B0, 3)),
 | |
|             in_dims=(None, 0),
 | |
|         )
 | |
|         test(
 | |
|             vmap(get_op(0), in_dims=(0, 0)),
 | |
|             (torch.rand(B1, 2), torch.rand(B0, B1, 3)),
 | |
|             in_dims=(None, 0),
 | |
|         )
 | |
| 
 | |
|     def test_conj(self):
 | |
|         op = torch.conj
 | |
| 
 | |
|         def run_test(dtype):
 | |
|             def get(shape):
 | |
|                 return torch.randn(shape, dtype=dtype)
 | |
| 
 | |
|             B0, B1 = 7, 11
 | |
|             test = self._vmap_test
 | |
| 
 | |
|             # Single vmap, various in_dims / out_dims
 | |
|             test(op, [get([B0, 3])])
 | |
|             test(op, [get([2, 5, B0, 3])], in_dims=2)
 | |
|             test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
 | |
| 
 | |
|             # Doubly nested vmap
 | |
|             test(vmap(op), [get([B0, B1])])
 | |
|             test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2)
 | |
|             test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2)
 | |
| 
 | |
|         # correctness tests
 | |
|         run_test(torch.float)
 | |
|         run_test(torch.cfloat)
 | |
| 
 | |
|         # check that torch.conj on a non-complex tensor returns the same tensor
 | |
|         real_tensor = torch.randn(3)
 | |
|         result = vmap(op)(real_tensor)
 | |
|         self.assertEqual(result.data_ptr(), real_tensor.data_ptr())
 | |
| 
 | |
|     def test_contiguous(self):
 | |
|         op = Tensor.contiguous
 | |
| 
 | |
|         self._test_unary(op, TensorFactory.randn, "cpu")
 | |
| 
 | |
|         # check that contiguous returns the original tensor if the per-examples
 | |
|         # are already contiguous
 | |
|         B0 = 3
 | |
|         x = torch.randn(B0, 2, 5, 7)
 | |
|         x = x.movedim(0, 2)
 | |
|         result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x)
 | |
|         self.assertTrue(result is x)
 | |
| 
 | |
|         msg = "NYI: querying is_contiguous inside of vmap for memory_format"
 | |
|         tensor = torch.randn(B0, 3)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(functools.partial(op, memory_format=torch.channels_last))(tensor)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor)
 | |
| 
 | |
|     def test_stride(self):
 | |
|         B0 = 3
 | |
| 
 | |
|         x = torch.randn(B0, 2, 5, 7)
 | |
| 
 | |
|         def foo(x):
 | |
|             assert x.stride() == (7 * 5, 7, 1)
 | |
|             return x
 | |
| 
 | |
|         vmap(foo)(x)
 | |
| 
 | |
|         x = torch.randn(2, B0, 5, 7).movedim(1, 0)
 | |
| 
 | |
|         def bar(x):
 | |
|             assert x.stride() == (7 * 5 * B0, 7, 1)
 | |
|             return x
 | |
| 
 | |
|         vmap(bar)(x)
 | |
| 
 | |
|     def test_chunk(self):
 | |
|         test = self._vmap_view_test
 | |
|         op = torch.chunk
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
| 
 | |
|         # tests for torch.split(self, split_size: int, dim)
 | |
|         test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None))
 | |
|         test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None, None)),
 | |
|             (torch.rand(B1, 1023, B0, 5), 4, 0),
 | |
|             in_dims=(2, None, None),
 | |
|         )
 | |
|         test(
 | |
|             vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
 | |
|             (torch.rand(B1, 2, B0, 64, B2),),
 | |
|             in_dims=2,
 | |
|         )
 | |
| 
 | |
|     def test_clamp(self):
 | |
|         clamp_cases = (
 | |
|             (lambda t: t.clamp(min=-0.5), TensorFactory.randn),
 | |
|             (lambda t: t.clamp(max=0.5), TensorFactory.randn),
 | |
|             (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn),
 | |
|             (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn),
 | |
|             (lambda t: t.clamp_max(max=0.5), TensorFactory.randn),
 | |
|         )
 | |
|         for op, getter in clamp_cases:
 | |
|             self._test_unary(op, getter, "cpu")
 | |
| 
 | |
|     def test_comparison_ops(self):
 | |
|         test = functools.partial(self._vmap_test, check_propagates_grad=False)
 | |
| 
 | |
|         getter = TensorFactory.randn
 | |
|         B0, B1 = 7, 11
 | |
| 
 | |
|         ops = (
 | |
|             torch.eq,
 | |
|             lambda x, y: x == y,
 | |
|             torch.gt,
 | |
|             lambda x, y: x > y,
 | |
|             torch.ge,
 | |
|             lambda x, y: x >= y,
 | |
|             torch.le,
 | |
|             lambda x, y: x <= y,
 | |
|             torch.lt,
 | |
|             lambda x, y: x < y,
 | |
|             torch.ne,
 | |
|             lambda x, y: x != y,
 | |
|         )
 | |
| 
 | |
|         for op in ops:
 | |
|             # Single vmap: op(Tensor, Tensor)
 | |
|             test(op, (getter([B0, 3]), getter([B0, 3])))
 | |
|             test(op, (getter([B0]), getter([B0, 2, 3])))
 | |
|             test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1))
 | |
|             test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1)
 | |
|             test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None))
 | |
|             test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None))
 | |
| 
 | |
|             # Nested vmap: op(Tensor, Tensor)
 | |
|             test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3])))
 | |
|             test(
 | |
|                 vmap(op, in_dims=(None, 0)),
 | |
|                 (getter([B0, 2, 3]), getter([B1, 3])),
 | |
|                 in_dims=(0, None),
 | |
|             )
 | |
| 
 | |
|             # test number as inputs
 | |
|             number = getter([]).item()
 | |
|             self._test_unary(
 | |
|                 lambda t: op(t, number), getter, "cpu", check_propagates_grad=False
 | |
|             )
 | |
| 
 | |
|     def test_diagonal(self):
 | |
|         tensor = torch.randn(3, 5, 7, 11, 13)
 | |
|         test = self._vmap_view_test
 | |
|         op = torch.diagonal
 | |
|         test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None))
 | |
|         test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None))
 | |
|         test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None))
 | |
|         test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1)
 | |
|         test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1)
 | |
|         test(
 | |
|             vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
 | |
|             (tensor,),
 | |
|             in_dims=1,
 | |
|             out_dims=1,
 | |
|         )
 | |
| 
 | |
|     def test_dot(self):
 | |
|         op = torch.dot
 | |
|         test = self._vmap_test
 | |
|         B0, B1 = 7, 11
 | |
| 
 | |
|         # shape mismatch
 | |
|         msg = "Shape mismatch"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2))
 | |
| 
 | |
|         # left arg is vmapped
 | |
|         test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None)),
 | |
|             (torch.rand(B1, B0, 5), torch.rand(5)),
 | |
|             in_dims=(1, None),
 | |
|         )
 | |
| 
 | |
|         # right arg is vmapped
 | |
|         test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0))
 | |
|         test(
 | |
|             vmap(op, in_dims=(None, 0)),
 | |
|             (torch.rand(5), torch.rand(B1, B0, 5)),
 | |
|             in_dims=(None, 1),
 | |
|         )
 | |
| 
 | |
|         # both args are vmapped
 | |
|         test(op, (torch.rand(B0, 5), torch.rand(B0, 5)))
 | |
|         test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None)),
 | |
|             (torch.rand(B1, 5), torch.rand(B0, 5)),
 | |
|             in_dims=(None, 0),
 | |
|         )
 | |
| 
 | |
|     def test_expand_as(self):
 | |
|         op = torch.Tensor.expand_as
 | |
|         test = self._vmap_view_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
|         test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5)))
 | |
|         test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None))
 | |
|         test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
 | |
|         test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5)))
 | |
|         test(
 | |
|             vmap(op),
 | |
|             (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)),
 | |
|             in_dims=(0, 1),
 | |
|         )
 | |
|         test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
 | |
|         test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))
 | |
| 
 | |
|     def test_fill_and_zero_inplace(self):
 | |
|         test = functools.partial(self._vmap_test, check_propagates_grad=False)
 | |
|         B0, B1 = 7, 11
 | |
|         ops = (
 | |
|             lambda t: t.fill_(0.1),
 | |
|             lambda t: t.fill_(torch.tensor(0.2)),
 | |
|             lambda t: t.zero_(),
 | |
|         )
 | |
| 
 | |
|         for op in ops:
 | |
|             # Single vmap, various in_dims / out_dims
 | |
|             test(op, [TensorFactory.randn([B0, 3])])
 | |
|             test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2)
 | |
|             test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
 | |
| 
 | |
|             # Doubly nested vmap
 | |
|             test(vmap(op), [TensorFactory.randn([B0, B1])])
 | |
|             test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2)
 | |
|             test(
 | |
|                 vmap(op, in_dims=2),
 | |
|                 [TensorFactory.randn([2, 5, B0, B1, 3])],
 | |
|                 in_dims=2,
 | |
|                 out_dims=2,
 | |
|             )
 | |
| 
 | |
|         # test when value is a batched tensor for fill_ operator
 | |
|         B0, B1 = 3, 5
 | |
|         test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)])
 | |
| 
 | |
|         with self.assertRaisesRegex(
 | |
|             RuntimeError, r"output with shape .+ doesn't match the broadcast shape"
 | |
|         ):
 | |
|             # Runtime Error is thrown when the tensor being written to isn't being vmapped over
 | |
|             vmap(Tensor.fill_, (None, 0))(
 | |
|                 TensorFactory.randn([B0, B1]), TensorFactory.randn([B0])
 | |
|             )
 | |
| 
 | |
|     def _test_complex_views(self, op, dtypes):
 | |
|         test = self._vmap_view_test
 | |
| 
 | |
|         def run_test(op, dtype):
 | |
|             def get(shape):
 | |
|                 return torch.randn(shape, dtype=dtype)
 | |
| 
 | |
|             B0, B1 = 7, 11
 | |
| 
 | |
|             # Single vmap, various in_dims / out_dims
 | |
|             test(op, [get([B0, 3])])
 | |
|             test(op, [get([3, B0])], in_dims=1)
 | |
|             test(op, [get([2, 5, B0, 3])], in_dims=2)
 | |
|             test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
 | |
| 
 | |
|             # Doubly nested vmap
 | |
|             test(vmap(op), [get([B0, B1])])
 | |
|             test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4)
 | |
|             test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2)
 | |
| 
 | |
|         for dtype in dtypes:
 | |
|             run_test(op, dtype)
 | |
| 
 | |
|     def test_real(self):
 | |
|         self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble])
 | |
| 
 | |
|     def test_imag(self):
 | |
|         self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble])
 | |
| 
 | |
|     def test_view_as_real(self):
 | |
|         self._test_complex_views(
 | |
|             torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble]
 | |
|         )
 | |
| 
 | |
|     def test_view_as_complex(self):
 | |
|         def run_test(dtype):
 | |
|             def get(shape):
 | |
|                 return torch.randn(shape, dtype=dtype)
 | |
| 
 | |
|             op = torch.view_as_complex
 | |
|             test = self._vmap_view_test
 | |
|             B0, B1 = 7, 11
 | |
| 
 | |
|             # Single vmap, various in_dims / out_dims
 | |
|             test(op, [get([B0, 3, 2])])
 | |
|             test(op, [get([2, 5, B0, 3, 2])], in_dims=2)
 | |
|             test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2)
 | |
| 
 | |
|             # Doubly nested vmap
 | |
|             test(vmap(op), [get([B0, B1, 2])])
 | |
|             test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2)
 | |
|             test(
 | |
|                 vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])], in_dims=2, out_dims=2
 | |
|             )
 | |
| 
 | |
|             # Interesting case #1: Batch dim directly before dim of size 2
 | |
|             test(op, [get([3, B0, 2])], in_dims=1)
 | |
|             test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2)
 | |
| 
 | |
|             # Interesting case #2: Batch dim at end of tensor, success cases
 | |
|             # view_as_complex requires that the dim with size 2 have stride 1
 | |
|             # in order for the view to function properly
 | |
|             test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1)
 | |
|             test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)])
 | |
|             test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)])
 | |
| 
 | |
|             # Interesting case #3: Batch dim at end of tensor, failure cases
 | |
|             msg = "Tensor must have a last dimension with stride 1"
 | |
|             with self.assertRaisesRegex(RuntimeError, msg):
 | |
|                 vmap(op, in_dims=1)(get([2, B0]))
 | |
|             with self.assertRaisesRegex(RuntimeError, msg):
 | |
|                 vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1]))
 | |
| 
 | |
|             # Invalid input: no dimension of size 2
 | |
|             msg = "Input tensor must have one or more dimensions"
 | |
|             with self.assertRaisesRegex(RuntimeError, msg):
 | |
|                 vmap(op)(get([B0]))
 | |
|             with self.assertRaisesRegex(RuntimeError, msg):
 | |
|                 vmap(vmap(op))(get([B0, B1]))
 | |
| 
 | |
|             # Invalid input: Batch dim has size 2, but the logical last dim does
 | |
|             # not have size 2
 | |
|             msg = "Tensor must have a last dimension of size 2"
 | |
|             with self.assertRaisesRegex(RuntimeError, msg):
 | |
|                 vmap(op, in_dims=1)(get([3, 2]))
 | |
| 
 | |
|         for dtype in [torch.float, torch.double]:
 | |
|             run_test(dtype)
 | |
| 
 | |
|     def test_is_complex(self):
 | |
|         ctensor = torch.randn(3, dtype=torch.cfloat)
 | |
|         tensor = torch.randn(3)
 | |
| 
 | |
|         def foo(x):
 | |
|             if x.is_complex():
 | |
|                 return torch.tensor(1)
 | |
|             else:
 | |
|                 return torch.tensor(0)
 | |
| 
 | |
|         self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
 | |
|         self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))
 | |
| 
 | |
|     def test_is_floating_point(self):
 | |
|         float_tensor = torch.tensor([1.0, 2.0, 3.0])
 | |
|         long_tensor = torch.tensor([1, 2, 3])
 | |
| 
 | |
|         def foo(x):
 | |
|             if x.is_floating_point():
 | |
|                 return torch.tensor(1)
 | |
|             else:
 | |
|                 return torch.tensor(0)
 | |
| 
 | |
|         self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1]))
 | |
|         self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0]))
 | |
| 
 | |
|     def test_is_contiguous(self):
 | |
|         def foo(x):
 | |
|             if x.is_contiguous():
 | |
|                 return torch.tensor(1.0)
 | |
|             else:
 | |
|                 return torch.tensor(0.0)
 | |
| 
 | |
|         B0, B1 = 3, 5
 | |
| 
 | |
|         # Single batch dim
 | |
|         contig = torch.randn(B0, 2, 7)
 | |
|         self.assertEqual(vmap(foo)(contig), torch.ones(B0))
 | |
| 
 | |
|         noncontig = torch.randn(2, B0, 7)
 | |
|         self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0))
 | |
| 
 | |
|         noncontig = torch.randn(2, B0, 7).movedim(1, 0)
 | |
|         self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0))
 | |
| 
 | |
|         noncontig = torch.randn(2, 7, B0)
 | |
|         self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0))
 | |
| 
 | |
|         # Multiple batch dims
 | |
|         contig = torch.randn(B0, B1, 3)
 | |
|         self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
 | |
| 
 | |
|         contig = torch.randn(B1, B0, 3)
 | |
|         self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1))
 | |
| 
 | |
|         contig = torch.randn(B1, B0, 3).movedim(0, 1)
 | |
|         self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
 | |
| 
 | |
|         noncontig = torch.randn(B0, 3, B1)
 | |
|         self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1))
 | |
| 
 | |
|         # is_contiguous on empty tensor is True
 | |
|         def bar(x):
 | |
|             assert x.is_contiguous()
 | |
|             return x
 | |
| 
 | |
|         vmap(bar)(torch.randn(B0, 0, 3))
 | |
|         vmap(bar, in_dims=1)(torch.randn(0, B0, 3))
 | |
|         vmap(bar)(torch.randn(B0, 0, 3).mT)
 | |
| 
 | |
|         # is_contiguous with other memory formats
 | |
|         def baz(x, memory_format):
 | |
|             x.is_contiguous(memory_format=memory_format)
 | |
|             return x
 | |
| 
 | |
|         msg = "NYI: querying is_contiguous inside of vmap for memory_format"
 | |
|         tensor = torch.randn(B0, 2, 7, 3)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor)
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
 | |
| 
 | |
|     def test_movedim(self):
 | |
|         op = torch.movedim
 | |
|         test = self._vmap_view_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
| 
 | |
|         # movedim(tensor, int, int) variant
 | |
|         test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None))
 | |
|         test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None, None)),
 | |
|             (torch.rand(B1, 2, B0, 5), 0, 1),
 | |
|             in_dims=(2, None, None),
 | |
|         )
 | |
|         test(
 | |
|             vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
 | |
|             (torch.rand(B1, 2, B0, 5, B2), 0, 1),
 | |
|             in_dims=(2, None, None),
 | |
|         )
 | |
| 
 | |
|         # movedim(tensor, intlist, intlist) variant
 | |
|         test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None))
 | |
|         test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None, None)),
 | |
|             (torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]),
 | |
|             in_dims=(2, None, None),
 | |
|         )
 | |
|         test(
 | |
|             vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
 | |
|             (torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]),
 | |
|             in_dims=(2, None, None),
 | |
|         )
 | |
| 
 | |
|     def test_mm(self):
 | |
|         op = torch.mm
 | |
|         test = self._vmap_test
 | |
|         B0, B1 = 7, 11
 | |
| 
 | |
|         # shape mismatch
 | |
|         msg = "Shape mismatch"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
 | |
| 
 | |
|         # left arg is vmapped
 | |
|         test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None)),
 | |
|             (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)),
 | |
|             in_dims=(1, None),
 | |
|         )
 | |
| 
 | |
|         # right arg is vmapped
 | |
|         test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
 | |
|         test(
 | |
|             vmap(op, in_dims=(None, 0)),
 | |
|             (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)),
 | |
|             in_dims=(None, 1),
 | |
|         )
 | |
| 
 | |
|         # both args are vmapped
 | |
|         test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2)))
 | |
|         test(
 | |
|             vmap(op),
 | |
|             (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)),
 | |
|             in_dims=(1, 0),
 | |
|         )
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None)),
 | |
|             (torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)),
 | |
|             in_dims=(None, 0),
 | |
|         )
 | |
| 
 | |
|     def test_mv(self):
 | |
|         op = torch.mv
 | |
|         test = self._vmap_test
 | |
|         B0, B1 = 7, 11
 | |
| 
 | |
|         # shape mismatch
 | |
|         msg = "Shape mismatch"
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2))
 | |
|         with self.assertRaisesRegex(RuntimeError, msg):
 | |
|             vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2))
 | |
| 
 | |
|         # left arg is vmapped
 | |
|         test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None)),
 | |
|             (torch.rand(B1, B0, 2, 5), torch.rand(5)),
 | |
|             in_dims=(1, None),
 | |
|         )
 | |
| 
 | |
|         # right arg is vmapped
 | |
|         test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0))
 | |
|         test(
 | |
|             vmap(op, in_dims=(None, 0)),
 | |
|             (torch.rand(2, 5), torch.rand(B1, B0, 5)),
 | |
|             in_dims=(None, 1),
 | |
|         )
 | |
| 
 | |
|         # both args are vmapped
 | |
|         test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5)))
 | |
|         test(
 | |
|             vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0)
 | |
|         )
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None)),
 | |
|             (torch.rand(B1, 2, 5), torch.rand(B0, 5)),
 | |
|             in_dims=(None, 0),
 | |
|         )
 | |
| 
 | |
|     def test_narrow(self):
 | |
|         op = torch.narrow
 | |
|         test = self._vmap_view_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
| 
 | |
|         test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None))
 | |
|         test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None, None, None)),
 | |
|             (torch.rand(B1, 2, B0, 5), 1, 0, 0),
 | |
|             in_dims=(2, None, None, None),
 | |
|         )
 | |
|         test(
 | |
|             vmap(
 | |
|                 vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)
 | |
|             ),
 | |
|             (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3),
 | |
|             in_dims=(2, None, None, None),
 | |
|         )
 | |
| 
 | |
|     def test_new_empty(self):
 | |
|         # Empty is non-deterministic so we just check that the shape of the
 | |
|         # output tensor is what we expect and that the vmap fallback isn't used.
 | |
|         op = Tensor.new_empty
 | |
| 
 | |
|         B0, B1 = 7, 11
 | |
| 
 | |
|         result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0))
 | |
|         self.assertEqual(result.shape, [B0, 2, 3])
 | |
| 
 | |
|         result = vmap(lambda x: op(x, []))(torch.randn(B0))
 | |
|         self.assertEqual(result.shape, [B0])
 | |
| 
 | |
|         result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1))
 | |
|         self.assertEqual(result.shape, [B0, B1, 2, 3])
 | |
| 
 | |
|     def test_new_empty_strided(self):
 | |
|         # Empty is non-deterministic so we just check that the size and shape
 | |
|         # of the output are what we expect and that the vmap fallback isn't used
 | |
|         B0, B1 = 7, 11
 | |
| 
 | |
|         def _test_single_vmap(size, stride, B0):
 | |
|             x = torch.randn(B0)
 | |
|             result = vmap(lambda x: x.new_empty_strided(size, stride))(x)
 | |
|             S = torch.empty_strided(size, stride).storage().size()
 | |
|             self.assertEqual(result.shape, [B0] + size)
 | |
|             self.assertEqual(result.stride(), [S] + stride)
 | |
| 
 | |
|         def _test_double_vmap(size, stride, B0, B1):
 | |
|             x = torch.randn(B0, B1)
 | |
|             result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x)
 | |
|             S = torch.empty_strided(size, stride).storage().size()
 | |
|             self.assertEqual(result.shape, [B0, B1] + size)
 | |
|             self.assertEqual(result.stride(), [B1 * S, S] + stride)
 | |
| 
 | |
|             x = torch.randn(B1, B0)
 | |
|             result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)(
 | |
|                 x
 | |
|             )
 | |
|             S = x.new_empty_strided(size, stride).storage().size()
 | |
|             self.assertEqual(result.shape, [B0, B1] + size)
 | |
|             self.assertEqual(result.stride(), [B1 * S, S] + stride)
 | |
| 
 | |
|         # contiguous case
 | |
|         _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0)
 | |
|         _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1)
 | |
| 
 | |
|         # expanded
 | |
|         _test_single_vmap([2, 3, 5], [0, 5, 1], B0)
 | |
|         _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1)
 | |
| 
 | |
|         # some of these cases are pretty strange, just verifying that if
 | |
|         # empty_strided allows them then BatchedTensor.new_empty_strided
 | |
|         # can as well
 | |
|         for shape in [[2, 3, 4], [0, 2, 0]]:
 | |
|             for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]:
 | |
|                 _test_single_vmap(shape, strides, B0)
 | |
|                 _test_double_vmap(shape, strides, B0, B1)
 | |
| 
 | |
|     def test_new_zeros(self):
 | |
|         op = Tensor.new_zeros
 | |
|         test = functools.partial(self._vmap_test, check_propagates_grad=False)
 | |
|         B0, B1 = 7, 11
 | |
| 
 | |
|         test(lambda x: op(x, 2, 3), (torch.rand(B0),))
 | |
|         test(lambda x: op(x, []), (torch.rand(B0),))
 | |
|         test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),))
 | |
| 
 | |
|     def test_select(self):
 | |
|         op = torch.select
 | |
|         test = self._vmap_view_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
|         test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None))
 | |
|         test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None))
 | |
|         test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
 | |
|         test(
 | |
|             vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)),
 | |
|             (torch.rand(B1, 2, B0, B2, 5),),
 | |
|             in_dims=2,
 | |
|         )
 | |
| 
 | |
|     def test_stack(self):
 | |
|         test = self._vmap_test
 | |
|         B0, B1 = 5, 7
 | |
| 
 | |
|         # Quick hack b/c vmap can't accept a list of tensors as an argument
 | |
|         def get_op(dim):
 | |
|             def op(*tensors):
 | |
|                 return torch.stack(tensors, dim=dim)
 | |
| 
 | |
|             return op
 | |
| 
 | |
|         test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3)))
 | |
|         test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0))
 | |
|         test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
 | |
|         test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
 | |
|         test(
 | |
|             vmap(get_op(0), in_dims=(0, None)),
 | |
|             (torch.rand(B1, 2), torch.rand(B0, 2)),
 | |
|             in_dims=(None, 0),
 | |
|         )
 | |
|         test(
 | |
|             vmap(get_op(0), in_dims=(0, 0)),
 | |
|             (torch.rand(B1, 2), torch.rand(B0, B1, 2)),
 | |
|             in_dims=(None, 0),
 | |
|         )
 | |
| 
 | |
|     def test_slice(self):
 | |
|         test = self._vmap_view_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
|         test(lambda t: t[0:1], (torch.rand(B0, 3, 5),))
 | |
|         test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2)
 | |
|         test(
 | |
|             vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2
 | |
|         )
 | |
|         test(
 | |
|             vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
 | |
|             (torch.rand(3, 5, B0, B1, B2),),
 | |
|             in_dims=2,
 | |
|         )
 | |
| 
 | |
|     def test_squeeze(self):
 | |
|         test = self._vmap_view_test
 | |
|         op = torch.squeeze
 | |
|         B0, B1 = 1, 11
 | |
|         test(op, (torch.rand(B0),))
 | |
|         test(op, (torch.rand(B0, 3, 5),))
 | |
|         test(op, (torch.rand(1, B0, 5),), in_dims=1)
 | |
|         test(op, (torch.rand(B0, 0, 1, 5, 1),))
 | |
|         test(op, (torch.rand(B0, 1, 1, 1, 1),))
 | |
|         test(vmap(op), (torch.rand(B0, B1, 1),))
 | |
|         test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2)
 | |
| 
 | |
|     def test_sum_dim(self):
 | |
|         test = self._vmap_test
 | |
|         B0, B1 = 5, 7
 | |
| 
 | |
|         # Single vmap, various in_dims / out_dims
 | |
|         test(lambda x: x.sum(()), [torch.randn([B0])])
 | |
|         test(lambda x: x.sum(()), [torch.randn([B0, 2])])
 | |
|         test(lambda x: x.sum(0), [torch.randn([B0])])
 | |
|         test(lambda x: x.sum(-1), [torch.randn([B0])])
 | |
|         test(lambda x: x.sum(0), [torch.randn([B0, 3])])
 | |
|         test(lambda x: x.sum(-1), [torch.randn([2, 5, B0, 3])], in_dims=2)
 | |
|         test(lambda x: x.sum(2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
 | |
| 
 | |
|         # Doubly nested vmap
 | |
|         test(vmap(lambda x: x.sum(())), [torch.randn([B0, B1])])
 | |
|         test(vmap(lambda x: x.sum(0)), [torch.randn([B0, B1])])
 | |
|         test(vmap(lambda x: x.sum(-1)), [torch.randn([B0, B1])])
 | |
|         test(vmap(lambda x: x.sum(-2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2)
 | |
|         test(
 | |
|             vmap(lambda x: x.sum(2), in_dims=2),
 | |
|             [torch.randn([2, 5, B0, B1, 3])],
 | |
|             in_dims=2,
 | |
|             out_dims=2,
 | |
|         )
 | |
| 
 | |
|     def test_reshape(self):
 | |
|         test = self._vmap_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
|         op = torch.reshape
 | |
|         test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True)
 | |
|         test(
 | |
|             op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False
 | |
|         )
 | |
|         test(
 | |
|             vmap(lambda t: t.reshape([-1])),
 | |
|             (torch.rand(B0, B1, 2, 5),),
 | |
|             check_view=True,
 | |
|         )
 | |
|         test(
 | |
|             vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1),
 | |
|             (torch.rand(3, B1, 2, B2, 5, B0),),
 | |
|             in_dims=5,
 | |
|             check_view=False,
 | |
|         )
 | |
| 
 | |
|     def test_reshape_as(self):
 | |
|         test = self._vmap_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
|         op = torch.Tensor.reshape_as
 | |
|         test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True)
 | |
|         test(
 | |
|             op,
 | |
|             (torch.rand(2 * 5), torch.rand(B0, 2, 5)),
 | |
|             in_dims=(None, 0),
 | |
|             check_view=True,
 | |
|         )
 | |
|         test(
 | |
|             op,
 | |
|             (torch.rand(B0, 2 * 5), torch.rand(2, 5)),
 | |
|             in_dims=(0, None),
 | |
|             check_view=True,
 | |
|         )
 | |
| 
 | |
|         test(
 | |
|             op,
 | |
|             (torch.rand(2, B0, 5), torch.rand(1, 1, 10)),
 | |
|             in_dims=(1, None),
 | |
|             check_view=False,
 | |
|         )
 | |
| 
 | |
|         test(
 | |
|             vmap(op),
 | |
|             (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)),
 | |
|             check_view=True,
 | |
|         )
 | |
|         test(
 | |
|             vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)),
 | |
|             (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)),
 | |
|             in_dims=(5, 0),
 | |
|             check_view=False,
 | |
|         )
 | |
| 
 | |
|     def test_result_type(self):
 | |
|         def scalar_tensor_with_dtype(op):
 | |
|             def wrapped(*args, **kwargs):
 | |
|                 dtype = op(*args, **kwargs)
 | |
|                 return torch.ones([], dtype=dtype)
 | |
| 
 | |
|             return wrapped
 | |
| 
 | |
|         test = self._vmap_test
 | |
|         op = scalar_tensor_with_dtype(torch.result_type)
 | |
| 
 | |
|         B0 = 2
 | |
| 
 | |
|         test(
 | |
|             op,
 | |
|             (torch.randn(B0), torch.randn(B0, dtype=torch.float64)),
 | |
|             check_propagates_grad=False,
 | |
|         )
 | |
|         test(
 | |
|             op,
 | |
|             (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)),
 | |
|             check_propagates_grad=False,
 | |
|         )
 | |
| 
 | |
|         test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False)
 | |
|         test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False)
 | |
| 
 | |
|         test(
 | |
|             lambda x: op(x, torch.tensor(1)),
 | |
|             (torch.randn(B0),),
 | |
|             check_propagates_grad=False,
 | |
|         )
 | |
|         test(
 | |
|             lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
 | |
|             (torch.randn(B0),),
 | |
|             check_propagates_grad=False,
 | |
|         )
 | |
| 
 | |
|         test(
 | |
|             op,
 | |
|             (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)),
 | |
|             check_propagates_grad=False,
 | |
|         )
 | |
|         test(
 | |
|             op,
 | |
|             (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)),
 | |
|             check_propagates_grad=False,
 | |
|         )
 | |
| 
 | |
|         test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False)
 | |
|         test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False)
 | |
| 
 | |
|         test(
 | |
|             lambda x: op(x, torch.tensor(1)),
 | |
|             (torch.randn(B0, 2),),
 | |
|             check_propagates_grad=False,
 | |
|         )
 | |
|         test(
 | |
|             lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
 | |
|             (torch.randn(B0, 2),),
 | |
|             check_propagates_grad=False,
 | |
|         )
 | |
| 
 | |
|         test(
 | |
|             op,
 | |
|             (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)),
 | |
|             check_propagates_grad=False,
 | |
|         )
 | |
|         test(
 | |
|             op,
 | |
|             (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)),
 | |
|             check_propagates_grad=False,
 | |
|         )
 | |
| 
 | |
|     @skipIfTorchDynamo("too slow")
 | |
|     def test_tensor_split(self):
 | |
|         test = self._vmap_view_test
 | |
|         op = torch.tensor_split
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
| 
 | |
|         # tests for torch.tensor_split(self, indices_or_sections: int, dim)
 | |
|         test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None))
 | |
|         test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None, None)),
 | |
|             (torch.rand(B1, 1023, B0, 5), 256, 0),
 | |
|             in_dims=(2, None, None),
 | |
|         )
 | |
|         test(
 | |
|             vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
 | |
|             (torch.rand(B1, 2, B0, 64, B2),),
 | |
|             in_dims=2,
 | |
|         )
 | |
| 
 | |
|         # tests for torch.tensor_split(self, indices_or_sections: List[int], dim)
 | |
|         test(
 | |
|             op,
 | |
|             (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1),
 | |
|             in_dims=(0, None, None),
 | |
|         )
 | |
|         test(
 | |
|             op,
 | |
|             (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1),
 | |
|             in_dims=(1, None, None),
 | |
|         )
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None, None)),
 | |
|             (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0),
 | |
|             in_dims=(2, None, None),
 | |
|         )
 | |
|         test(
 | |
|             vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)),
 | |
|             (torch.rand(B1, 2, B0, 64, B2),),
 | |
|             in_dims=2,
 | |
|         )
 | |
| 
 | |
|     def test_split(self):
 | |
|         test = self._vmap_view_test
 | |
|         op = torch.split
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
| 
 | |
|         # tests for torch.split(self, split_size: int, dim)
 | |
|         test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None))
 | |
|         test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None, None)),
 | |
|             (torch.rand(B1, 1023, B0, 5), 256, 0),
 | |
|             in_dims=(2, None, None),
 | |
|         )
 | |
|         test(
 | |
|             vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
 | |
|             (torch.rand(B1, 2, B0, 64, B2),),
 | |
|             in_dims=2,
 | |
|         )
 | |
| 
 | |
|         # tests for torch.split(self, split_size: List[int], dim)
 | |
|         test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None))
 | |
|         test(
 | |
|             op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None)
 | |
|         )
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None, None)),
 | |
|             (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0),
 | |
|             in_dims=(2, None, None),
 | |
|         )
 | |
|         test(
 | |
|             vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)),
 | |
|             (torch.rand(B1, 2, B0, 64, B2),),
 | |
|             in_dims=2,
 | |
|         )
 | |
| 
 | |
|     def test_trace(self):
 | |
|         op = torch.trace
 | |
|         test = self._vmap_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
| 
 | |
|         test(op, (torch.rand(B0, 2, 5),))
 | |
|         test(op, (torch.rand(2, B0, 5),), in_dims=1)
 | |
|         test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
 | |
|         test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
 | |
| 
 | |
|     def test_transpose(self):
 | |
|         op = torch.transpose
 | |
|         test = self._vmap_view_test
 | |
| 
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
|         test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),))
 | |
|         test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),))
 | |
|         test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),))
 | |
|         test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1)
 | |
|         test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
 | |
|         test(
 | |
|             vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)),
 | |
|             (torch.rand(B1, 2, B0, 5, B2),),
 | |
|             in_dims=2,
 | |
|         )
 | |
| 
 | |
|         # Special case: scalar tensor
 | |
|         for dim1, dim2 in itertools.product([0, -1], [0, -1]):
 | |
|             x = torch.rand(B0)
 | |
|             result = vmap(lambda x: op(x, dim1, dim2))(x)
 | |
|             self.assertTrue(result is x)
 | |
| 
 | |
|     def test_t(self):
 | |
|         op = torch.t
 | |
|         test = self._vmap_view_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
|         test(op, (torch.rand(B0, 2, 5),))
 | |
|         test(op, (torch.rand(2, B0, 5),), in_dims=1)
 | |
|         test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
 | |
|         test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
 | |
| 
 | |
|     def test_T_numpy(self):
 | |
|         def op(t):
 | |
|             return t.T
 | |
| 
 | |
|         test = self._vmap_view_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
|         test(op, (torch.rand(B0, 2, 3, 5),))
 | |
|         test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
 | |
|         test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
 | |
|         test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)
 | |
|         test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2)
 | |
| 
 | |
|     def test_to(self):
 | |
|         test = self._vmap_test
 | |
|         B0, B1 = 7, 11
 | |
| 
 | |
|         test(lambda t: t.to("cpu"), (torch.rand(B0),))
 | |
|         test(lambda t: t.to(torch.double), (torch.rand(B0),))
 | |
|         test(
 | |
|             lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64))
 | |
|         )
 | |
|         test(
 | |
|             lambda t, o: t.to(o),
 | |
|             (torch.rand(B0), torch.randn(B0, dtype=torch.float64)),
 | |
|             in_dims=(0, None),
 | |
|         )
 | |
|         test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),))
 | |
| 
 | |
|         # also test some casting methods
 | |
|         test(lambda t: t.double(), (torch.rand(B0),))
 | |
|         test(lambda t: t.float(), (torch.rand(B0),))
 | |
|         test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False)
 | |
|         test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False)
 | |
| 
 | |
|     def test_unfold(self):
 | |
|         op = torch.Tensor.unfold
 | |
|         test = self._vmap_view_test
 | |
|         B0, B1, B2 = 3, 2, 5
 | |
| 
 | |
|         test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None))
 | |
|         test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None, None, None)),
 | |
|             (torch.rand(B1, 7, B0, 11), 1, 5, 1),
 | |
|             in_dims=(2, None, None, None),
 | |
|         )
 | |
|         test(
 | |
|             vmap(
 | |
|                 vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)
 | |
|             ),
 | |
|             (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4),
 | |
|             in_dims=(2, None, None, None),
 | |
|         )
 | |
| 
 | |
|     def test_unbind(self):
 | |
|         test = self._vmap_view_test
 | |
|         op = torch.unbind
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
| 
 | |
|         test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None))
 | |
|         test(op, (torch.rand(B0, 2, 0),))
 | |
|         test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None))
 | |
|         test(
 | |
|             vmap(op, in_dims=(0, None)),
 | |
|             (torch.rand(B1, 1023, B0, 5), 1),
 | |
|             in_dims=(2, None),
 | |
|         )
 | |
|         test(
 | |
|             vmap(vmap(lambda t: op(t, dim=1), in_dims=2)),
 | |
|             (torch.rand(B1, 2, B0, 32, B2),),
 | |
|             in_dims=2,
 | |
|         )
 | |
| 
 | |
|     def test_view(self):
 | |
|         test = self._vmap_view_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
|         op = torch.Tensor.view
 | |
| 
 | |
|         # We should error out if the view would produce an incorrect result
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10])
 | |
| 
 | |
|         test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None))
 | |
|         test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None))
 | |
|         test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),))
 | |
|         test(
 | |
|             vmap(vmap(lambda t: t.reshape([-1])), in_dims=1),
 | |
|             (torch.rand(B2, B0, B1, 3, 2, 5),),
 | |
|             in_dims=1,
 | |
|         )
 | |
| 
 | |
|     def test_view_as(self):
 | |
|         test = self._vmap_view_test
 | |
|         B0, B1, B2 = 7, 11, 13
 | |
|         op = torch.Tensor.view_as
 | |
| 
 | |
|         # We should error out if the view would produce an incorrect result
 | |
|         with self.assertRaises(RuntimeError):
 | |
|             vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10))
 | |
| 
 | |
|         test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)))
 | |
|         test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0))
 | |
|         test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None))
 | |
| 
 | |
|         test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None))
 | |
| 
 | |
|         test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)))
 | |
|         test(
 | |
|             vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)),
 | |
|             (torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)),
 | |
|             in_dims=(2, 0),
 | |
|         )
 | |
| 
 | |
|     def test_no_random_op_support(self):
 | |
|         B0 = 2
 | |
| 
 | |
|         captured = torch.rand(3)
 | |
| 
 | |
|         random_ops = [
 | |
|             # out-of-place on BatchedTensor
 | |
|             (torch.bernoulli, (torch.rand(B0, 1),)),
 | |
|             (lambda t: torch.bernoulli(t, p=0.5), (torch.rand(B0, 1),)),
 | |
|             (lambda t: torch.multinomial(t, 2), (torch.rand(B0, 3),)),
 | |
|             (torch.normal, (torch.randn(B0, 1), torch.randn(B0, 1))),
 | |
|             (lambda t: torch.normal(t, 1.0), (torch.randn(B0, 1),)),
 | |
|             (lambda t: torch.normal(0.0, t), (torch.randn(B0, 1),)),
 | |
|             (torch.poisson, (torch.rand(B0, 1),)),
 | |
|             (torch.rand_like, (torch.rand(B0, 1),)),
 | |
|             (torch.randn_like, (torch.rand(B0, 1),)),
 | |
|             (lambda t: torch.randint_like(t, 2), (torch.rand(B0, 1),)),
 | |
|             (lambda t: torch.randint_like(t, 0, 2), (torch.rand(B0, 1),)),
 | |
|             # out-of-place on captured tensor
 | |
|             (lambda t: torch.bernoulli(captured), (torch.rand(B0),)),
 | |
|             (lambda t: torch.bernoulli(captured, p=0.5), (torch.rand(B0),)),
 | |
|             (lambda t: torch.multinomial(captured, 2), (torch.rand(B0),)),
 | |
|             (lambda t: torch.normal(captured, captured), (torch.randn(B0),)),
 | |
|             (lambda t: torch.normal(captured, 1.0), (torch.randn(B0),)),
 | |
|             (lambda t: torch.normal(0.0, captured), (torch.randn(B0),)),
 | |
|             (lambda t: torch.poisson(captured), (torch.rand(B0),)),
 | |
|             (lambda t: torch.rand_like(captured), (torch.rand(B0),)),
 | |
|             (lambda t: torch.randn_like(captured), (torch.rand(B0),)),
 | |
|             (lambda t: torch.randint_like(captured, 2), (torch.rand(B0),)),
 | |
|             (lambda t: torch.randint_like(captured, 0, 2), (torch.rand(B0),)),
 | |
|             # in-place on BatchedTensor
 | |
|             (lambda t: t.bernoulli_(), (torch.randn(B0, 1),)),
 | |
|             (lambda t: t.cauchy_(), (torch.randn(B0, 1),)),
 | |
|             (lambda t: t.exponential_(), (torch.randn(B0, 1),)),
 | |
|             (lambda t: t.geometric_(0.5), (torch.randn(B0, 1),)),
 | |
|             (lambda t: t.log_normal_(), (torch.randn(B0, 1),)),
 | |
|             (lambda t: t.normal_(), (torch.randn(B0, 1),)),
 | |
|             (lambda t: t.random_(), (torch.randn(B0, 1),)),
 | |
|             (lambda t: t.random_(0, 2), (torch.randn(B0, 1),)),
 | |
|             (lambda t: t.random_(2), (torch.randn(B0, 1),)),
 | |
|             (lambda t: t.uniform_(), (torch.randn(B0, 1),)),
 | |
|             # in-place on captured tensor
 | |
|             (lambda t: captured.bernoulli_(), (torch.randn(B0),)),
 | |
|             (lambda t: captured.cauchy_(), (torch.randn(B0),)),
 | |
|             (lambda t: captured.exponential_(), (torch.randn(B0),)),
 | |
|             (lambda t: captured.geometric_(0.5), (torch.randn(B0),)),
 | |
|             (lambda t: captured.log_normal_(), (torch.randn(B0),)),
 | |
|             (lambda t: captured.normal_(), (torch.randn(B0),)),
 | |
|             (lambda t: captured.random_(), (torch.randn(B0),)),
 | |
|             (lambda t: captured.random_(0, 2), (torch.randn(B0),)),
 | |
|             (lambda t: captured.random_(2), (torch.randn(B0),)),
 | |
|             (lambda t: captured.uniform_(), (torch.randn(B0),)),
 | |
|             # factory functions
 | |
|             (lambda t: torch.rand(1), (torch.randn(B0),)),
 | |
|             (lambda t: torch.randn(1), (torch.randn(B0),)),
 | |
|             (lambda t: torch.randint(5, [1]), (torch.randn(B0),)),
 | |
|             (lambda t: torch.randperm(5), (torch.randn(B0),)),
 | |
|         ]
 | |
|         for op, args in random_ops:
 | |
|             with self.assertRaisesRegex(
 | |
|                 RuntimeError, "vmap: We do not yet support calling random operations"
 | |
|             ):
 | |
|                 vmap(op)(*args)
 | |
| 
 | |
| 
 | |
| def construct_v(output, batch_size):
 | |
|     return torch.randn(
 | |
|         batch_size, *output.shape, dtype=output.dtype, device=output.device
 | |
|     )
 | |
| 
 | |
| 
 | |
| def as_tuple(x):
 | |
|     if isinstance(x, tuple):
 | |
|         return x
 | |
|     elif isinstance(x, list):
 | |
|         return tuple(x)
 | |
|     else:
 | |
|         return (x,)
 | |
| 
 | |
| 
 | |
| def differentiable(args):
 | |
|     return tuple(
 | |
|         arg
 | |
|         for arg in as_tuple(args)
 | |
|         if isinstance(arg, torch.Tensor) and arg.requires_grad
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _get_rand_no_zeros(*args, **kwargs):
 | |
|     requires_grad = kwargs.get("requires_grad", False)
 | |
|     kwargs_without_requires_grad = kwargs.copy()
 | |
|     kwargs_without_requires_grad["requires_grad"] = False
 | |
|     result = torch.rand(*args, **kwargs_without_requires_grad)
 | |
|     return result.clamp_min_(0.1).requires_grad_(requires_grad)
 | |
| 
 | |
| 
 | |
| class TestVmapBatchedGradientLegacy(Namespace.TestVmapBaseLegacy):
 | |
|     def _vmap_test(self, *args, **kwargs):
 | |
|         return _vmap_test(self, *args, **kwargs)
 | |
| 
 | |
|     # Tests batched gradient computation of outputs = op(*args, **kwargs)
 | |
|     # by comparing it to a sequential map+stack fallback.
 | |
|     #
 | |
|     # output_process_fn: a function that maps the outputs to the part
 | |
|     #       that should be differentiated.
 | |
|     # batch_size: the batch dim size for the batched grad
 | |
|     def _batched_grad_test(
 | |
|         self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3
 | |
|     ):
 | |
|         if kwargs is None:
 | |
|             kwargs = {}
 | |
|         outputs = op(*args, **kwargs)
 | |
|         outputs = differentiable(output_process_fn(outputs))
 | |
|         batched_vectors = tuple(construct_v(out, batch_size) for out in outputs)
 | |
| 
 | |
|         def vector_jacobian_product(*vectors):
 | |
|             return torch.autograd.grad(
 | |
|                 outputs, differentiable(args), vectors, retain_graph=True
 | |
|             )
 | |
| 
 | |
|         self._vmap_test(
 | |
|             vector_jacobian_product, batched_vectors, check_propagates_grad=False
 | |
|         )
 | |
| 
 | |
|     # Tests batched second grad computation of outputs = op(*args, **kwargs).
 | |
|     # by comparing it to a sequential map+stack fallback.
 | |
|     #
 | |
|     # output_process_fn: a function that maps the outputs to the part
 | |
|     #       that should be differentiated.
 | |
|     # batch_size: the batch dim size for the batched grad
 | |
|     #
 | |
|     # NB: we only test computing batched gradients in the second gradient
 | |
|     # computation. One specific use case that does this is computing the hessian
 | |
|     # matrix of a scalar-valued function; this is useful in Bayesian Logistic
 | |
|     # Regression.
 | |
|     # It might be useful to have a test that computes batched first gradients and
 | |
|     # then uses those to compute batched second gradients in the future.
 | |
|     def _batched_grad_grad_test(
 | |
|         self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3
 | |
|     ):
 | |
|         if kwargs is None:
 | |
|             kwargs = {}
 | |
|         outputs = op(*args, **kwargs)
 | |
|         outputs = differentiable(output_process_fn(outputs))
 | |
|         ones = tuple(torch.ones_like(out) for out in outputs)
 | |
|         # Same thing as summing together all of the outputs and calling .backward()
 | |
|         first_grads = torch.autograd.grad(
 | |
|             outputs, differentiable(args), ones, create_graph=True
 | |
|         )
 | |
|         first_grads = differentiable(first_grads)
 | |
|         self.assertNotEqual(
 | |
|             len(first_grads), 0, "None of the first grads depend on the input!"
 | |
|         )
 | |
| 
 | |
|         batched_vectors = tuple(construct_v(grad, batch_size) for grad in first_grads)
 | |
| 
 | |
|         def vector_hessian_product(*vectors):
 | |
|             outputs = torch.autograd.grad(
 | |
|                 first_grads,
 | |
|                 differentiable(args),
 | |
|                 vectors,
 | |
|                 retain_graph=True,
 | |
|                 allow_unused=True,
 | |
|             )
 | |
|             outputs = tuple(out for out in outputs if out is not None)
 | |
|             assert len(outputs) > 0
 | |
|             return outputs
 | |
| 
 | |
|         self._vmap_test(
 | |
|             vector_hessian_product, batched_vectors, check_propagates_grad=False
 | |
|         )
 | |
| 
 | |
|     def _test_arithmetic(self, op, device, test_grad_grad=True):
 | |
|         x = torch.randn(2, 3, requires_grad=True, device=device)
 | |
|         y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
 | |
|         scalar = 3.14
 | |
|         self._batched_grad_test(op, (x, y))
 | |
|         self._batched_grad_test(op, (scalar, y))
 | |
|         self._batched_grad_test(op, (x, scalar))
 | |
| 
 | |
|         if test_grad_grad:
 | |
|             self._batched_grad_grad_test(op, (x, y))
 | |
| 
 | |
|     def test_add(self, device):
 | |
|         self._test_arithmetic(torch.add, device, test_grad_grad=False)
 | |
|         self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False)
 | |
| 
 | |
|     def test_sub(self, device):
 | |
|         self._test_arithmetic(torch.sub, device, test_grad_grad=False)
 | |
|         self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False)
 | |
| 
 | |
|     def test_mul(self, device):
 | |
|         self._test_arithmetic(torch.mul, device)
 | |
|         self._test_arithmetic(lambda x, y: x * y, device)
 | |
| 
 | |
|     def test_div(self, device):
 | |
|         self._test_arithmetic(torch.div, device)
 | |
|         self._test_arithmetic(lambda x, y: x / y, device)
 | |
| 
 | |
|     @allowVmapFallbackUsage
 | |
|     def test_binary_cross_entropy(self, device):
 | |
|         x = torch.sigmoid(torch.randn(3, 2, device=device, requires_grad=True))
 | |
|         target = torch.rand(3, 2, device=device)
 | |
| 
 | |
|         op = functools.partial(F.binary_cross_entropy, target=target)
 | |
| 
 | |
|         self._batched_grad_test(op, (x,), {})
 | |
|         self._batched_grad_grad_test(op, (x,), {})
 | |
| 
 | |
|     def test_expand(self, device):
 | |
|         x = torch.randn(2, 3, device=device, requires_grad=True)
 | |
| 
 | |
|         def op(x):
 | |
|             return x.expand(5, 5, 2, 3)
 | |
| 
 | |
|         self._batched_grad_test(op, (x,))
 | |
| 
 | |
|     @allowVmapFallbackUsage
 | |
|     def test_index(self, device):
 | |
|         x = torch.randn(2, 3, requires_grad=True, device=device)
 | |
|         index = torch.tensor([[0, 0], [1, 1]], device=device)
 | |
| 
 | |
|         def op(x):
 | |
|             y = x * x
 | |
|             return y[index]
 | |
| 
 | |
|         self._batched_grad_test(op, (x,))
 | |
|         self._batched_grad_grad_test(op, (x,))
 | |
| 
 | |
|     def test_lgamma(self, device):
 | |
|         x = torch.randn(2, 3, requires_grad=True, device=device)
 | |
|         self._batched_grad_test(Tensor.lgamma, (x,))
 | |
|         self._batched_grad_grad_test(Tensor.lgamma, (x,))
 | |
| 
 | |
|     def test_log(self, device):
 | |
|         x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
 | |
|         self._batched_grad_test(torch.log, (x,))
 | |
|         self._batched_grad_grad_test(torch.log, (x,))
 | |
| 
 | |
|     def test_logsumexp(self, device):
 | |
|         x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
 | |
| 
 | |
|         def op(x):
 | |
|             return torch.logsumexp(x, -1)
 | |
| 
 | |
|         self._batched_grad_test(op, (x,))
 | |
|         self._batched_grad_grad_test(op, (x,))
 | |
| 
 | |
|     def test_log1p(self, device):
 | |
|         x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
 | |
|         self._batched_grad_test(torch.log1p, (x,))
 | |
|         self._batched_grad_grad_test(torch.log1p, (x,))
 | |
| 
 | |
|     @allowVmapFallbackUsage
 | |
|     def test_max(self, device):
 | |
|         x = torch.randn(2, 3, requires_grad=True, device=device)
 | |
|         self._batched_grad_test(torch.max, (x,))
 | |
| 
 | |
|     @allowVmapFallbackUsage
 | |
|     def test_median(self, device):
 | |
|         x = torch.randn(2, 3, requires_grad=True, device=device)
 | |
|         self._batched_grad_test(torch.median, (x,))
 | |
| 
 | |
|     @allowVmapFallbackUsage
 | |
|     def test_min(self, device):
 | |
|         x = torch.randn(2, 3, requires_grad=True, device=device)
 | |
|         self._batched_grad_test(torch.min, (x,))
 | |
| 
 | |
|     def test_permute(self, device):
 | |
|         x = torch.randn(2, 3, 5, requires_grad=True, device=device)
 | |
| 
 | |
|         def op(x):
 | |
|             return x.permute(2, 0, 1)
 | |
| 
 | |
|         self._batched_grad_test(op, (x,))
 | |
| 
 | |
|     def test_reshape(self, device):
 | |
|         x = torch.randn(2, 3, 5, requires_grad=True, device=device)
 | |
| 
 | |
|         def op(x):
 | |
|             return x.reshape([2 * 3, 5])
 | |
| 
 | |
|         self._batched_grad_test(op, (x,))
 | |
| 
 | |
|     def test_sigmoid(self, device):
 | |
|         x = torch.randn(2, 3, requires_grad=True, device=device)
 | |
|         self._batched_grad_test(Tensor.sigmoid, (x,))
 | |
|         self._batched_grad_grad_test(Tensor.sigmoid, (x,))
 | |
| 
 | |
|     def test_stack(self, device):
 | |
|         x = torch.randn(2, 3, device=device, requires_grad=True)
 | |
|         y = torch.randn(2, 3, device=device, requires_grad=True)
 | |
| 
 | |
|         def op(x, y):
 | |
|             return torch.stack([x, y])
 | |
| 
 | |
|         self._batched_grad_test(op, (x, y))
 | |
| 
 | |
|     def test_select(self, device):
 | |
|         x = torch.randn(2, 3, device=device, requires_grad=True)
 | |
|         self._batched_grad_test(lambda x: x[1], (x,))
 | |
|         self._batched_grad_test(lambda x: x.select(1, 2), (x,))
 | |
|         self._batched_grad_test(lambda x: x.select(-1, 0), (x,))
 | |
| 
 | |
|     def test_slice(self, device):
 | |
|         x = torch.randn(2, 3, 5, device=device, requires_grad=True)
 | |
|         self._batched_grad_test(lambda x: x[0:1], (x,))
 | |
|         self._batched_grad_test(lambda x: x[:, 1:3], (x,))
 | |
|         self._batched_grad_test(lambda x: x[..., 1:3], (x,))
 | |
| 
 | |
|     def test_trace(self, device):
 | |
|         x = torch.randn(2, 3, device=device, requires_grad=True)
 | |
|         self._batched_grad_test(Tensor.trace, (x,))
 | |
| 
 | |
|     def test_threshold(self, device):
 | |
|         x = torch.randn(2, 3, device=device, requires_grad=True)
 | |
|         self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,))
 | |
| 
 | |
|     @allowVmapFallbackUsage
 | |
|     def test_inplace_on_view(self, device):
 | |
|         leaf = torch.randn(4, 5, requires_grad=True)
 | |
| 
 | |
|         def func(leaf):
 | |
|             # Make sure the function is non-trivially twice differentiable
 | |
|             base = leaf * leaf
 | |
|             view = base[0]
 | |
|             view.cos_()
 | |
|             return view
 | |
| 
 | |
|         self._batched_grad_test(func, (leaf,), {})
 | |
|         self._batched_grad_grad_test(func, (leaf,), {})
 | |
| 
 | |
|     @allowVmapFallbackUsage
 | |
|     def test_inplace_manyview(self, device):
 | |
|         leaf = torch.randn(4, 4, 5, requires_grad=True)
 | |
| 
 | |
|         def func(leaf):
 | |
|             # Make sure the function is non-trivially twice differentiable
 | |
|             base = leaf * leaf
 | |
|             view = base.transpose(0, 2)
 | |
|             view = view[1]
 | |
|             view = view.diagonal()
 | |
|             view = view[::2]
 | |
|             view.cos_()
 | |
|             return view
 | |
| 
 | |
|         self._batched_grad_test(func, (leaf,), {})
 | |
|         self._batched_grad_grad_test(func, (leaf,), {})
 | |
| 
 | |
|     def test_diagonal(self, device):
 | |
|         x = torch.randn(4, 5, device=device, requires_grad=True)
 | |
|         self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,))
 | |
| 
 | |
|         x = torch.randn(3, 4, 5, device=device, requires_grad=True)
 | |
|         self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,))
 | |
| 
 | |
|     @allowVmapFallbackUsage
 | |
|     def test_unrelated_output(self, device):
 | |
|         B0 = 3
 | |
|         x = torch.randn([], requires_grad=True)
 | |
|         y = torch.randn([], requires_grad=True)
 | |
|         gy = torch.randn(B0, requires_grad=True)
 | |
| 
 | |
|         def vjp(v):
 | |
|             (res,) = torch.autograd.grad(y, x, v, allow_unused=True)
 | |
|             return torch.zeros_like(x) if res is None else res
 | |
| 
 | |
|         result = vmap(vjp)(gy)
 | |
|         self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
 | |
| 
 | |
|     @allowVmapFallbackUsage
 | |
|     def test_unrelated_output_multiple_grad(self, device):
 | |
|         B0 = 3
 | |
|         x = torch.randn([], requires_grad=True)
 | |
|         y = torch.randn([], requires_grad=True)
 | |
|         gy = torch.randn(B0, requires_grad=True)
 | |
| 
 | |
|         def vjp(v):
 | |
|             (res,) = torch.autograd.grad(y, x, v, allow_unused=True)
 | |
|             return torch.zeros_like(x) if res is None else res
 | |
| 
 | |
|         _ = vjp(gy[0])
 | |
|         result = vmap(vjp)(gy)
 | |
|         self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
 | |
| 
 | |
| 
 | |
| instantiate_device_type_tests(TestVmapBatchedGradientLegacy, globals(), None)
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     run_tests()
 |