Compare commits

...

64 Commits

Author SHA1 Message Date
49413b9764 linting 2025-11-04 14:55:01 +00:00
596f830a88 Added support for number casting bools and complex zeros. Added static type for failed builds. 2025-11-04 14:54:26 +00:00
28b75c1510 linting. 2025-11-03 21:01:28 +00:00
b4f673cd16 Merge branch 'main' into bugfix/dtype_foward_agrad 2025-11-03 20:58:12 +00:00
2e1b5745ad Needed stricter constraints on casting zerotensors to python numbers to avoid semantic errors with FakeTensors. 2025-11-03 16:38:59 +00:00
1b115bc2cd Added separate conditional to handle converting ZeroTensors to PyObjects. Only use zerotensor_meta if there is a wrapped number and it is a zerotensor. 2025-11-01 02:33:15 +00:00
37d7c7c68d Don't need to set is_wrapped_number. Just need to see if there are zerotensors. 2025-10-31 18:30:22 +00:00
72d786def3 Merge branch 'main' into bugfix/dtype_foward_agrad 2025-10-31 15:50:19 +00:00
73041dbe77 Leverage zerotensor metas to get proper dtype and cast python number for zerotensors. 2025-10-31 14:59:35 +00:00
748d8256eb Removed unneeded uses of is_wrapped_number. 2025-10-31 14:58:52 +00:00
a6e52c5358 Move the check of ZeroTensor to see if has storage initialized. If it doesn't it is any meta tensors. 2025-10-29 19:36:30 +00:00
bdf7bc0d6e Removed redundant code. 2025-10-29 12:58:42 +00:00
d8b30f0a5f No torchscript implementation of is_wrapped_number. 2025-10-28 21:34:36 +00:00
714868110b linting. 2025-10-28 21:26:13 +00:00
f955c577c7 Used ternary operator. 2025-10-28 21:24:54 +00:00
583a4172f9 Merge branch 'bugfix/dtype_foward_agrad' of github.com:skpark-rh/pytorch into bugfix/dtype_foward_agrad 2025-10-28 15:02:49 +00:00
1c3af6aed5 When kernel invokes python func with args and kwargs, relaxed internal assert to allow meta devices types with 0dims through. 2025-10-28 15:02:33 +00:00
46386bf9b4 Added helper function to set is_wrapped_number. 2025-10-28 15:01:31 +00:00
aea4ecf761 Removed was_wrapped_number and used already existing is_wrapped_number. 2025-10-28 15:00:50 +00:00
eb40d7fb92 Merge branch 'pytorch:main' into bugfix/dtype_foward_agrad 2025-10-17 15:52:27 -04:00
51da7c9aa3 Merge branch 'bugfix/dtype_foward_agrad' of github.com:skpark-rh/pytorch into bugfix/dtype_foward_agrad 2025-10-17 19:20:09 +00:00
3473ae9b1d was_wrapped_number property doesn't have a TorchScript implementation. 2025-10-17 19:19:55 +00:00
e7d59e8d98 Clean up with lintrunner. 2025-10-17 19:19:22 +00:00
112c7a634c Wrote new test for the forward autograd bug where basic arithmetic operations caused dtypes to be different. 2025-10-17 19:19:22 +00:00
63656fbadd Added wrapped_num template for div. 2025-10-17 19:19:21 +00:00
141ce81c19 Initalize boolean variable to false. 2025-10-17 19:19:21 +00:00
3501580923 Added passthrough for new property. 2025-10-17 19:19:21 +00:00
d39815de2a Reverted exposing the is_wrapped_number method. 2025-10-17 19:19:21 +00:00
787e895190 Added is_wrapped_number method to determine on the python side if the tensor from a wrapped number. 2025-10-17 19:16:31 +00:00
d24b01d505 Need to update add. 2025-10-17 19:16:31 +00:00
4d35e3ab70 Python Tensor init doc update for was_wrapped_number. 2025-10-17 19:16:31 +00:00
ee590d73f1 Setting the was_wrapped_number for zero tensors derived from wrapped_numbers. Then using the correct dtype promotions on the python side. 2025-10-17 19:16:31 +00:00
729f1e3150 Added a new property called was_wrapped_number and exposed it to the python side to handle dtype promotions. 2025-10-17 19:15:58 +00:00
04a23a4867 Reverted exposing the is_wrapped_number method. 2025-10-17 19:15:58 +00:00
746eb7d353 Added fw derivative template for mul_Tensor to set a zero tensor with a "is_wrapped_number" as true if the derived derivated is also a wrapped number. 2025-10-17 19:15:58 +00:00
0dc5cbfdf4 Added is_wrapped_number method to determine on the python side if the tensor from a wrapped number. 2025-10-17 19:15:58 +00:00
03fd03293e Merge branch 'main' into bugfix/dtype_foward_agrad 2025-10-09 16:20:10 -04:00
6b2308f845 was_wrapped_number property doesn't have a TorchScript implementation. 2025-10-07 17:41:14 +00:00
36a4c6a8c6 Clean up with lintrunner. 2025-10-06 22:17:36 +00:00
b1fd13f52d Merge branch 'pytorch:main' into bugfix/dtype_foward_agrad 2025-10-06 16:27:05 -04:00
f8efa8fdce Merge branch 'main' into bugfix/dtype_foward_agrad 2025-10-06 20:26:10 +00:00
c3fb426d0a Wrote new test for the forward autograd bug where basic arithmetic operations caused dtypes to be different. 2025-10-06 20:04:35 +00:00
c3885f8cbc Added wrapped_num template for div. 2025-10-06 20:04:01 +00:00
4eca0bda27 Initalize boolean variable to false. 2025-10-06 13:49:53 +00:00
b08d7e00ba Merge branch 'main' into bugfix/dtype_foward_agrad 2025-10-02 16:40:27 +00:00
d2d2bc826d Merge branch 'pytorch:main' into bugfix/dtype_foward_agrad 2025-10-02 11:05:10 -04:00
167cb03030 Merge branch 'main' into bugfix/dtype_foward_agrad 2025-10-01 18:32:42 +00:00
f8a0d63d30 Merge branch 'pytorch:main' into bugfix/dtype_foward_agrad 2025-10-01 14:05:29 -04:00
8c2d2d07c2 Added passthrough for new property. 2025-10-01 16:50:19 +00:00
dc6cd52b31 Merge branch 'bugfix/dtype_foward_agrad' of github.com:skpark-rh/pytorch into bugfix/dtype_foward_agrad 2025-09-30 20:20:01 +00:00
7bc1899837 Need to update add. 2025-09-30 20:19:33 +00:00
2ccd5d0cb0 Python Tensor init doc update for was_wrapped_number. 2025-09-30 20:19:33 +00:00
7a21c49dde Setting the was_wrapped_number for zero tensors derived from wrapped_numbers. Then using the correct dtype promotions on the python side. 2025-09-30 20:19:33 +00:00
d1f7c1135c Added a new property called was_wrapped_number and exposed it to the python side to handle dtype promotions. 2025-09-30 20:19:33 +00:00
18faa1e3a5 Reverted exposing the is_wrapped_number method. 2025-09-30 20:19:33 +00:00
11e9d13aa5 Added fw derivative template for mul_Tensor to set a zero tensor with a "is_wrapped_number" as true if the derived derivated is also a wrapped number. 2025-09-30 20:19:33 +00:00
145d95a6e5 Added is_wrapped_number method to determine on the python side if the tensor from a wrapped number. 2025-09-30 20:19:33 +00:00
1ed00e4122 Need to update add. 2025-09-30 20:17:06 +00:00
76ac42b221 Python Tensor init doc update for was_wrapped_number. 2025-09-30 19:56:42 +00:00
1a10b92d80 Setting the was_wrapped_number for zero tensors derived from wrapped_numbers. Then using the correct dtype promotions on the python side. 2025-09-30 19:55:57 +00:00
ad40ba40cc Added a new property called was_wrapped_number and exposed it to the python side to handle dtype promotions. 2025-09-30 19:52:10 +00:00
c327cfc15e Reverted exposing the is_wrapped_number method. 2025-09-30 19:50:11 +00:00
c29c212ab3 Added fw derivative template for mul_Tensor to set a zero tensor with a "is_wrapped_number" as true if the derived derivated is also a wrapped number. 2025-09-29 22:28:33 +00:00
ddf3e41a90 Added is_wrapped_number method to determine on the python side if the tensor from a wrapped number. 2025-09-29 22:26:39 +00:00
6 changed files with 97 additions and 11 deletions

View File

@ -1009,12 +1009,25 @@ static Device correct_out_device(const Tensor& self, const Tensor& other) {
}
}
static Tensor send_to_meta(const Tensor& self, const Device& device) {
Tensor out_meta;
if (self._is_zerotensor() && self.unsafeGetTensorImpl()->is_wrapped_number()) {
out_meta = at::_efficientzerotensor(self.sizes(), self.options().device(device));
out_meta.unsafeGetTensorImpl()->set_wrapped_number(true);
} else {
out_meta = self.to(device);
}
return out_meta;
}
Tensor mul_zerotensor(const Tensor& self, const Tensor& other) {
auto out_device = correct_out_device(self, other);
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
auto device_ = Device(DeviceType::Meta);
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
auto self_meta = send_to_meta(self, device_);
auto other_meta = send_to_meta(other, device_);
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self_meta, other_meta);
return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
}
@ -1023,7 +1036,9 @@ Tensor div_zerotensor(const Tensor& self, const Tensor& other) {
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
auto device_ = Device(DeviceType::Meta);
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
auto self_meta = send_to_meta(self, device_);
auto other_meta = send_to_meta(other, device_);
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self_meta, other_meta);
if (self._is_zerotensor()) {
if (other._is_zerotensor()) {
@ -1052,8 +1067,9 @@ static Tensor maybe_add_maybe_sub(const Tensor& self, const Tensor& other, const
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
auto device_ = Device(DeviceType::Meta);
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
auto meta_out = at::_ops::add_Tensor::redispatch(
meta_dks, self.to(device_), other.to(device_), alpha);
auto self_meta = send_to_meta(self, device_);
auto other_meta = send_to_meta(other, device_);
auto meta_out = at::_ops::add_Tensor::redispatch(meta_dks, self_meta, other_meta, alpha);
auto get_out_like = [&] (const Tensor& tensor)
{

View File

@ -2992,12 +2992,50 @@ class TestFakeTensor(TestCase):
self.assertEqual(strided_result.layout, torch.strided)
class TestForwardADWithScalars(TestCase):
@ops(
[op for op in op_db if op.name in ["mul", "add", "div"]],
allowed_dtypes=(torch.float32,),
)
def test_0d_tensor_with_python_scalar(self, device, dtype, op):
"""Test that forward AD preserves dtype when combining 0D tensors with Python scalars."""
if torch.float not in op.supported_backward_dtypes(device):
raise unittest.SkipTest("Does not support autograd")
# skip if operator doesnt support forward AD
if not op.supports_forward_ad:
raise unittest.SkipTest("Does not support forward_ad")
# create 0D tensors
primal0d = torch.ones((), device=device, dtype=dtype)
tangent0d = torch.ones((), device=device, dtype=dtype)
with torch.autograd.forward_ad.dual_level():
dual0d = torch.autograd.forward_ad.make_dual(primal0d, tangent0d)
# Test with scalar on RHS
if op.supports_rhs_python_scalar:
result = op(dual0d, 2.0)
p, t = torch.autograd.forward_ad.unpack_dual(result)
self.assertEqual(
p.dtype, t.dtype, f"{op.name} and scalar on RHS - dtype mismatch"
)
# Test with scalar on LHS
if op.supports_one_python_scalar:
result = op(2.0, dual0d)
p, t = torch.autograd.forward_ad.unpack_dual(result)
self.assertEqual(
p.dtype, t.dtype, f"{op.name} and scalar on LHS - dtype mismatch"
)
instantiate_device_type_tests(TestCommon, globals(), allow_xpu=True)
instantiate_device_type_tests(TestCompositeCompliance, globals())
instantiate_device_type_tests(TestMathBits, globals())
instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")
instantiate_device_type_tests(TestFakeTensor, globals())
instantiate_device_type_tests(TestTags, globals())
instantiate_device_type_tests(TestForwardADWithScalars, globals())
if __name__ == "__main__":
TestCase._default_dtype_check_enabled = True

View File

@ -763,6 +763,12 @@ auto ${inp_name}_t = (${inp_name}_t_raw.defined() || !${inp_name}_tensor.defined
"""
)
FW_DERIVATIVE_UPDATE_WRAPPED_NUM_TEMPLATE = CodeTemplate(
"""\
update_wrapped_number(${inp_name}_tensor, ${inp_name}_t);
"""
)
FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate(
"""\
auto ${inp_name}_p = toNonOptPrimal(${inp});
@ -1911,6 +1917,13 @@ def emit_body(
zeros_fn=zeros_fn,
)
)
if zeros_fn == "_efficientzerotensor_symint":
unpacked_arguments += (
FW_DERIVATIVE_UPDATE_WRAPPED_NUM_TEMPLATE.substitute(
inp_name=inp.name
)
)
if inp.name in (derivative.required_inputs_primal or []):
unpacked_arguments += (
FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(

View File

@ -79,6 +79,12 @@ Tensor toNonOptPrimal(const std::optional<Tensor>& t) {
return Tensor();
}
void update_wrapped_number(Tensor& input, Tensor& output) {
if (input.unsafeGetTensorImpl()->is_wrapped_number()) {
output.unsafeGetTensorImpl()->set_wrapped_number(true);
}
}
void copy_range(variable_list& out, IndexRange range, const Tensor& t) {
TORCH_CHECK(range.second <= out.size());
TORCH_CHECK(

View File

@ -43,6 +43,7 @@ inline std::optional<Tensor> wrap_opt_if(const Tensor& t, const bool cond) {
TORCH_API Tensor
apply_loss_reduction(const Tensor& unreduced, int64_t reduction);
TORCH_API bool any_variable_defined(const variable_list& variables);
TORCH_API void update_wrapped_number(Tensor& input, Tensor& output);
TORCH_API void copy_range(
variable_list& out,
IndexRange range,

View File

@ -587,7 +587,9 @@ py::object toPyObject(IValue ivalue) {
} else if (ivalue.isTensor()) {
auto tensor = std::move(ivalue).toTensor();
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
TORCH_INTERNAL_ASSERT(tensor.device().is_cpu());
TORCH_INTERNAL_ASSERT(
tensor.device().is_cpu() ||
(tensor._is_zerotensor() && tensor.dim() == 0));
auto py_tensor = py::cast(tensor);
if (PyObject_HasAttrString(py_tensor.ptr(), "_wrapped_number")) {
return py_tensor.attr("_wrapped_number");
@ -595,17 +597,27 @@ py::object toPyObject(IValue ivalue) {
auto scalar_type = tensor.scalar_type();
switch (scalar_type) {
case at::ScalarType::Bool:
return py::cast(*tensor.const_data_ptr<bool>());
return (tensor._is_zerotensor())
? py::cast(false)
: py::cast(*tensor.const_data_ptr<bool>());
case at::ScalarType::Long:
return py::cast(*tensor.const_data_ptr<int64_t>());
return (tensor._is_zerotensor())
? py::cast(int64_t(0))
: py::cast(*tensor.const_data_ptr<int64_t>());
case at::ScalarType::UInt64:
return py::cast(*tensor.const_data_ptr<uint64_t>());
return (tensor._is_zerotensor())
? py::cast(u_int64_t(0))
: py::cast(*tensor.const_data_ptr<uint64_t>());
case at::ScalarType::Double:
return py::cast(*tensor.const_data_ptr<double>());
return (tensor._is_zerotensor())
? py::cast(0.0)
: py::cast(*tensor.const_data_ptr<double>());
case at::ScalarType::ComplexDouble:
// TODO: https://github.com/pytorch/pytorch/issues/77134
return py::cast(static_cast<std::complex<double>>(
*tensor.const_data_ptr<c10::complex<double>>()));
return (tensor._is_zerotensor())
? py::cast(std::complex<double>(0.0, 0.0))
: py::cast(static_cast<std::complex<double>>(
*tensor.const_data_ptr<c10::complex<double>>()));
default:
TORCH_CHECK(
false,