Add device logic handling for functions which allow scalar inputs as tensors (#86149)

Some functions allow scalars as tensor inputs. Add handling for them in device logic.

Fix for https://github.com/pytorch/torchdynamo/issues/1445
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86149
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
This commit is contained in:
Elias Ellison
2022-10-04 16:15:56 +00:00
committed by PyTorch MergeBot
parent d6b030856b
commit 9da5646cdb
4 changed files with 24 additions and 0 deletions

View File

@ -1408,6 +1408,11 @@ Call this whenever a new thread is created in order to propagate values from
std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n";
});
py_module.def(
"_should_allow_numbers_as_tensors", [](const std::string& name) {
return torch::should_allow_numbers_as_tensors(name);
});
py_module.def("_is_deploy_enabled", []() {
#if defined(USE_DEPLOY)
return true;