mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[dynamo] Support builtin complex with constant args (#160799)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160799 Approved by: https://github.com/guilhermeleobas, https://github.com/mlazos
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							66166cf1e7
						
					
				
				
					commit
					35e4d97e04
				
			| @ -12848,6 +12848,36 @@ fn | ||||
|         res = opt_f(x) | ||||
|         self.assertEqual(ref, res) | ||||
|  | ||||
|     def test_builtin_complex(self): | ||||
|         def f(x): | ||||
|             c = ( | ||||
|                 complex(), | ||||
|                 complex(1), | ||||
|                 complex(2, 3), | ||||
|                 complex(imag=2), | ||||
|                 complex(real=1), | ||||
|                 complex(imag=1, real=2), | ||||
|                 complex("1+2j"), | ||||
|             ) | ||||
|             return [x + z for z in c] | ||||
|  | ||||
|         x = torch.randn(1) | ||||
|         opt_f = torch.compile(f, backend="eager", fullgraph=True) | ||||
|         res = opt_f(x) | ||||
|         ref = f(x) | ||||
|         self.assertEqual(res, ref) | ||||
|  | ||||
|     def test_builtin_complex_args(self): | ||||
|         @torch.compile(backend="eager", fullgraph=True) | ||||
|         def f(*args, **kwargs): | ||||
|             return torch.tensor(complex(*args, **kwargs)) | ||||
|  | ||||
|         self.assertRaises(Unsupported, f, 1, 1, 1) | ||||
|         self.assertRaises(Unsupported, f, 1, 1, fake_arg=1) | ||||
|         self.assertRaises(Unsupported, f, fake_arg=1) | ||||
|         self.assertRaises(Unsupported, f, []) | ||||
|         self.assertRaises(Unsupported, f, "1 + j") | ||||
|  | ||||
|  | ||||
| class TestTracer(JitTestCase): | ||||
|     def test_jit_save(self): | ||||
|  | ||||
| @ -1478,6 +1478,21 @@ class BuiltinVariable(VariableTracker): | ||||
|     call_int = _call_int_float | ||||
|     call_float = _call_int_float | ||||
|  | ||||
|     def call_complex(self, tx: "InstructionTranslator", *args, **kwargs): | ||||
|         if self.constant_args(*args, **kwargs): | ||||
|             try: | ||||
|                 c = complex( | ||||
|                     *(arg.as_python_constant() for arg in args), | ||||
|                     **{k: kwargs[k].as_python_constant() for k in kwargs}, | ||||
|                 ) | ||||
|             except (TypeError, ValueError) as exc: | ||||
|                 raise_observed_exception( | ||||
|                     type(exc), | ||||
|                     tx, | ||||
|                     args=list(map(ConstantVariable.create, exc.args)), | ||||
|                 ) | ||||
|             return ConstantVariable(c) | ||||
|  | ||||
|     def call_bool(self, tx: "InstructionTranslator", arg): | ||||
|         # Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`. | ||||
|         # https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697 | ||||
|  | ||||
		Reference in New Issue
	
	Block a user