Disallow FakeTensor.data_ptr access in eager mode (#137221)

Previously we raised a deprecation warning (beginning PyTorch 2.4). Now
that we are on 2.6, we're completing the deprecation and disallowing
this behavior.

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137221
Approved by: https://github.com/albanD, https://github.com/eellison
This commit is contained in:
rzou
2024-10-03 08:29:25 -07:00
committed by PyTorch MergeBot
parent cfcd0e1fe9
commit 7e13e7dd7e
9 changed files with 34 additions and 95 deletions

View File

@ -24,22 +24,6 @@ void throwNullDataPtrError() {
"https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html");
}
// NOTE: [FakeTensor.data_ptr deprecation]
// Today:
// - FakeTensor.data_ptr errors out in torch.compile.
// - FakeTensor.data_ptr raises the following deprecation warning otherwise.
// - the following deprecation warning is only for FakeTensor (for now).
// In the future we can consider extending to more wrapper Tensor subclasses.
void warnDeprecatedDataPtr() {
TORCH_WARN_ONCE(
"Accessing the data pointer of FakeTensor is deprecated and will error in "
"PyTorch 2.5. This is almost definitely a bug in your code and will "
"cause undefined behavior with subsystems like torch.compile. "
"Please wrap calls to tensor.data_ptr() in an opaque custom op; "
"If all else fails, you can guard accesses to tensor.data_ptr() on "
"isinstance(tensor, FakeTensor).")
}
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
// Allowlist verification.
// Only if the devicetype is in the allowlist,

View File

@ -17,7 +17,6 @@
namespace c10 {
C10_API void throwNullDataPtrError();
C10_API void warnDeprecatedDataPtr();
// A storage represents the underlying backing data buffer for a
// tensor. This concept was inherited from the original Torch7
@ -131,9 +130,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
if (throw_on_mutable_data_ptr_) {
throwNullDataPtrError();
}
if (warn_deprecated_on_mutable_data_ptr_) {
warnDeprecatedDataPtr();
}
maybe_materialize_cow();
}
return data_ptr_;
@ -166,9 +162,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
if (throw_on_mutable_data_ptr_) {
throwNullDataPtrError();
}
if (warn_deprecated_on_mutable_data_ptr_) {
warnDeprecatedDataPtr();
}
maybe_materialize_cow();
}
return data_ptr_.mutable_get();
@ -253,11 +246,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
refresh_has_data_ptr_check();
}
void set_warn_deprecated_on_mutable_data_ptr() {
warn_deprecated_on_mutable_data_ptr_ = true;
refresh_has_data_ptr_check();
}
protected:
// materialize_cow_storage needs to call set_data_ptr_no_materlize_cow
friend void c10::impl::cow::materialize_cow_storage(StorageImpl& storage);
@ -273,8 +261,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
private:
void refresh_has_data_ptr_check() {
has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ ||
warn_deprecated_on_mutable_data_ptr_;
has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_;
}
inline bool is_cow() const {
@ -301,8 +288,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
bool has_data_ptr_check_ = false;
// If we should throw when mutable_data_ptr() or mutable_data() is called.
bool throw_on_mutable_data_ptr_ = false;
// If we warn when mutable_data_ptr() or mutable_data() is called.
bool warn_deprecated_on_mutable_data_ptr_ = false;
Allocator* allocator_;
impl::PyObjectSlot pyobj_slot_;
};

View File

@ -1112,14 +1112,19 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
self.assertEqual(z.grad, z_opt.grad)
def test_data_ptr_access_copy(self):
import torch._functorch.config as _config
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
with FakeTensorMode():
x = torch.randn(3)
y = copy.copy(x)
with FakeTensorMode():
x = torch.randn(3)
y = copy.copy(x)
self.assertEqual(y.shape, x.shape)
def test_data_ptr_access_in_eager_errors(self):
with FakeTensorMode():
x = torch.randn(3)
with self.assertRaisesRegex(
RuntimeError, "Cannot access data pointer of Tensor"
):
x.data_ptr()
def test_data_ptr_access_fails_in_forward(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch.library.define("mylib::foo", "(Tensor x) -> Tensor", lib=lib)

View File

@ -1588,7 +1588,6 @@ def _functionalization_reapply_views_tls() -> _bool: ...
def _only_lift_cpu_tensors() -> _bool: ...
def _set_only_lift_cpu_tensors(value: _bool) -> None: ...
def _set_throw_on_mutable_data_ptr(tensor: Tensor) -> None: ...
def _set_warn_deprecated_on_mutable_data_ptr(tensor: Tensor) -> None: ...
class DispatchKey(Enum):
${dispatch_key_hints}

View File

@ -304,15 +304,13 @@ class OutputGraph:
# In export mode, we force the shape_env to strictly disallow any constraining
# of the user marked dynamic dims
import torch._functorch.config as _config
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
fake_mode = torch._subclasses.FakeTensorMode(
shape_env=shape_env,
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
allow_non_fake_inputs=True if self.export else False,
export=self.export,
)
fake_mode = torch._subclasses.FakeTensorMode(
shape_env=shape_env,
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
allow_non_fake_inputs=True if self.export else False,
export=self.export,
)
self.tracing_context: TracingContext = TracingContext(fake_mode)
self.init_ambient_guards()
@ -1354,13 +1352,10 @@ class OutputGraph:
self.call_cleanup_hooks()
old_fake_mode = self.tracing_context.fake_mode
if not self.export:
import torch._functorch.config as _config
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
backend_fake_mode = torch._subclasses.FakeTensorMode(
shape_env=old_fake_mode.shape_env,
)
# TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
backend_fake_mode = torch._subclasses.FakeTensorMode(
shape_env=old_fake_mode.shape_env,
)
# TODO(voz): Ostensibily, this should be scoped and
# restore back to old_fake_mode, but doing so currently violates
# a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode

View File

@ -150,11 +150,6 @@ visualize_memory_budget_pareto = (
# cost of some performance
aggressive_recomputation = False
# If FakeTensor.data_ptr() should error.
# This option is independent of AOTAutograd and torch.compile, but our policy
# is to turn it off during torch.compile.
fake_tensor_allow_unsafe_data_ptr_access = True
# Unlifts effect tokens from the inputs/outputs in the traced graph and instead
# inserts make_token/sink_token calls in the graph to create tokens and then
# sink them at the end. Note that this means the graph is no longer functional

View File

@ -673,10 +673,7 @@ class FakeTensor(Tensor):
dispatch_device=True,
device_for_backend_keys=device,
)
if not fake_mode._allow_unsafe_data_ptr_access:
torch._C._set_throw_on_mutable_data_ptr(self)
else:
torch._C._set_warn_deprecated_on_mutable_data_ptr(self)
torch._C._set_throw_on_mutable_data_ptr(self)
assert elem.device.type == "meta", elem.device.type
device = device if isinstance(device, torch.device) else torch.device(device)
@ -1132,9 +1129,6 @@ class FakeTensorMode(TorchDispatchMode):
# places where we unconditionally allow scalar outputs, TO BE REMOVED
self.allow_scalar_outputs = False
self._allow_unsafe_data_ptr_access = (
torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access
)
self.allow_meta = torch._functorch.config.fake_tensor_allow_meta
self.cache_enabled = (
torch._dynamo.config.fake_tensor_cache_enabled

View File

@ -965,19 +965,6 @@ void initDispatchBindings(PyObject* module) {
->set_throw_on_mutable_data_ptr();
});
// Invariant: you must ONLY call this with FakeTensors.
m.def("_set_warn_deprecated_on_mutable_data_ptr", [](const at::Tensor& t) {
if (!t.unsafeGetTensorImpl()->has_storage()) {
// If the Tensor doesn't have a storage, then accessing .data_ptr()
// will already raise an error.
return;
}
t.unsafeGetTensorImpl()
->storage()
.unsafeGetStorageImpl()
->set_warn_deprecated_on_mutable_data_ptr();
});
m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);

View File

@ -1880,15 +1880,12 @@ class _MakefxTracer:
fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
if fake_tensor_mode is None:
import torch._functorch.config as _config
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=True,
allow_non_fake_inputs=self._allow_non_fake_inputs,
shape_env=ShapeEnv(),
static_shapes=True,
)
fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=True,
allow_non_fake_inputs=self._allow_non_fake_inputs,
shape_env=ShapeEnv(),
static_shapes=True,
)
self.fake_tensor_mode = fake_tensor_mode
elif self.tracing_mode == "symbolic":
import torch._dynamo
@ -1896,14 +1893,12 @@ class _MakefxTracer:
fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
if fake_tensor_mode is None:
shape_env = ShapeEnv()
import torch._functorch.config as _config
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=False,
allow_non_fake_inputs=self._allow_non_fake_inputs,
shape_env=shape_env,
)
fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=False,
allow_non_fake_inputs=self._allow_non_fake_inputs,
shape_env=shape_env,
)
assert (
fake_tensor_mode.shape_env is not None
), "shape_env should be set if tracing with 'symbolic'"