Compare commits

...

2 Commits

Author SHA1 Message Date
72471bb36a Guards debug logging for recursive 2025-09-03 13:46:11 -07:00
09e3781855 [dynamo][guards] Skip guard on _compiled_call_impl
ghstack-source-id: 0c73c70c9b9a8f01e4ad18336abfc7b46e4d4c87
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162094
2025-09-03 13:31:22 -07:00
2 changed files with 62 additions and 14 deletions

View File

@ -1951,6 +1951,18 @@ class GuardBuilder(GuardBuilderBase):
def NONE_MATCH(self, guard: Guard) -> None:
# checks `val is None`
# Skip the guard on torch.nn.Module._compiled_call_impl flag. This flag
# is used when one uses mod.compile.
source = guard.originating_source
if (
isinstance(source, DictGetItemSource)
and source.index == "_compiled_call_impl"
and isinstance(source.base, TypeDictSource)
and self.get(source.base.base.name()) is torch.nn.Module
):
return
ref = self.arg_ref(guard)
val = self.get(guard.name)
assert val is None

View File

@ -2470,7 +2470,7 @@ void stop_recording_dict_pointers(
PyObject* value,
bool result);
bool is_recording_dict_pointers(RootGuardManager* root);
void record_dict_pointer(RootGuardManager* root, PyObject* dict_pointer);
void record_dict_pointer(RootGuardManager* root, PyObject* dict_pointer, std::string source);
void record_tensor_pointer(RootGuardManager* root, PyObject* tensor_pointer);
GuardManager* clone_guard_manager(
@ -2482,7 +2482,7 @@ void add_relational_guard_resetter_to_cloned_root(
std::shared_ptr<RelationalGuard> guard);
std::shared_ptr<RelationalGuard> get_no_tensor_aliasing_guard(
RootGuardManager* _root);
// std::string get_compile_id(RootGuardManager* root);
std::string get_compile_id(RootGuardManager* root);
struct WeakEntry {
PyObject* wr; // weakref
@ -2750,8 +2750,10 @@ class GuardManager {
// tag safe optimizations
void stash_dict_pointers(
PyObject* value,
std::vector<std::pair<PyObject*, uint64_t>> dict_pointers) {
std::vector<std::pair<PyObject*, uint64_t>> dict_pointers,
std::vector<std::pair<PyObject*, std::string>> dict_pointers_sources) {
_dict_pointers[value] = dict_pointers;
_dict_pointers_sources[value] = dict_pointers_sources;
}
void stash_tensor_pointers(
@ -2871,18 +2873,40 @@ class GuardManager {
}
bool check_dict_pointer_tags(PyObject* value) {
std::cout << "RECURSIVE_DICT_TAG: \nStarting\n";
if (_dict_callback_installed) {
// This means that for 3.12+, there are callbacks watching dict pointers.
return true;
}
int count = 0;
// for (auto& kv : _dict_pointers[value]) {
// PyObject* dict_pointer = kv.first;
// uint64_t old_tag = kv.second;
// uint64_t new_tag = get_dict_version_unchecked(dict_pointer);
// std::string source = _dict_pointers_sources[value][count].second;
// count++;
// if (old_tag != new_tag) {
// std::cout << "RECURSIVE_DICT_TAG_ALL: FAILING " << get_compile_id(_root) << ": " << source << ": " << dict_pointer << ", " << old_tag << " and " << new_tag << "\n";
// }
// }
count = 0;
for (auto& kv : _dict_pointers[value]) {
PyObject* dict_pointer = kv.first;
uint64_t old_tag = kv.second;
uint64_t new_tag = get_dict_version_unchecked(dict_pointer);
std::string source = _dict_pointers_sources[value][count].second;
count++;
if (old_tag != new_tag) {
std::cout << "RECURSIVE_DICT_TAG: FAILING " << get_compile_id(_root) << ": " << source << ": " << dict_pointer << ", " << old_tag << " and " << new_tag << "\n";
std::cout << "RECURSIVE_DICT_TAG: FAILING VALUE " << py::str(dict_pointer) << std::endl;
std::cout << "RECURSIVE_DICT_TAG: Ending\n";
return false;
} else {
std::cout << "RECURSIVE_DICT_TAG: PASSING " << get_compile_id(_root) << ": " << source << ": " << dict_pointer << ", " << old_tag << " and " << new_tag << "\n";
}
}
std::cout << "RECURSIVE_DICT_TAG: Ending\n\n";
return true;
}
@ -2988,7 +3012,7 @@ class GuardManager {
} else if (_is_tag_safe && is_recording_dict_pointers(_root)) {
// This is a tag safe node, record the dict pointer
if (_is_dict) {
record_dict_pointer(_root, value);
record_dict_pointer(_root, value, get_source());
} else if (_has_no_tensor_aliasing_guard) {
record_tensor_pointer(_root, value);
}
@ -3390,6 +3414,8 @@ class GuardManager {
bool _disable_dict_tag_matching = false;
std::unordered_map<PyObject*, std::vector<std::pair<PyObject*, uint64_t>>>
_dict_pointers;
std::unordered_map<PyObject*, std::vector<std::pair<PyObject*, std::string>>>
_dict_pointers_sources;
std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
std::vector<WeakEntry> _tag_safe_entries;
@ -3632,9 +3658,9 @@ class RootGuardManager : public GuardManager {
_compile_id = compile_id;
}
// std::string get_compile_id() {
// return _compile_id;
// }
std::string get_compile_id() {
return _compile_id;
}
private:
// Reset the state of all the relational guards on failure.
@ -3655,6 +3681,7 @@ class RootGuardManager : public GuardManager {
_is_recording_dict_pointers = false;
_current_tag_safe_root = nullptr;
_recorded_dict_pointers.clear();
_recorded_dict_pointers_sources.clear();
_recorded_tensor_pointers.clear();
}
@ -3662,7 +3689,7 @@ class RootGuardManager : public GuardManager {
if (result) {
// Stash the pointers only if the guard eval passed
_current_tag_safe_root->stash_dict_pointers(
value, _recorded_dict_pointers);
value, _recorded_dict_pointers, _recorded_dict_pointers_sources);
_current_tag_safe_root->stash_tensor_pointers(
value, _recorded_tensor_pointers);
}
@ -3673,9 +3700,17 @@ class RootGuardManager : public GuardManager {
return _is_recording_dict_pointers;
}
void record_dict_pointer(PyObject* dict_pointer) {
void record_dict_pointer(PyObject* dict_pointer, std::string source) {
_recorded_dict_pointers.push_back(
std::make_pair(dict_pointer, get_dict_version_unchecked(dict_pointer)));
std::string mro = "__mro__";
if (source.find(mro) != std::string::npos) {
std::cout << "RECORDING " << source << " " << py::str(dict_pointer) << "\n";
}
_recorded_dict_pointers_sources.push_back(
std::make_pair(dict_pointer, source)
);
}
void record_tensor_pointer(PyObject* tensor_pointer) {
@ -3737,6 +3772,7 @@ class RootGuardManager : public GuardManager {
bool _is_recording_dict_pointers{false};
GuardManager* _current_tag_safe_root{nullptr};
std::vector<std::pair<PyObject*, uint64_t>> _recorded_dict_pointers;
std::vector<std::pair<PyObject*, std::string>> _recorded_dict_pointers_sources;
std::vector<PyObject*> _recorded_tensor_pointers;
};
@ -4167,8 +4203,8 @@ bool is_recording_dict_pointers(RootGuardManager* root) {
return root->is_recording_dict_pointers();
}
void record_dict_pointer(RootGuardManager* root, PyObject* dict_pointer) {
root->record_dict_pointer(dict_pointer);
void record_dict_pointer(RootGuardManager* root, PyObject* dict_pointer, std::string source) {
root->record_dict_pointer(dict_pointer, source);
}
void record_tensor_pointer(RootGuardManager* root, PyObject* tensor_pointer) {
@ -4180,9 +4216,9 @@ std::shared_ptr<RelationalGuard> get_no_tensor_aliasing_guard(
return _root->get_no_tensor_aliasing_guard();
}
// std::string get_compile_id(RootGuardManager* root) {
// return root->get_compile_id();
// }
std::string get_compile_id(RootGuardManager* root) {
return root->get_compile_id();
}
class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
public: