#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::jit::tracer { //////////////////////////////////////////////////////////////////////////////// // Recording the traces //////////////////////////////////////////////////////////////////////////////// namespace detail { template static void genericAddInput(Node* n, T value) { Value* v = n->owningGraph()->insertConstant(value); recordSourceLocation(v->node()); n->addInput(v); } template static void genericAddOptionalInput( Node* n, const char* name, const std::optional& value) { if (value) { jit::tracer::addInputs(n, name, *value); } else { Graph* g = n->owningGraph(); Value* none = g->insertNode(g->createNone())->output(); n->addInput(none); } } template static void badArgType(const T& v) { TORCH_CHECK( false, "Found an unsupported argument type in the JIT tracer: ", c10::demangle_type(), ". File a bug report."); } static thread_local std::shared_ptr tracing_state; } // namespace detail static std::atomic tracer_state_warn_mode{true}; std::atomic& getTracerStateWarnMode() { return tracer_state_warn_mode; } std::function pauseTracing() { std::shared_ptr state = getTracingState(); tracer::setTracingState(nullptr); return [state]() { tracer::setTracingState(state); }; } void delValueTrace(const IValue& var) { getTracingState()->delValue(var); } void TracingState::delValue(const IValue& var) { for (const auto i : c10::irange(env_stack.size())) { auto& value_map = env_stack.at(env_stack.size() - 1 - i); auto it = value_map.find(var); if (it == value_map.end()) { continue; } value_map.erase(it); } } // Given a IValue 'var', return the 'node' which represents the instruction // which computes the value of this variable in the IR. // Here, we interpret untraced variables as constants that are just embedded // in the graph. This is useful to handle code which does things like this // (from torch.autograd.variable, now moved to C++): // // def mm(self, matrix): // output = Variable(self.data.new(self.data.size(0), matrix.data.size(1))) // return Addmm.apply(output, self, matrix, 0, 1, True) // // Here, mm fakes up a dummy variable with uninitialized data to do an inplace // update on, but subsequently ignores it because the alpha scaling factor is // zero. This is one of the cases where a Variable can be created inside of a // trace, and if we treat it as a constant, everything will work out. Value* getValueTrace(const IValue& var) { return getTracingState()->getValue(var); } static Value* getOptTensorValueTrace(const std::optional& var) { return getValueTrace(IValue(var)); } Value* TracingState::getValue(const IValue& var) { // allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...] // arguments if (var.isTensorList()) { return graph ->insertNode(graph->createList( TensorType::get(), fmap( var.toTensorVector(), [&](const IValue& val) { return getValue(val); }))) ->output(); } else if (var.isTuple()) { return graph ->insertNode(graph->createTuple(fmap( var.toTupleRef().elements(), [&](const IValue& val) { return getValue(val); }))) ->output(); } else if (var.isGenericDict()) { auto dict = var.toGenericDict(); TypePtr key_type = dict.keyType(); TypePtr value_type = dict.valueType(); std::vector keys; std::vector values; for (const auto& entry : dict) { keys.emplace_back(getValue(entry.key())); values.emplace_back(getValue(entry.value())); } auto dict_node = graph->createDict(key_type, value_type, keys, values); return graph->insertNode(dict_node)->output(); } if (var.isTensor()) { auto& ten = var.toTensor(); if (!ten.defined()) { Node* n = graph->createNone(); return graph->insertNode(n)->output(); } for (const auto i : c10::irange(env_stack.size())) { auto& value_map = env_stack.at(env_stack.size() - 1 - i); auto it = value_map.find(var); if (it == value_map.end()) { continue; } if (!it->second->hasDebugName()) { auto unique_name = getTracingState()->lookup_var_name_fn(ten); if (!unique_name.empty()) { it->second->setDebugName(unique_name); } } return it->second; } // Didn't find it. Bake in a constant if (ten.requires_grad()) { pauseTracing(); std::ostringstream oss; oss << "Cannot insert a Tensor that requires grad as a constant. " << "Consider making it a parameter or input, or detaching the gradient\n" << "Tensor:\n" << ten; throw std::runtime_error(oss.str()); } Value* constant = graph->insertConstant(ten); recordSourceLocation(constant->node()); constant->inferTypeFrom(ten); auto it = env_stack.back().emplace(var, constant); return it.first->second; } else if (var.isFuture() || var.isObject()) { for (const auto i : c10::irange(env_stack.size())) { auto& future_map = env_stack.at(env_stack.size() - 1 - i); auto it = future_map.find(var); if (it == future_map.end()) { continue; } return it->second; } // Find torchbind classes if (isCustomClass(var)) { auto obj = Object(var.toObject()); auto qualname = obj.type()->name(); auto custom_class_type = getCustomClass(qualname->qualifiedName()); if (custom_class_type) { auto capsule = var.toObject()->getAttr("capsule"); for (const auto i : c10::irange(env_stack.size())) { auto& value_map = env_stack.at(env_stack.size() - 1 - i); auto it = value_map.find(capsule); if (it == value_map.end()) { continue; } return it->second; } } } std::ostringstream oss; if (var.isFuture()) { oss << "Tried to trace Future or Object that the tracer was not aware of."; } else { oss << "Tried to trace " << var << " but it is not part of the active trace. Modules that are called during a trace" << " must be registered as submodules of the thing being traced."; } throw std::runtime_error(oss.str()); } else { // If the values are non-tensors, we try to create constants // and bake those constants into the traced graph auto constant = tryInsertConstant(*graph, var); if (constant) { recordSourceLocation(constant.value()->node()); return *constant; } std::ostringstream os; os << "Tracer cannot get value trace for type " << var.tagKind() << ". " << "The below value could not be materialized as a constant:\n" << var; throw std::runtime_error(os.str()); } } bool TracingState::hasValue(const IValue& var) const { for (const auto& frame : env_stack) { if (frame.count(var)) { return true; } } return false; } Value* TracingState::getOutput(const IValue& iv, size_t i) { bool tracing_mode_strict = getTracingState()->strict; if (iv.isTensor()) { const at::Tensor& var = iv.toTensor(); if (!var.defined()) { Node* n = graph->createNone(); return graph->insertNode(n)->output(); } auto& value_map = getTracingState()->env_stack.back(); auto it = value_map.find(iv); if (it == value_map.end()) { std::ostringstream os; os << "output " << i << " (" << var << ") of traced region did not have observable " << "data dependence with trace inputs; this probably indicates your " "program " << "cannot be understood by the tracer."; throw std::runtime_error(os.str()); } return it->second; } else if (iv.isTensorList()) { if (tracing_mode_strict) { tracer::warn( "Encountering a list at the output of the tracer", STRICT_TRACER_MSG); } return graph ->insertNode(graph->createList( TensorType::get(), fmap( iv.toTensorVector(), [&](const IValue& ival) { return getOutput(ival, i); }))) ->output(); } else if (iv.isTuple()) { const auto& tuple = iv.toTupleRef().elements(); auto tuple_node = graph->createTuple( fmap(tuple, [&](const IValue& ival) { return getOutput(ival, i); })); graph->insertNode(tuple_node); return tuple_node->output(); } else if (iv.isGenericDict()) { if (tracing_mode_strict) { throw std::runtime_error( "Encountering a dict at the output of the tracer" + std::string(STRICT_TRACER_MSG)); } auto dict = iv.toGenericDict(); TypePtr key_type = dict.keyType(); TypePtr value_type = dict.valueType(); bool key_type_valid = key_type->isSubtypeOf(*StringType::get()) || key_type->isSubtypeOf(*TensorType::get()); bool value_type_valid = value_type->isSubtypeOf(*TensorType::get()); // Support tuple values that contain only tensors if (value_type->isSubtypeOf(*AnyTupleType::get())) { value_type_valid = true; for (const auto& type : value_type->containedTypes()) { if (!type->isSubtypeOf(*TensorType::get())) { value_type_valid = false; break; } } } if (!key_type_valid || !value_type_valid) { std::ostringstream os; os << "output " << i << " (" << dict << ") of traced region " << "cannot be understood by the tracer, only outputs matching" << "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] " << "can be a dictionary output of a traced function"; throw std::runtime_error(os.str()); } std::vector keys; std::vector values; for (const auto& entry : dict) { keys.emplace_back(getValue(entry.key())); values.emplace_back(getOutput(entry.value(), i)); } auto dict_node = graph->createDict(key_type, value_type, keys, values); graph->insertNode(dict_node); return dict_node->output(); } else { TORCH_CHECK( false, "Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions"); } } Node* TracingState::createNode(c10::Symbol op_name, size_t num_outputs) { return graph->create(op_name, num_outputs); } void TracingState::insertNode(Node* node) { graph->insertNode(node); } // XXX: this function mutates input static IValue addInput( const std::shared_ptr& state, const IValue& input, const TypePtr& type, Value* value) { value->setType(type); if (type->isSubtypeOf(*TensorType::get())) { auto input_tensor = input.toTensor(); auto const& name = input_tensor.name(); if (state->hasValue(input)) { input_tensor = input_tensor.view(input_tensor.sizes()); } if (!value->hasDebugName()) { value->setDebugName(name); } state->setValue(input_tensor, value); return input_tensor; } else if (auto tuple_type = type->cast()) { auto unpack_node = state->graph->insertNode(state->graph->createTupleUnpack(value)); auto elem_values = unpack_node->outputs(); auto elem_types = tuple_type->elements(); auto tuple = input.toTuple(); const auto& elems = tuple->elements(); size_t num_elems = elems.size(); AT_ASSERT( elem_values.size() == num_elems && elem_types.size() == num_elems); for (const auto i : c10::irange(num_elems)) { tuple->unsafeSetElement( i, addInput(state, elems.at(i), elem_types[i], elem_values[i])); } return tuple; } else if (auto dict_type = type->cast()) { auto dict = input.toGenericDict(); // Unpack the list values statically for (const auto& entry : dict) { const IValue& key = entry.key(); auto static_key = state->graph->insertConstant(key); auto static_value = state->graph->insert(aten::__getitem__, {value, static_key}); recordSourceLocation(static_value->node()); dict.insert_or_assign( entry.key(), addInput( state, entry.value(), dict_type->getValueType(), static_value)); } return dict; } else if (auto list_type = type->cast()) { size_t num_elems = input.isList() ? input.toListRef().size() : input.toTensorVector().size(); auto list_unpack = state->graph->insertNode( state->graph->createListUnpack(value, num_elems)); auto unpack_outputs = list_unpack->outputs(); if (input.isTensorList()) { auto elems = input.toTensorList(); for (const auto i : c10::irange(num_elems)) { elems[i] = addInput( state, elems.get(i), list_type->getElementType(), unpack_outputs[i]) .toTensor(); } return elems; } else { auto elems = input.toList(); for (const auto i : c10::irange(num_elems)) { elems[i] = addInput( state, elems.get(i), list_type->getElementType(), unpack_outputs[i]); } return elems; } } else { TORCH_CHECK( false, "Only tensors or (possibly nested) dict or tuples of tensors can be " "inputs to traced functions. Got ", type->repr_str()); } } static void gatherParametersAndBuffers( const std::shared_ptr& state, Value* self_value, const Module& self, const std::string& prefix) { Graph& g = *self_value->owningGraph(); state->setValue(self._ivalue(), self_value); auto self_ty = self.type(); for (const NameValue& s : self.named_attributes(/*recurse=*/false)) { auto qualname = prefix + "." + s.name; Value* trace_get_attr = g.insertNode(g.create(prim::TracedAttr)) ->s_(attr::scope, qualname) ->output() ->setType(s.value.type()); if (s.value.type()->isSubtypeOf(*TensorType::get())) { addInput(state, s.value, s.value.type(), trace_get_attr); } if (isCustomClass(s.value)) { tracer::setValueTrace(s.value, trace_get_attr); } auto attr_type = self_ty->getAttribute(s.name); // Skipping Parameters and Buffers that are behind an `InterfaceType` // because it is illegal for InterfaceType to expose any attribute. // And these attributes should never be used/exposed outside of // InterfaceType'd module anyway. if (attr_type->is_module() && attr_type->kind() != TypeKind::InterfaceType) { gatherParametersAndBuffers( state, trace_get_attr, Module(s.value.toObject()), qualname); } } } std::pair, Stack> trace( Stack inputs, const std::function& traced_fn, std::function var_name_lookup_fn, bool strict, bool force_outplace, Module* self, const std::vector& argument_names) { try { // Start tracing, treating 'inputs' as inputs to the trace, which can be // varied on subsequent invocations of the trace. Any other variables // will be treated as constants. if (isTracing()) { TORCH_CHECK(false, "Tracing can't be nested"); } auto state = std::make_shared(); setTracingState(state); // if we are a module, then make sure the modules parameters are in the map // and mapped to accesses to the self object if (self) { Value* self_value = state->graph->insertInput(0, "self")->setType( self->_ivalue()->type()); gatherParametersAndBuffers(state, self_value, *self, {"__module"}); } // When enough argument name hints are provided, use them as debug names // for traced function/modules. // Here argument_names is allowed to have more names than needed because // some arguments may have valid default values, therefore they don't need // example inputs. if (argument_names.size() >= inputs.size()) { for (size_t i = 0, e = inputs.size(); i < e; ++i) { IValue& input = inputs[i]; input = addInput( state, input, input.type(), state->graph->addInput(argument_names[i])); } } else { for (IValue& input : inputs) { input = addInput(state, input, input.type(), state->graph->addInput()); } } auto graph = state->graph; getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn); getTracingState()->strict = strict; getTracingState()->force_outplace = force_outplace; // Invoke the traced function auto out_stack = traced_fn(inputs); // Exit a trace, treating 'out_stack' as the outputs of the trace. These // are the variables whose values will be computed upon subsequent // invocations of the trace. size_t i = 0; for (auto& output : out_stack) { // NB: The stack is in "reverse" order, so when we pass the diagnostic // number we need to flip it based on size. state->graph->registerOutput( state->getOutput(output, out_stack.size() - i)); i++; } setTracingState(nullptr); if (getInlineEverythingMode()) { Inline(*graph); } FixupTraceScopeBlocks(graph, self); NormalizeOps(graph); return {state, out_stack}; } catch (...) { tracer::abandon(); throw; } } // Abort tracing. Used to reset the state in case of errors. void abandon() { setTracingState(nullptr); } void setValueTrace(const IValue& v, Value* value) { return getTracingState()->setValue(v, value); } void TracingState::setValue(const IValue& v, Value* value) { if (v.isTensor()) { auto& var = v.toTensor(); AT_ASSERT(var.defined()); env_stack.back()[v] = value; // If the value comes from a CallFunction or CallMethod, it may not have // shape information attached. For debuggability, we enhance the type // information by assigning the concrete value's type to the jit::Value. if (auto tensor_type = value->type()->cast()) { if (!tensor_type->isComplete()) { value->inferTypeFrom(var); } } } else if (v.isTensorList()) { auto outputs = v.toTensorList(); Node* unpack_node = graph->insertNode(graph->createListUnpack(value, outputs.size())); for (const auto i : c10::irange(outputs.size())) { setValue(outputs.get(i), unpack_node->outputs()[i]); } } else if (v.isTuple()) { const auto& outputs = v.toTupleRef().elements(); Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value)); for (const auto i : c10::irange(outputs.size())) { setValue(outputs[i], unpack_node->outputs()[i]); } } else if (v.isList()) { auto elements = v.toListRef(); Node* unpack_node = graph->insertNode(graph->createListUnpack(value, elements.size())); for (const auto i : c10::irange(elements.size())) { setValue(elements[i], unpack_node->outputs()[i]); } } else if (isCustomClass(v)) { auto capsule = v.toObject()->getAttr("capsule"); env_stack.back()[capsule] = value; } else if (v.isFuture() || v.isObject()) { env_stack.back()[v] = value; } else if (v.isGenericDict()) { auto dict = v.toGenericDict(); TypePtr key_type = dict.keyType(); TypePtr value_type = dict.valueType(); for (const auto& entry : dict) { auto static_key = graph->insertConstant(entry.key()); auto static_value = graph->insert(aten::__getitem__, {value, static_key}); setValue(entry.value(), static_value); } } else { std::ostringstream os; os << "Tracer cannot set value trace for type " << v.tagKind() << ". " << "Supported types are tensor, tensor list, and tuple of tensors."; throw std::runtime_error(os.str()); } } void addInputs(Node* n, const char* name, int64_t value) { using ArgumentStash = jit::tracer::ArgumentStash; if (ArgumentStash::hasValue(name)) { Value* v = ArgumentStash::popValue(name); n->addInput(v); } else { detail::genericAddInput(n, value); } } void addInputs(Node* n, const char* name, const c10::SymInt& value) { addInputs(n, name, value.guard_int(__FILE__, __LINE__)); } void addInputs(Node* n, const char* name, std::optional value) { using ArgumentStash = jit::tracer::ArgumentStash; if (ArgumentStash::hasValue(name)) { Value* v = ArgumentStash::popValue(name); n->addInput(v); } else if (value) { detail::genericAddInput(n, *value); } else { Graph* g = n->owningGraph(); Value* none = g->insertNode(g->createNone())->output(); n->addInput(none); } } void addInputs(Node* n, const char* name, bool value) { detail::genericAddInput(n, value); } void addInputs(Node* n, const char* name, const std::optional& value) { detail::genericAddOptionalInput(n, name, value); } void addInputs(Node* n, const char* name, double value) { detail::genericAddInput(n, value); } void addInputs(Node* n, const char* name, const std::optional& value) { detail::genericAddOptionalInput(n, name, value); } void addInputs(Node* n, const char* name, const at::Scalar& value) { using ArgumentStash = jit::tracer::ArgumentStash; if (ArgumentStash::hasValue(name)) { Value* v = ArgumentStash::popValue(name); n->addInput(v); } else { detail::genericAddInput(n, value); } } void addInputs( Node* n, const char* name, const std::optional& value) { detail::genericAddOptionalInput(n, name, value); } void addInputs(Node* n, const char* name, const std::string_view value) { detail::genericAddInput(n, std::string(value)); } void addInputs( Node* n, const char* name, const std::optional& value) { detail::genericAddOptionalInput(n, name, value); } void addInputs(Node* n, const char* name, const at::Tensor& value) { n->addInput(getValueTrace(value)); } void addInputs( Node* n, const char* name, const std::optional& value) { detail::genericAddOptionalInput(n, name, value); } void addInputs( Node* n, const char* name, const std::optional& value) { Graph* g = n->owningGraph(); if (value.has_value() && value->defined()) { detail::genericAddInput(n, *value); } else { Value* undef_gen = g->insertNode(g->createNone())->output(); n->addInput(undef_gen); } } void addInputs(Node* n, const char* name, at::Device value) { detail::genericAddInput(n, value); } void addInputs(Node* n, const char* name, c10::Stream stream) { detail::genericAddInput(n, c10::IValue(stream)); } void addInputs(Node* n, const char* name, at::Layout value) { detail::genericAddInput(n, static_cast(value)); } void addInputs(Node* n, const char* name, at::ScalarType value) { detail::genericAddInput(n, static_cast(value)); } void addInputs(Node* n, const char* name, at::MemoryFormat value) { detail::genericAddInput(n, static_cast(value)); } void addInputs( Node* n, const char* name, const std::optional& value) { detail::genericAddOptionalInput(n, name, value); } void addInputs( Node* n, const char* name, const std::optional& value) { detail::genericAddOptionalInput(n, name, value); } void addInputs( Node* n, const char* name, const std::optional& value) { detail::genericAddOptionalInput(n, name, value); } void addInputs( Node* n, const char* name, std::optional value) { TORCH_CHECK(false, "NYI: Named tensors are not supported with the tracer"); } void addInputs( Node* n, const char* name, const std::optional& value) { detail::genericAddOptionalInput(n, name, value); } void addInputs( Node* n, const char* name, at::ArrayRef value, bool allow_undefined) { addInputs(n, name, at::ITensorListRef(value), allow_undefined); } void addInputs( Node* n, const char* name, const std::vector& value, bool allow_undefined) { addInputs(n, name, at::ITensorListRef(value), allow_undefined); } void addInputs( Node* n, const char* name, at::ITensorListRef value, bool allow_undefined) { Graph* g = n->owningGraph(); Node* list_node = nullptr; if (allow_undefined) { // if allow undefined, we create a list of optional tensors list_node = g->insertNode( g->createList(OptionalType::ofTensor(), fmap(value, getValueTrace))); } else { list_node = g->insertNode( g->createList(TensorType::get(), fmap(value, getValueTrace))); } n->addInput(list_node->output()); } TORCH_API void addInputs( Node* n, const char* name, const List>& value) { Graph* g = n->owningGraph(); Node* list_node = nullptr; list_node = g->insertNode(g->createList( OptionalType::ofTensor(), fmap(value, getOptTensorValueTrace))); n->addInput(list_node->output()); } void addInputs( Node* n, const char* name, ArrayRef> value, const ClassTypePtr& class_type) { Graph* g = n->owningGraph(); Node* list_node = g->insertNode(g->createList(class_type, fmap(value, getValueTrace))); n->addInput(list_node->output()); } void addInputs(Node* n, const char* name, at::IntArrayRef value) { using ArgumentStash = jit::tracer::ArgumentStash; std::vector info = ArgumentStash::hasIntArrayRef(name) ? ArgumentStash::popIntArrayRef(name) : ArgumentStash::IntArrayRefTrace(value.size()); auto& g = getTracingState()->graph; for (const auto i : c10::irange(info.size())) { if (info[i] != nullptr) continue; info[i] = g->insertConstant(value[i]); recordSourceLocation(info[i]->node()); } for (jit::Value* v : info) { if (*v->type() != *jit::IntType::get()) { throw std::runtime_error( "Type mismatch in setposattr for IntArrayRef. Check that your program " "is valid without tracing, and please file a bug report if it is."); } } n->addInput( g->insertNode(g->createList(jit::IntType::get(), info))->output()); } void addInputs(Node* n, const char* name, c10::SymIntArrayRef value) { addInputs(n, name, C10_AS_INTARRAYREF_SLOW(value)); } void addInputs(Node* n, const char* name, std::optional value) { addInputs( n, name, value.has_value() ? std::make_optional(value->guard_int(__FILE__, __LINE__)) : std::nullopt); } void addInputs( Node* n, const char* name, const std::optional& opt_value) { detail::genericAddOptionalInput(n, name, opt_value); } void addInputs( Node* n, const char* name, const at::OptionalIntArrayRef& opt_value) { if (opt_value.has_value()) { jit::tracer::addInputs(n, name, *opt_value); } else { Graph* g = n->owningGraph(); Value* none = g->insertNode(g->createNone())->output(); n->addInput(none); } } void addInputs( Node* n, const char* name, const at::OptionalSymIntArrayRef& opt_value) { if (opt_value.has_value()) { jit::tracer::addInputs(n, name, *opt_value); } else { Graph* g = n->owningGraph(); Value* none = g->insertNode(g->createNone())->output(); n->addInput(none); } } void addInputs(Node* n, const char* name, ArrayRef value) { std::vector info; auto& g = getTracingState()->graph; for (double elt : value) { info.push_back(g->insertConstant(elt)); recordSourceLocation(info.back()->node()); } n->addInput( g->insertNode(g->createList(jit::FloatType::get(), info))->output()); } void addInputs( Node* n, const char* name, const std::optional>& opt_value) { detail::genericAddOptionalInput(n, name, opt_value); } void addInputs( Node* n, const char* name, const c10::intrusive_ptr& obj) { Value* v = getValueTrace(obj); n->addInput(v); } void addOutput(Node* node, const at::Tensor& output) { setOutput(node->addOutput(), output); } void setOutput(Value* value, const at::Tensor& output) { if (output.defined()) { value->inferTypeFrom(output); setValueTrace(output, value); } } void addOutput(Node* node, const std::vector& outputs) { Value* value = node->addOutput()->setType(ListType::ofTensors()); Graph* graph = node->owningGraph(); Node* unpack_node = graph->insertNode( graph->create(prim::ListUnpack, {value}, outputs.size())); for (const auto i : c10::irange(outputs.size())) { Value* output_val = unpack_node->outputs()[i]; output_val->inferTypeFrom(outputs[i]); setValueTrace(outputs[i], output_val); } } void addOutput(Node* node, const c10::List& outputs) { return addOutput(node, outputs.vec()); } void addOutput( Node* node, const c10::intrusive_ptr& output) { Value* output_val = node->addOutput(); output_val->inferTypeFrom(output); setValueTrace(output, output_val); } const std::shared_ptr& getTracingState() { return detail::tracing_state; } void setTracingState(std::shared_ptr state) { at::tracer::impl::set_dispatch_enabled(state != nullptr); detail::tracing_state = std::move(state); } TracingState::TracingState() : graph(new Graph()), env_stack{Frame()} {} TracingState::~TracingState() = default; autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) { auto& tracing_state = getTracingState(); auto& graph = tracing_state->graph; Variable size_var; { // Make sure this scalar to tensor isn't traced! at::AutoDispatchBelowADInplaceOrView guard; size_var = scalar_to_tensor(at::Scalar(var.size(dim))); } auto* value = getValueTrace(var); auto dim_val = graph->insertConstant(dim); recordSourceLocation(dim_val->node()); auto* node = graph->insertNode(graph->create(aten::size, {value, dim_val})); recordSourceLocation(node); node->output()->setType(jit::IntType::get()); auto ten = graph->insertNode(graph->createNumToTensor(node->output()))->output(); setValueTrace(size_var, ten); return size_var; } autograd::Variable getNumelOf(const autograd::Variable& var) { auto& tracing_state = getTracingState(); auto& graph = tracing_state->graph; Variable numel_var; { // Make sure this scalar to tensor isn't traced! at::AutoDispatchBelowADInplaceOrView guard; numel_var = scalar_to_tensor(at::Scalar(var.numel())); } auto* value = getValueTrace(var); auto* node = graph->insertNode(graph->create(Symbol::aten("numel"), {value})); recordSourceLocation(node); node->output()->setType(jit::IntType::get()); auto ten = graph->insertNode(graph->createNumToTensor(node->output()))->output(); setValueTrace(numel_var, ten); return numel_var; } void ensureUniqueIfOutOfPlaced(const char* name, const at::Tensor& tensor) { auto& state = getTracingState(); if (state && state->force_outplace == false) { // If we're not converting in-place ops to out-of-place, this check is // unnecessary return; } auto aliases = tensor.storage().use_count(); if (isTracing() && aliases > 1) { std::stringstream ss; ss << "There are " << aliases << " live references to the data region being modified when tracing in-place operator " << name << ". This might cause the trace to be incorrect, because all other views " << "that also reference this data will not reflect this change in the trace! " << "On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. " << "are outputs of torch.split), this might still be safe."; warn(ss.str().c_str()); } } void ensureUniqueIfOutOfPlaced( const char* name, const std::optional& tensor) { ensureUniqueIfOutOfPlaced(name, tensor.has_value() ? *tensor : at::Tensor()); } //////////////////////////////////////////////////////////////////////////////// // Argument stash //////////////////////////////////////////////////////////////////////////////// thread_local ArgumentStash ArgumentStash::stash; void ArgumentStash::stashIntArrayRefElem( const std::string& arg_name, size_t size, size_t idx, const Variable& var) { // TODO: check type? if (!isTracing()) return; IntArrayRefTrace& list_trace = stash.intlists.emplace(arg_name, size).first->second; AT_ASSERT(size == list_trace.size()); AT_ASSERT(idx < list_trace.size()); AT_ASSERT(list_trace[idx] == nullptr); Value* ten = getValueTrace(var); auto& g = *ten->owningGraph(); WithInsertPoint guard(ten->node()->next()); auto prim = g.insert(aten::Int, {ten}); list_trace[idx] = prim; } void ArgumentStash::stashValue( const std::string& arg_name, size_t idx, const Variable& var, const TypePtr& type) { if (!isTracing()) return; Value* ten = getValueTrace(var); WithInsertPoint guard(ten->node()->next()); auto& g = *ten->owningGraph(); if (type == IntType::get()) { ten = g.insert(aten::Int, {ten}); } else if (type == FloatType::get()) { ten = g.insert(aten::Float, {ten}); } else if (type == NumberType::get()) { ten = g.insert(aten::ScalarImplicit, {ten}); } stash.values.emplace(arg_name, ten); } //////////////////////////////////////////////////////////////////////////////// // Stack trace recording //////////////////////////////////////////////////////////////////////////////// // no python present so we just do not record source information static void defaultRecordSourceLocation(Node* n) {} static std::atomic record_source_location(defaultRecordSourceLocation); void recordSourceLocation(Node* n) { return record_source_location.load()(n); } void setRecordSourceLocation(void (*v)(Node*)) { record_source_location.store(v); } static std::vector defaultPythonCallstack() { return std::vector(); } static std::atomic python_callstack_fn( defaultPythonCallstack); std::vector pythonCallstack() { return python_callstack_fn.load()(); } void setPythonCallstack(std::vector (*v)()) { python_callstack_fn.store(v); } static void defaultWarn(const std::string& str) { TORCH_WARN(str); } static std::atomic warn_callback{defaultWarn}; const char* WARN_PYTHON_DATAFLOW = " might cause the trace to be incorrect. We can't record the data flow of " "Python values, so this value will be treated as a constant in the future. " "This means that the trace might not generalize to other inputs!"; const char* WARN_CONSTRUCTOR = " results are registered as constants in the trace. You can safely ignore this " "warning if you use this function to create tensors out of constant variables " "that would be the same every time you call this function. In any other case, " "this might cause the trace to be incorrect."; const char* WARN_RESIZE = " can't be represented in the JIT at the moment, so we won't connect any uses of " "this value with its current trace. If you happen to use it again, it will show " "up as a constant in the graph. Consider using `view` or `reshape` to make " "it traceable."; const char* STRICT_TRACER_MSG = " might cause the trace to be incorrect, this is only valid if the container " "structure does not change based on the module's inputs. Consider using a constant " "container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a " "`NamedTuple` instead). If you absolutely need this and know the side effects, pass " "strict=False to trace() to allow this behavior."; // XXX: _kind can be a nullptr void _do_warn(const char* _reason, const char* _kind) { std::string reason{_reason}; std::string kind{_kind ? _kind : ""}; std::ostringstream s; s << reason << kind; warn_callback.load()(s.str()); } void setWarn(warn_fn_type fn) { warn_callback.store(fn); } } // namespace torch::jit::tracer