mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Disallow FakeTensor.data_ptr access in eager mode (#137221)
Previously we raised a deprecation warning (beginning PyTorch 2.4). Now that we are on 2.6, we're completing the deprecation and disallowing this behavior. Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/137221 Approved by: https://github.com/albanD, https://github.com/eellison
This commit is contained in:
@ -24,22 +24,6 @@ void throwNullDataPtrError() {
|
||||
"https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html");
|
||||
}
|
||||
|
||||
// NOTE: [FakeTensor.data_ptr deprecation]
|
||||
// Today:
|
||||
// - FakeTensor.data_ptr errors out in torch.compile.
|
||||
// - FakeTensor.data_ptr raises the following deprecation warning otherwise.
|
||||
// - the following deprecation warning is only for FakeTensor (for now).
|
||||
// In the future we can consider extending to more wrapper Tensor subclasses.
|
||||
void warnDeprecatedDataPtr() {
|
||||
TORCH_WARN_ONCE(
|
||||
"Accessing the data pointer of FakeTensor is deprecated and will error in "
|
||||
"PyTorch 2.5. This is almost definitely a bug in your code and will "
|
||||
"cause undefined behavior with subsystems like torch.compile. "
|
||||
"Please wrap calls to tensor.data_ptr() in an opaque custom op; "
|
||||
"If all else fails, you can guard accesses to tensor.data_ptr() on "
|
||||
"isinstance(tensor, FakeTensor).")
|
||||
}
|
||||
|
||||
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
|
||||
// Allowlist verification.
|
||||
// Only if the devicetype is in the allowlist,
|
||||
|
@ -17,7 +17,6 @@
|
||||
namespace c10 {
|
||||
|
||||
C10_API void throwNullDataPtrError();
|
||||
C10_API void warnDeprecatedDataPtr();
|
||||
|
||||
// A storage represents the underlying backing data buffer for a
|
||||
// tensor. This concept was inherited from the original Torch7
|
||||
@ -131,9 +130,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
||||
if (throw_on_mutable_data_ptr_) {
|
||||
throwNullDataPtrError();
|
||||
}
|
||||
if (warn_deprecated_on_mutable_data_ptr_) {
|
||||
warnDeprecatedDataPtr();
|
||||
}
|
||||
maybe_materialize_cow();
|
||||
}
|
||||
return data_ptr_;
|
||||
@ -166,9 +162,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
||||
if (throw_on_mutable_data_ptr_) {
|
||||
throwNullDataPtrError();
|
||||
}
|
||||
if (warn_deprecated_on_mutable_data_ptr_) {
|
||||
warnDeprecatedDataPtr();
|
||||
}
|
||||
maybe_materialize_cow();
|
||||
}
|
||||
return data_ptr_.mutable_get();
|
||||
@ -253,11 +246,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
||||
refresh_has_data_ptr_check();
|
||||
}
|
||||
|
||||
void set_warn_deprecated_on_mutable_data_ptr() {
|
||||
warn_deprecated_on_mutable_data_ptr_ = true;
|
||||
refresh_has_data_ptr_check();
|
||||
}
|
||||
|
||||
protected:
|
||||
// materialize_cow_storage needs to call set_data_ptr_no_materlize_cow
|
||||
friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage);
|
||||
@ -273,8 +261,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
||||
|
||||
private:
|
||||
void refresh_has_data_ptr_check() {
|
||||
has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ ||
|
||||
warn_deprecated_on_mutable_data_ptr_;
|
||||
has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_;
|
||||
}
|
||||
|
||||
inline bool is_cow() const {
|
||||
@ -301,8 +288,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
||||
bool has_data_ptr_check_ = false;
|
||||
// If we should throw when mutable_data_ptr() or mutable_data() is called.
|
||||
bool throw_on_mutable_data_ptr_ = false;
|
||||
// If we warn when mutable_data_ptr() or mutable_data() is called.
|
||||
bool warn_deprecated_on_mutable_data_ptr_ = false;
|
||||
Allocator* allocator_;
|
||||
impl::PyObjectSlot pyobj_slot_;
|
||||
};
|
||||
|
@ -1112,14 +1112,19 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
|
||||
self.assertEqual(z.grad, z_opt.grad)
|
||||
|
||||
def test_data_ptr_access_copy(self):
|
||||
import torch._functorch.config as _config
|
||||
|
||||
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
||||
with FakeTensorMode():
|
||||
x = torch.randn(3)
|
||||
y = copy.copy(x)
|
||||
with FakeTensorMode():
|
||||
x = torch.randn(3)
|
||||
y = copy.copy(x)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
def test_data_ptr_access_in_eager_errors(self):
|
||||
with FakeTensorMode():
|
||||
x = torch.randn(3)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Cannot access data pointer of Tensor"
|
||||
):
|
||||
x.data_ptr()
|
||||
|
||||
def test_data_ptr_access_fails_in_forward(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)
|
||||
|
@ -1588,7 +1588,6 @@ def _functionalization_reapply_views_tls() -> _bool: ...
|
||||
def _only_lift_cpu_tensors() -> _bool: ...
|
||||
def _set_only_lift_cpu_tensors(value: _bool) -> None: ...
|
||||
def _set_throw_on_mutable_data_ptr(tensor: Tensor) -> None: ...
|
||||
def _set_warn_deprecated_on_mutable_data_ptr(tensor: Tensor) -> None: ...
|
||||
|
||||
class DispatchKey(Enum):
|
||||
${dispatch_key_hints}
|
||||
|
@ -304,15 +304,13 @@ class OutputGraph:
|
||||
|
||||
# In export mode, we force the shape_env to strictly disallow any constraining
|
||||
# of the user marked dynamic dims
|
||||
import torch._functorch.config as _config
|
||||
|
||||
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
||||
fake_mode = torch._subclasses.FakeTensorMode(
|
||||
shape_env=shape_env,
|
||||
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
|
||||
allow_non_fake_inputs=True if self.export else False,
|
||||
export=self.export,
|
||||
)
|
||||
fake_mode = torch._subclasses.FakeTensorMode(
|
||||
shape_env=shape_env,
|
||||
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
|
||||
allow_non_fake_inputs=True if self.export else False,
|
||||
export=self.export,
|
||||
)
|
||||
self.tracing_context: TracingContext = TracingContext(fake_mode)
|
||||
self.init_ambient_guards()
|
||||
|
||||
@ -1354,13 +1352,10 @@ class OutputGraph:
|
||||
self.call_cleanup_hooks()
|
||||
old_fake_mode = self.tracing_context.fake_mode
|
||||
if not self.export:
|
||||
import torch._functorch.config as _config
|
||||
|
||||
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
||||
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
|
||||
backend_fake_mode = torch._subclasses.FakeTensorMode(
|
||||
shape_env=old_fake_mode.shape_env,
|
||||
)
|
||||
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
|
||||
backend_fake_mode = torch._subclasses.FakeTensorMode(
|
||||
shape_env=old_fake_mode.shape_env,
|
||||
)
|
||||
# TODO(voz): Ostensibily, this should be scoped and
|
||||
# restore back to old_fake_mode, but doing so currently violates
|
||||
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
|
||||
|
@ -150,11 +150,6 @@ visualize_memory_budget_pareto = (
|
||||
# cost of some performance
|
||||
aggressive_recomputation = False
|
||||
|
||||
# If FakeTensor.data_ptr() should error.
|
||||
# This option is independent of AOTAutograd and torch.compile, but our policy
|
||||
# is to turn it off during torch.compile.
|
||||
fake_tensor_allow_unsafe_data_ptr_access = True
|
||||
|
||||
# Unlifts effect tokens from the inputs/outputs in the traced graph and instead
|
||||
# inserts make_token/sink_token calls in the graph to create tokens and then
|
||||
# sink them at the end. Note that this means the graph is no longer functional
|
||||
|
@ -673,10 +673,7 @@ class FakeTensor(Tensor):
|
||||
dispatch_device=True,
|
||||
device_for_backend_keys=device,
|
||||
)
|
||||
if not fake_mode._allow_unsafe_data_ptr_access:
|
||||
torch._C._set_throw_on_mutable_data_ptr(self)
|
||||
else:
|
||||
torch._C._set_warn_deprecated_on_mutable_data_ptr(self)
|
||||
torch._C._set_throw_on_mutable_data_ptr(self)
|
||||
|
||||
assert elem.device.type == "meta", elem.device.type
|
||||
device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
@ -1132,9 +1129,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# places where we unconditionally allow scalar outputs, TO BE REMOVED
|
||||
self.allow_scalar_outputs = False
|
||||
|
||||
self._allow_unsafe_data_ptr_access = (
|
||||
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access
|
||||
)
|
||||
self.allow_meta = torch._functorch.config.fake_tensor_allow_meta
|
||||
self.cache_enabled = (
|
||||
torch._dynamo.config.fake_tensor_cache_enabled
|
||||
|
@ -965,19 +965,6 @@ void initDispatchBindings(PyObject* module) {
|
||||
->set_throw_on_mutable_data_ptr();
|
||||
});
|
||||
|
||||
// Invariant: you must ONLY call this with FakeTensors.
|
||||
m.def("_set_warn_deprecated_on_mutable_data_ptr", [](const at::Tensor& t) {
|
||||
if (!t.unsafeGetTensorImpl()->has_storage()) {
|
||||
// If the Tensor doesn't have a storage, then accessing .data_ptr()
|
||||
// will already raise an error.
|
||||
return;
|
||||
}
|
||||
t.unsafeGetTensorImpl()
|
||||
->storage()
|
||||
.unsafeGetStorageImpl()
|
||||
->set_warn_deprecated_on_mutable_data_ptr();
|
||||
});
|
||||
|
||||
m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
|
||||
m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);
|
||||
|
||||
|
@ -1880,15 +1880,12 @@ class _MakefxTracer:
|
||||
|
||||
fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
|
||||
if fake_tensor_mode is None:
|
||||
import torch._functorch.config as _config
|
||||
|
||||
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
||||
fake_tensor_mode = FakeTensorMode(
|
||||
allow_fallback_kernels=True,
|
||||
allow_non_fake_inputs=self._allow_non_fake_inputs,
|
||||
shape_env=ShapeEnv(),
|
||||
static_shapes=True,
|
||||
)
|
||||
fake_tensor_mode = FakeTensorMode(
|
||||
allow_fallback_kernels=True,
|
||||
allow_non_fake_inputs=self._allow_non_fake_inputs,
|
||||
shape_env=ShapeEnv(),
|
||||
static_shapes=True,
|
||||
)
|
||||
self.fake_tensor_mode = fake_tensor_mode
|
||||
elif self.tracing_mode == "symbolic":
|
||||
import torch._dynamo
|
||||
@ -1896,14 +1893,12 @@ class _MakefxTracer:
|
||||
fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
|
||||
if fake_tensor_mode is None:
|
||||
shape_env = ShapeEnv()
|
||||
import torch._functorch.config as _config
|
||||
|
||||
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
||||
fake_tensor_mode = FakeTensorMode(
|
||||
allow_fallback_kernels=False,
|
||||
allow_non_fake_inputs=self._allow_non_fake_inputs,
|
||||
shape_env=shape_env,
|
||||
)
|
||||
fake_tensor_mode = FakeTensorMode(
|
||||
allow_fallback_kernels=False,
|
||||
allow_non_fake_inputs=self._allow_non_fake_inputs,
|
||||
shape_env=shape_env,
|
||||
)
|
||||
assert (
|
||||
fake_tensor_mode.shape_env is not None
|
||||
), "shape_env should be set if tracing with 'symbolic'"
|
||||
|
Reference in New Issue
Block a user