s/AutoNonVariableTypeMode/AutoDispatchBelowAutograd/ (#56423)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56423

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D27866606

Pulled By: ailzhang

fbshipit-source-id: e3942356dc3133d1c5722de40ec0d45e6a60f2f1
This commit is contained in:
Ailing Zhang
2021-04-20 17:16:20 -07:00
committed by Facebook GitHub Bot
parent 13ac0019ae
commit 3d904b56ec
30 changed files with 61 additions and 61 deletions

View File

@ -31,7 +31,7 @@ Tensor& scalar_fill(Tensor& self, const Scalar& value) {
Tensor scalar_tensor_static(const Scalar& s, c10::optional<ScalarType> dtype_opt, c10::optional<Device> device_opt) {
at::tracer::impl::NoTracerDispatchMode tracer_guard;
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
auto result = at::detail::empty_cpu({}, dtype_opt, c10::nullopt, device_opt, c10::nullopt, c10::nullopt);
scalar_fill(result, s);
return result;

View File

@ -45,7 +45,7 @@ static inline void set_item(const Tensor& self, ArrayRef<TensorIndex> indices, c
Tensor value;
{
at::AutoNonVariableTypeMode guard;
at::AutoDispatchBelowAutograd guard;
// TODO: This qint special case looks very suspicious...
if (isQIntType(self.scalar_type())) {
value = at::indexing::scalarToTensor(v, device(kCPU).dtype(kFloat), at::Device(kCPU));

View File

@ -8,12 +8,12 @@
//
// Historically, tracing function was controlled by two switches:
//
// - `AutoNonVariableTypeMode` guard
// - `AutoDispatchBelowAutograd` guard
//
// Tracing function used to be script-generated inside `VariableType_*.cpp`
// kernels, sharing the same `Autograd` dispatch key with autograd function.
// Therefore, before tracing function was moved out of VariableType,
// `AutoNonVariableTypeMode` guard can also disable tracing as a side effect
// `AutoDispatchBelowAutograd` guard can also disable tracing as a side effect
// of disabling `Autograd` dispatching.
//
// - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h`
@ -42,7 +42,7 @@
//
// - `tracer::impl::NoTracerDispatchMode` guard
//
// It's used to cover the old semantics of `AutoNonVariableTypeMode` after
// It's used to cover the old semantics of `AutoDispatchBelowAutograd` after
// tracing was moved out of VariableType.
//
// Before tracing function was moved out of VariableType, tracing was enabled
@ -51,7 +51,7 @@
// 1) `TracingState` object in TLS != null;
// - Either inside the execution scope of `tracer::trace()`, or
// - Eagerly called `setTracingState()` with non-null object.
// 2) Not inside `AutoNonVariableTypeMode` scope;
// 2) Not inside `AutoDispatchBelowAutograd` scope;
//
// After:
//
@ -69,10 +69,10 @@
// `setTracingState()` Python/C++ APIs (and other APIs calling it) so that
// these two can be unified.
//
// - `AutoNonVariableTypeMode` v.s. `tracer::impl::NoTracerDispatchMode`
// - `AutoDispatchBelowAutograd` v.s. `tracer::impl::NoTracerDispatchMode`
//
// We don't need to always set both guards together to keep semantics
// unchanged. For the follow use cases of `AutoNonVariableTypeMode` we don't
// unchanged. For the follow use cases of `AutoDispatchBelowAutograd` we don't
// need set the new tracer guard:
//
// * Script-generated VariableType kernels. The guard is not necessary as
@ -99,7 +99,7 @@
// * Some manually maintained functions, e.g.:
// `torch/csrc/autograd/VariableTypeManual.cpp`.
// Set the new guard if it's not obvious whether `setTracingState(null)`
// has been called before it reaches the `AutoNonVariableTypeMode` guard.
// has been called before it reaches the `AutoDispatchBelowAutograd` guard.
//
// We might need tweak the usage of the new guard to optimize/fix things.
// It should only affect the correctness of tracing function, because the

View File

@ -45,7 +45,7 @@ int64_t decrementKernel(const Tensor& tensor, int64_t input) {
}
void expectCallsIncrement(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -56,7 +56,7 @@ void expectCallsIncrement(DispatchKey dispatch_key) {
}
void expectCallsDecrement(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -911,7 +911,7 @@ std::string concatKernel(const Tensor& tensor1, std::string a, const std::string
}
void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});

View File

@ -34,7 +34,7 @@ int64_t decrementKernel(const Tensor& tensor, int64_t input) {
}
void expectCallsIncrement(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -45,7 +45,7 @@ void expectCallsIncrement(DispatchKey dispatch_key) {
}
void expectCallsDecrement(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -653,7 +653,7 @@ std::string concatKernel(const Tensor& tensor1, std::string a, const std::string
}
void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -663,7 +663,7 @@ void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
}
void expectCannotCallConcatBoxed(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});

View File

@ -31,7 +31,7 @@ using std::unique_ptr;
namespace {
void expectCallsIncrement(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -839,7 +839,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalIn
}
void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});

View File

@ -21,7 +21,7 @@ using std::unique_ptr;
namespace {
void expectCallsIncrement(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -32,7 +32,7 @@ void expectCallsIncrement(DispatchKey dispatch_key) {
}
void expectCallsDecrement(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -564,7 +564,7 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_w
}
void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});

View File

@ -43,7 +43,7 @@ void redispatchingKernel_with_DispatchKeySet(const OperatorHandle& op, c10::Disp
}
void expectCallsIncrement(c10::DispatchKeySet ks) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -58,7 +58,7 @@ void expectCallsIncrement(DispatchKey dispatch_key) {
}
void expectCallsIncrementUnboxed(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -68,7 +68,7 @@ void expectCallsIncrementUnboxed(DispatchKey dispatch_key) {
}
void expectCallsDecrement(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});

View File

@ -41,7 +41,7 @@ struct DecrementKernel final : OperatorKernel {
};
void expectCallsIncrement(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -52,7 +52,7 @@ void expectCallsIncrement(DispatchKey dispatch_key) {
}
void expectCallsDecrement(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});
@ -794,7 +794,7 @@ struct ConcatKernel final : OperatorKernel {
};
void expectCallsConcatUnboxed(DispatchKey dispatch_key) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// assert that schema and cpu kernel are present
auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""});

View File

@ -82,7 +82,7 @@ template <typename T>
"https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]]
Tensor create(std::unique_ptr<T> ptr, TensorOptions options) {
// None of this should trace, so turn off Tracer dispatching
at::AutoNonVariableTypeMode guard; // TODO: remove
at::AutoDispatchBelowAutograd guard; // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard;
// We store this instance away in a Tensor and register a deleter function

View File

@ -270,7 +270,7 @@ void slow_conv2d_backward_out_cpu_template(
const int64_t batch_size = input.size(0);
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
NoGradGuard no_grad;
AutoNonVariableTypeMode non_variable_type_mode;
AutoDispatchBelowAutograd non_variable_type_mode;
for (int64_t t = start; t < end; t++) {
Tensor grad_input_t = grad_input[t];
Tensor grad_output_t = grad_output[t];
@ -448,7 +448,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> slow_conv2d_forward_out_cpu(const Tensor&
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
NoGradGuard no_grad;
AutoNonVariableTypeMode non_variable_type_mode;
AutoDispatchBelowAutograd non_variable_type_mode;
for (int64_t t = start; t < end; t++) {
Tensor input_t = input[t].unsqueeze(0);
Tensor output_t = output[t];

View File

@ -370,7 +370,7 @@ void slow_conv3d_backward_out_cpu_template(
const int64_t batch_size = input.size(0);
at::parallel_for(
0, batch_size, CONV3D_GRAIN_SALT, [&](int64_t start, int64_t end) {
AutoNonVariableTypeMode non_variable_type_mode;
AutoDispatchBelowAutograd non_variable_type_mode;
for (int64_t t = start; t < end; t++) {
Tensor grad_input_t = grad_input[t];
Tensor grad_output_t = grad_output_contiguous[t];
@ -596,7 +596,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> slow_conv3d_forward_out_cpu(const Tensor&
at::parallel_for(
0, batch_size, CONV3D_GRAIN_SALT, [&](int64_t start, int64_t end) {
AutoNonVariableTypeMode non_variable_type_mode;
AutoDispatchBelowAutograd non_variable_type_mode;
for (int64_t t = start; t < end; t++) {
Tensor input_t = input[t];
Tensor output_t = output[t];

View File

@ -619,7 +619,7 @@ Tensor scalar_tensor(const Scalar& s,
// revert this to following:
// auto result = at::empty({}, options);
at::tracer::impl::NoTracerDispatchMode tracer_guard;
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
auto result = empty_cpu({}, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
at::native::fill_(result, s);
return result;

View File

@ -31,7 +31,7 @@ namespace {
// Note: this is not a native function as Quantizer is not exposed to python yet
QuantizerPtr Tensor::quantizer() const {
// This is a terrible hack to emulate what VariableType is doing
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
return get_qtensorimpl(*this)->quantizer();
}

View File

@ -71,7 +71,7 @@ void noopDelete(void*) {}
} // namespace detail
Tensor TensorMaker::make_tensor() {
AutoNonVariableTypeMode guard{}; // TODO: Remove.
AutoDispatchBelowAutograd guard{}; // TODO: Remove.
tracer::impl::NoTracerDispatchMode tracer_guard{};
check_size_nonnegative(sizes_);

View File

@ -59,7 +59,7 @@ struct C10_API LocalDispatchKeySet {
};
// thread_local variables cannot be C10_API on Windows.
// Inlining this seems to break AutoNonVariableTypeGuard on Android.
// Inlining this seems to break AutoDispatchBelowAutograd on Android.
#if defined(_MSC_VER) || defined(C10_ANDROID)
C10_API LocalDispatchKeySet tls_local_dispatch_key_set();
#else // defined(_MSC_VER) || defined(C10_ANDROID)

View File

@ -165,7 +165,7 @@ IMPLEMENTATION_TEMPLATE = CT("""\
C10_NOINLINE void implementation_${key}() { // ${name}
${initialization}
run_op = [=] {
at::AutoNonVariableTypeMode guard;
at::AutoDispatchBelowAutograd guard;
${statements}
auto the_result = ${invocation};
${assignments}

View File

@ -81,7 +81,7 @@ void get_autograd_operator_from_registry_and_execute() {
}
void get_autograd_operator_from_registry_and_execute_in_nograd_mode() {
at::AutoNonVariableTypeMode _var_guard(true);
at::AutoDispatchBelowAutograd _var_guard(true);
torch::Tensor x = torch::randn({5,5}, torch::requires_grad());
torch::Tensor y = torch::randn({5,5}, torch::requires_grad());

View File

@ -231,7 +231,7 @@ CALL_REDISPATCH = CodeTemplate("""\
at::redispatch::${api_name}(${unpacked_args})""")
# If the non-variable operation has return values, we use the `tmp` variable to hold the
# values temporarily and pass the values to the return variables outside of the
# `at::AutoNonVariableTypeMode` guard block.
# `at::AutoDispatchBelowAutograd` guard block.
DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate("""\
auto ${tmp_var} = ([&]() {
${guard}
@ -672,7 +672,7 @@ def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
return call
def emit_call(f: NativeFunction, unpacked_bindings: List[Binding]) -> str:
# We only care about adding `at::AutoNonVariableTypeMode` guard for non-variable dispatch
# We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
# (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
# the baseType operations still dispatch to non-Variable type, even if the arguments passed
# in are now Variables.
@ -681,7 +681,7 @@ def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
base_type_call = emit_dispatch_call(f, 'self_', unpacked_args)
if get_view_info(fn) is not None or modifies_arguments(f):
guard = 'at::AutoNonVariableTypeMode guard;'
guard = 'at::AutoDispatchBelowAutograd guard;'
else:
guard = 'at::AutoDispatchBelowInplaceOrView guard;'

View File

@ -65,7 +65,7 @@ inline at::Tensor from_blob(
const Deleter& deleter,
const at::TensorOptions& options = at::TensorOptions()) {
at::Tensor tensor = ([&]() {
at::AutoNonVariableTypeMode non_var_type_mode(true); // TODO: remove
at::AutoDispatchBelowAutograd non_var_type_mode(true); // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard;
return at::from_blob(data, sizes, strides, deleter, options.requires_grad(c10::nullopt));
})();
@ -83,7 +83,7 @@ inline at::Tensor from_blob(
at::IntArrayRef strides,
const at::TensorOptions& options = at::TensorOptions()) {
at::Tensor tensor = ([&]() {
at::AutoNonVariableTypeMode non_var_type_mode(true); // TODO: remove
at::AutoDispatchBelowAutograd non_var_type_mode(true); // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard;
return at::from_blob(data, sizes, strides, options.requires_grad(c10::nullopt));
})();
@ -102,7 +102,7 @@ inline at::Tensor from_blob(
const Deleter& deleter,
const at::TensorOptions& options = at::TensorOptions()) {
at::Tensor tensor = ([&]() {
at::AutoNonVariableTypeMode non_var_type_mode(true); // TODO: remove
at::AutoDispatchBelowAutograd non_var_type_mode(true); // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard;
return at::from_blob(data, sizes, deleter, options.requires_grad(c10::nullopt));
})();
@ -118,7 +118,7 @@ inline at::Tensor from_blob(
at::IntArrayRef sizes,
const at::TensorOptions& options = at::TensorOptions()) {
at::Tensor tensor = ([&]() {
at::AutoNonVariableTypeMode non_var_type_mode(true); // TODO: remove
at::AutoDispatchBelowAutograd non_var_type_mode(true); // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard;
return at::from_blob(data, sizes, options.requires_grad(c10::nullopt));
})();

View File

@ -137,7 +137,7 @@ AT_FORALL_COMPLEX_TYPES(TENSOR)
sizes_({(int64_t)values.size()}), \
scalar_type_(at::k##S), \
type_(TensorDataContainerType::Tensor) { \
at::AutoNonVariableTypeMode non_var_type_mode(true); \
at::AutoDispatchBelowAutograd non_var_type_mode(true); \
if (scalar_type_ == at::kBool) { \
tensor_ = at::tensor(values, at::TensorOptions().device(at::kCPU)); \
} else { \
@ -212,7 +212,7 @@ AT_FORALL_COMPLEX_TYPES(TENSOR)
}
if (is_scalar()) {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
return at::scalar_tensor(scalar_, options);
} else if (is_init_list()) {
// NOTE: Here we explicitly choose to initialize the tensor on CPU first,
@ -222,7 +222,7 @@ AT_FORALL_COMPLEX_TYPES(TENSOR)
// filling each element of it (which involves `N` CUDA kernel launches where
// `N` is the number of the elements in the tensor).
at::Tensor tensor = ([&]() {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
return at::empty(sizes_, options.device(at::kCPU));
})();
fill_tensor(tensor);

View File

@ -90,7 +90,7 @@ Tensor _fw_primal(const Tensor & self, int64_t level) {
grad_fn->set_next_edges(collect_next_edges( self ));
}
auto tmp = ([&]() {
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
return self_.alias();
})();
std::function<at::Tensor(const at::Tensor&)> func=nullptr;
@ -131,7 +131,7 @@ Tensor & copy_(c10::DispatchKeySet ks, Tensor & self, const Tensor & src, bool n
grad_fn->src_device = src.device();
}
{
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
at::redispatch::copy_(ks & c10::after_autograd_keyset, self_, src_, non_blocking);
}
rebase_history(self , std::move(grad_fn));
@ -166,7 +166,7 @@ Tensor& resize_(
AT_ERROR("cannot resize variables that require grad");
}
{
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
at::redispatch::resize_(ks & c10::after_autograd_keyset, self_, size, optional_memory_format);
}
@ -188,7 +188,7 @@ Tensor& resize_as_(
AT_ERROR("cannot resize variables that require grad");
}
{
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
at::redispatch::resize_as_(ks & c10::after_autograd_keyset, self_, the_template_, optional_memory_format);
}

View File

@ -124,7 +124,7 @@ variable_list Gather::apply(variable_list&& inputs) {
// so no need for extra logic here
at::Tensor variable;
{
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// This is special logic for torch::cuda::gather!
const auto destination_index =
destination_device_.is_cpu() ? -1 : destination_device_.index();

View File

@ -92,7 +92,7 @@ static inline Variable valueToTensor(c10::TensorOptions options, PyObject* value
if (THPVariable_Check(value)) {
return THPVariable_Unpack(value);
}
at::AutoNonVariableTypeMode guard; // TODO: remove
at::AutoDispatchBelowAutograd guard; // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard;
if (THPUtils_checkLong(value) || PyBool_Check(value)) {
return at::indexing::scalarToTensor(Scalar(THPUtils_unpackLong(value)), options, device);

View File

@ -186,7 +186,7 @@ at::Tensor inferAndAlloc(
} else {
c10::IntArrayRef isizes(sizes);
// Non Variable type guard for empty_cuda call
at::AutoNonVariableTypeMode non_variable_type_mode;
at::AutoDispatchBelowAutograd non_variable_type_mode;
return at::native::empty_cuda(
isizes, at_type, c10::nullopt, options.device, c10::nullopt);
}
@ -410,7 +410,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
if (executor_entry && executor_entry->init) {
{
// context manager to disable auto grad for `empty_cuda` calls later;
at::AutoNonVariableTypeMode non_variable_type_mode;
at::AutoDispatchBelowAutograd non_variable_type_mode;
// take the short-cut for launch if we see a recorded input set again;
launch_params = executor_entry->launch_params;
for (size_t i = 0; i < executor_entry->output_sizes.size(); i++) {

View File

@ -868,7 +868,7 @@ autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
Variable size_var;
{
// Make sure this scalar to tensor isn't traced!
at::AutoNonVariableTypeMode guard;
at::AutoDispatchBelowAutograd guard;
size_var = scalar_to_tensor(at::Scalar(var.size(dim)));
}
auto* value = getValueTrace(var);

View File

@ -441,7 +441,7 @@ jit::RegisterOperators reg_fut_ops({
Tensor weight = pop(stack).toTensor();
Tensor input = pop(stack).toTensor();
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::AutoDispatchBelowAutograd non_var_type_mode(true);
// aten::convolution takes care of 0 dim case before calls into
// backends
if (input.size(0) == 0) {

View File

@ -1092,7 +1092,7 @@ at::Tensor PythonArgs::tensor_slow(int i) {
throw TypeError("expected Tensor as argument %d, but got %s", i,
Py_TYPE(obj)->tp_name);
}
at::AutoNonVariableTypeMode guard; // TODO: remove
at::AutoDispatchBelowAutograd guard; // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard;
at::Tensor tensor = scalar_to_tensor(scalar);

View File

@ -256,7 +256,7 @@ Tensor internal_new_from_data(
// here.
Tensor tensor;
{
at::AutoNonVariableTypeMode guard; // TODO: remove
at::AutoDispatchBelowAutograd guard; // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard;
tensor = at::empty(sizes, at::initialTensorOptions().dtype(inferred_scalar_type).pinned_memory(pin_memory));
recursive_store(

View File

@ -1055,7 +1055,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork {
//
// The correct fix is to stop allocating tensors that are not variables,
// but to conveniently do this c10d must depend on torch not ATen
at::AutoNonVariableTypeMode _no_grad(true);
at::AutoDispatchBelowAutograd _no_grad(true);
auto input = tensors[0];
// Perform local reduction if we have multiple inputs.