mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This reverts commit ac7b4e7fe4d233dcd7f6343d42b4fa3d64bce548. Reverted https://github.com/pytorch/pytorch/pull/156703 on behalf of https://github.com/clee2000 due to failing internally D80206253, see above comment for details ([comment](https://github.com/pytorch/pytorch/pull/156703#issuecomment-3362156908))
589 lines
21 KiB
Python
589 lines
21 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
from torch.testing._internal.common_jit import check_against_reference
|
|
from torch.testing._internal.common_utils import (
|
|
enable_profiling_mode_for_profiling_tests,
|
|
GRAPH_EXECUTOR,
|
|
num_profiled_runs,
|
|
ProfilingMode,
|
|
)
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from typing import List, Optional, Tuple
|
|
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
from torch.testing._internal.jit_utils import (
|
|
disable_autodiff_subgraph_inlining,
|
|
JitTestCase,
|
|
)
|
|
|
|
|
|
@unittest.skipIf(
|
|
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
|
|
)
|
|
class TestAutodiffSubgraphSlicing(JitTestCase):
|
|
# TODO: It is better if we can test directly on graphs instead of the current
|
|
# end-to-end fashion.
|
|
def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
|
|
with disable_autodiff_subgraph_inlining():
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
ge = torch.jit.script(fn)
|
|
inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
|
|
ge(*inputs, profile_and_replay=True)
|
|
return ge.graph_for(*inputs)
|
|
|
|
def assertGraphSize(self, graph, size):
|
|
nodes = list(
|
|
filter(
|
|
lambda n: (
|
|
n.kind() != "prim::BailOut"
|
|
and n.kind() != "prim::BailoutTemplate"
|
|
and n.kind() != "prim::TypeCheck"
|
|
and n.kind() != "prim::RequiresGradCheck"
|
|
),
|
|
graph.nodes(),
|
|
)
|
|
)
|
|
self.assertEqual(len(list(nodes)), size)
|
|
|
|
def test_chunk_constant_script_ad(self):
|
|
@torch.jit.script
|
|
def func(x):
|
|
x1, x2 = torch.chunk(x, 2)
|
|
return (x1, x2)
|
|
|
|
input = torch.rand(6, 10).requires_grad_()
|
|
with disable_autodiff_subgraph_inlining():
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
func(input, profile_and_replay=True)
|
|
FileCheck().check_not("prim::DifferentiableGraph").run(
|
|
func.graph_for(input)
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
GRAPH_EXECUTOR != ProfilingMode.PROFILING,
|
|
"This threshold is only valid for Profiling Executor",
|
|
)
|
|
def test_diff_graph_inline_threshold(self):
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
NUM_RUNS = 1
|
|
with num_profiled_runs(NUM_RUNS):
|
|
|
|
@torch.jit.script
|
|
def foo(x):
|
|
# two nodes should be fused
|
|
# see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49
|
|
return torch.sigmoid(torch.sigmoid(x))
|
|
|
|
@torch.jit.script
|
|
def bar(x):
|
|
# two nodes should NOT be fused
|
|
return torch.sigmoid(x)
|
|
|
|
input = torch.rand([4, 4], requires_grad=True)
|
|
foo(input)
|
|
foo(input)
|
|
|
|
bar(input)
|
|
bar(input)
|
|
|
|
self.assertGraphContainsExactly(
|
|
foo.graph_for(input), "prim::DifferentiableGraph", 1
|
|
)
|
|
self.assertGraphContainsExactly(
|
|
bar.graph_for(input), "prim::DifferentiableGraph", 0
|
|
)
|
|
|
|
def test_bias_as_module_attr(self):
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, has_bias):
|
|
super().__init__()
|
|
self.ll = torch.nn.Linear(10, 10, has_bias)
|
|
|
|
def forward(self, x, y):
|
|
return self.ll(x + y) * x + y
|
|
|
|
x = torch.rand(10, 10, requires_grad=True)
|
|
no_bias = M(False)
|
|
scripted_no_bias = torch.jit.script(no_bias)
|
|
scripted_no_bias(x, x)
|
|
scripted_no_bias(x, x)
|
|
scripted_no_bias(x, x)
|
|
has_bias = M(True)
|
|
check_against_reference(
|
|
self,
|
|
scripted_no_bias,
|
|
no_bias,
|
|
lambda x: x,
|
|
(
|
|
x,
|
|
x,
|
|
),
|
|
check_types=False,
|
|
)
|
|
scripted_has_bias = torch.jit.script(has_bias)
|
|
scripted_has_bias(x, x)
|
|
scripted_has_bias(x, x)
|
|
scripted_has_bias(x, x)
|
|
check_against_reference(
|
|
self,
|
|
scripted_has_bias,
|
|
has_bias,
|
|
lambda x: x,
|
|
(
|
|
x,
|
|
x,
|
|
),
|
|
check_types=False,
|
|
)
|
|
|
|
def test_constructed_bias(self):
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
|
|
def method1(x, weight, b1, b2):
|
|
bias = b1 * b2
|
|
return torch.nn.functional.linear(x, weight, bias)
|
|
|
|
N = 10
|
|
x = torch.rand(N, N, requires_grad=True)
|
|
weight = torch.rand(N, N, requires_grad=True)
|
|
b1 = torch.rand(N, N, requires_grad=True)
|
|
b2 = torch.rand(N, N, requires_grad=True)
|
|
scripted = self.checkScript(method1, (x, weight, b1, b2))
|
|
# check_types requires last_graph on scripted to be set, so we just skip it
|
|
check_against_reference(
|
|
self,
|
|
scripted,
|
|
method1,
|
|
lambda x: x,
|
|
(x, weight, b1, b2),
|
|
check_types=False,
|
|
)
|
|
|
|
def test_bias_as_arg(self):
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
|
|
def method1(x, weight, bias: Optional[torch.Tensor]):
|
|
return torch.nn.functional.linear(x, weight, bias).relu() + 2
|
|
|
|
N = 10
|
|
x = torch.rand(N, N, requires_grad=True)
|
|
weight = torch.rand(N, N, requires_grad=True)
|
|
bias = None
|
|
scripted = self.checkScript(method1, (x, weight, bias))
|
|
# check_types requires last_graph on scripted to be set, so we just skip it
|
|
check_against_reference(
|
|
self,
|
|
scripted,
|
|
method1,
|
|
lambda x: x,
|
|
(x, weight, bias),
|
|
check_types=False,
|
|
)
|
|
bias = torch.rand(N, N, requires_grad=True)
|
|
scripted = self.checkScript(method1, (x, weight, bias))
|
|
# check_types requires last_graph on scripted to be set, so we just skip it
|
|
check_against_reference(
|
|
self,
|
|
scripted,
|
|
method1,
|
|
lambda x: x,
|
|
(x, weight, bias),
|
|
check_types=False,
|
|
)
|
|
|
|
def test_requires_grad_for_tensor_list(self):
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
# output & var_list[0] should have requires_grad set to True
|
|
def func(
|
|
input0: torch.Tensor, input1: torch.Tensor
|
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
|
var_list = [input0, input1]
|
|
var = torch.cat(var_list)
|
|
output = var + 1.0
|
|
return output, var_list
|
|
|
|
jit_f = torch.jit.script(func)
|
|
input0 = torch.randn((2,), requires_grad=True)
|
|
input1 = torch.randn((2,))
|
|
output_ref = func(input0, input1)
|
|
for _ in range(2):
|
|
output = jit_f(input0, input1)
|
|
assert output_ref[0].requires_grad == output[0].requires_grad
|
|
assert output_ref[1][0].requires_grad == output[1][0].requires_grad
|
|
assert output_ref[1][1].requires_grad == output[1][1].requires_grad
|
|
|
|
@unittest.skip(
|
|
"disable until we property handle tensor lists with undefined gradients"
|
|
)
|
|
def test_differentiable_graph_ops_requires_grad(self):
|
|
x = torch.randn(8, 2, dtype=torch.float).requires_grad_()
|
|
y = torch.randn(8, 2, dtype=torch.float)
|
|
|
|
def t(x: torch.Tensor, y: torch.Tensor, flag: bool):
|
|
o = x + 1.0
|
|
o1 = torch.relu(o)
|
|
o = y + 1.5
|
|
o2 = torch.relu(o)
|
|
o3 = o1 + o2
|
|
|
|
if flag:
|
|
o = o1 + 1.0
|
|
oo1 = torch.relu(o)
|
|
o = o2 + 2.5
|
|
oo2 = torch.relu(o)
|
|
oo3 = oo1 + oo2
|
|
else:
|
|
o = o1 * 1.0
|
|
oo1 = torch.relu(o)
|
|
o = o2 * 2.0
|
|
oo2 = torch.relu(o)
|
|
oo3 = oo1 + oo2
|
|
|
|
return o1, o2, o3, oo1, oo2, oo3
|
|
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
t_jit = torch.jit.script(t)
|
|
jit_o = t_jit(x, y, False)
|
|
jit_o = t_jit(x, y, False)
|
|
o = t(x, y, False)
|
|
|
|
FileCheck().check("prim::DifferentiableGraph").run(
|
|
t_jit.graph_for(x, y, False)
|
|
)
|
|
# validate the differentiableGraphOps are marking proper requires_grad
|
|
for oo, jit_oo in zip(o, jit_o):
|
|
self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
|
|
self.assertEqual(oo, jit_oo)
|
|
# one more runs to trigger fusion
|
|
jit_o = t_jit(x, y, False)
|
|
for oo, jit_oo in zip(o, jit_o):
|
|
self.assertEqual(oo.dtype, jit_oo.dtype)
|
|
self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
|
|
self.assertEqual(oo, jit_oo)
|
|
|
|
@unittest.skipIf(
|
|
GRAPH_EXECUTOR == ProfilingMode.PROFILING,
|
|
"Simple Executor doesn't support gradients",
|
|
)
|
|
def test_prune_grad(self):
|
|
@torch.jit.script
|
|
def t(input, bias):
|
|
return torch.nn.functional.relu(input + bias)
|
|
|
|
input = torch.randn(2, 8, requires_grad=True)
|
|
bias = torch.randn(8, requires_grad=False) # bias does NOT require grad
|
|
NUM_PROFILED_RUNS = 1
|
|
with num_profiled_runs(NUM_PROFILED_RUNS):
|
|
WARMUP = 3 # 2 runs to reach backward + 1 to optimize it
|
|
for _ in range(WARMUP):
|
|
o = t(input, bias)
|
|
o.sum().backward()
|
|
|
|
fwd_plan = list(t.get_debug_state().execution_plans.values())[0]
|
|
bwd_graph = list(
|
|
fwd_plan.code.grad_executor_states()[0].execution_plans.values()
|
|
)[0].graph
|
|
tup = next(bwd_graph.outputs())
|
|
self.assertEqual(len(list(tup.node().inputs())), 1)
|
|
|
|
def test_simple_merge(self):
|
|
# o --> o
|
|
def fn(x, y, z):
|
|
a = x * y
|
|
b = a * z
|
|
return b
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
|
|
|
self.assertGraphSize(graph, 1)
|
|
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
|
|
|
def test_simple_no_merge(self):
|
|
# o: autodiff supported. x: not autodiff supported.
|
|
# o --> x
|
|
def fn(x, y, z):
|
|
a = x * y
|
|
b = torch.zeros([abs(int(y))])
|
|
return a, b
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
|
g_str = str(graph)
|
|
FileCheck().check("aten::Int").check("aten::zeros").check_not("aten::mul").run(
|
|
g_str[0 : g_str.find("return")]
|
|
)
|
|
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
|
|
|
def test_does_not_merge_unrelated(self):
|
|
# o o
|
|
def fn(w, x, y, z):
|
|
a = x * y
|
|
b = w * z
|
|
return a, b
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
|
|
|
|
self.assertGraphSize(graph, 3)
|
|
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
|
|
|
|
def test_merges_without_cycles(self):
|
|
# o --> o --> o
|
|
# | ^
|
|
# \_________/
|
|
def fn(w, x, y):
|
|
a = w * x
|
|
b = a * y
|
|
c = a * b
|
|
return c
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
|
|
|
self.assertGraphSize(graph, 1)
|
|
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
|
|
|
def test_merges_dense(self):
|
|
# o o
|
|
# |\ /|
|
|
# | \ / |
|
|
# | /\ |
|
|
# vv vv
|
|
# o o
|
|
def fn(x, y):
|
|
a, b = x.chunk(2)
|
|
c, d = y.chunk(2)
|
|
return a + c, b + d
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 2, 2)
|
|
|
|
self.assertGraphSize(graph, 2)
|
|
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
|
|
|
def test_does_not_create_cycles(self):
|
|
# o --> x --> o
|
|
# | ^
|
|
# \_________/
|
|
def fn(w, x, y):
|
|
a = w * x
|
|
b = torch.zeros(abs(int(a)))
|
|
c = a * b
|
|
return c
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
|
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
|
|
|
|
def test_merges_up(self):
|
|
# o --> x o
|
|
# | ^
|
|
# \_________/
|
|
def fn(w, x, y, z):
|
|
a = w * x
|
|
b = torch.zeros(abs(int(y)))
|
|
c = a * z
|
|
return b, c
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
|
|
g_str = str(graph)
|
|
FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")])
|
|
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
|
|
|
def test_merges_down(self):
|
|
# o x --> o
|
|
# | ^
|
|
# \_________/
|
|
def fn(v, w, x, y):
|
|
a = v * w
|
|
b = torch.ones(int(y))
|
|
c = b * a
|
|
return a, c
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
|
|
|
|
# add moved down
|
|
g_str = str(graph)
|
|
FileCheck().check_not("aten::add").run(g_str[0 : g_str.find("return")])
|
|
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 1)
|
|
|
|
def test_respects_lexical_scoping(self):
|
|
def fn(x, k):
|
|
y = x * 1.1
|
|
if bool(k):
|
|
k = k + y
|
|
z = y * k
|
|
return z, k
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1)
|
|
# We should not have combined the two multiplications into
|
|
# the same group; they should each be a separate DiffGraph
|
|
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 3)
|
|
|
|
def test_merge_respects_aliasing(self):
|
|
def fn(x, k, cond):
|
|
y = x * 1.1
|
|
y = y * k
|
|
y = y * 2.2
|
|
if bool(cond):
|
|
z1 = y[0]
|
|
z2 = y[1]
|
|
z1.add_(3)
|
|
out = z2 + k + 3.3
|
|
out = out * out
|
|
return out
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, [2, 2], [2, 2], 1)
|
|
# z2 did did not get merged into the subgraph
|
|
FileCheck().check("prim::If").check("aten::select").check_next(
|
|
"aten::select"
|
|
).check_next("aten::add_").check("Differentiable").run(graph)
|
|
self.assertGraphContainsExactly(graph, "prim::DifferentiableGraph", 2)
|
|
|
|
def test_aliased_outputs(self):
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
# Case 1: aliasing between relu and t
|
|
# is within a DifferentiableGraph. It should be valid
|
|
# to merge both split_with_sizes in relu in one graph
|
|
input_str = """
|
|
graph(%a : Tensor):
|
|
%b : Tensor = aten::relu(%a)
|
|
%2 : Tensor = aten::t(%b)
|
|
return (%2)
|
|
"""
|
|
|
|
graph = torch._C.parse_ir(input_str)
|
|
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
|
|
FileCheck().check("with prim::DifferentiableGraph").check(
|
|
"aten::relu"
|
|
).check("aten::t").run(graph)
|
|
|
|
# Case 2: aliasing between relu and split_with_sizes
|
|
# are both outputs of a Diff graph. It should be invalid
|
|
# to merge both split_with_sizes in relu in one graph
|
|
# i.e. relu and split_with_sizes should be in different
|
|
# differentiable graphs
|
|
input_str = """
|
|
graph(%a : Tensor):
|
|
%b : Tensor = aten::relu(%a)
|
|
%0 : int[] = prim::Constant[value=[2, 2, 1]]()
|
|
%1 : int = prim::Constant[value=0]()
|
|
%2 : Tensor[] = aten::split_with_sizes(%b, %0, %1)
|
|
%3 : (Tensor[], Tensor[]) = prim::TupleConstruct(%b, %2)
|
|
return (%3)
|
|
"""
|
|
|
|
graph = torch._C.parse_ir(input_str)
|
|
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
|
|
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
|
|
"with prim::DifferentiableGraph"
|
|
).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run(
|
|
graph
|
|
)
|
|
|
|
# Case 3: two aliased nodes in a graph.
|
|
# Both `split_with_sizes` should be unfused
|
|
input_str = """
|
|
graph(%a : Tensor):
|
|
%b : Tensor = aten::relu(%a)
|
|
%s1 : int[] = prim::Constant[value=[2, 2, 1]]()
|
|
%s2 : int[] = prim::Constant[value=[3, 1]]()
|
|
%1 : int = prim::Constant[value=0]()
|
|
%2 : Tensor[] = aten::split_with_sizes(%b, %s1, %1)
|
|
%3 : Tensor[] = aten::split_with_sizes(%b, %s2, %1)
|
|
%4 : (Tensor, Tensor[]) = prim::TupleConstruct(%b, %2, %3)
|
|
return (%4)
|
|
"""
|
|
|
|
graph = torch._C.parse_ir(input_str)
|
|
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
|
|
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
|
|
"with prim::DifferentiableGraph"
|
|
).check("Tensor = aten::relu").check_not("aten::split_with_sizes").run(
|
|
graph
|
|
)
|
|
|
|
# Case 4: the aliased output has a descendant
|
|
# Both should be unfused. Note, %3 comes before %2
|
|
# to test that we unfuse in the reverse topo order
|
|
input_str = """
|
|
graph(%a : Tensor):
|
|
%b : Tensor = aten::relu(%a)
|
|
%0 : int[] = prim::Constant[value=[2, 2, 1]]()
|
|
%1 : int = prim::Constant[value=0]()
|
|
%2 : Tensor = aten::t(%b)
|
|
%3 : Tensor = aten::relu(%2)
|
|
%4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%b, %3, %2)
|
|
return (%4)
|
|
"""
|
|
|
|
graph = torch._C.parse_ir(input_str)
|
|
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
|
|
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
|
|
"with prim::DifferentiableGraph"
|
|
).check("Tensor = aten::relu").check_not("aten::t").run(graph)
|
|
|
|
# Case 5: multiple aliased groups
|
|
# Both should be unfused. Note, %3 comes before %2
|
|
# to test that we unfuse in the reverse topo order
|
|
input_str = """
|
|
graph(%a : Tensor):
|
|
%b : Tensor = aten::relu(%a)
|
|
%c : Tensor = aten::abs(%a)
|
|
%0 : int[] = prim::Constant[value=[2, 2, 1]]()
|
|
%1 : int = prim::Constant[value=0]()
|
|
%d : Tensor = aten::t(%c)
|
|
%2 : Tensor = aten::t(%b)
|
|
%3 : Tensor = aten::relu(%2)
|
|
%4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%3, %2, %d, %b, %c, %b)
|
|
return (%4)
|
|
"""
|
|
|
|
graph = torch._C.parse_ir(input_str)
|
|
torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
|
|
FileCheck().check("Tensor = prim::DifferentiableGraph").check(
|
|
"with prim::DifferentiableGraph"
|
|
).check("Tensor = aten::relu").check_not("aten::t").run(graph)
|
|
|
|
def test_has_profiled_info_aliasing_outputs(self):
|
|
# The expectation is that CallFunction will prevent the final profile node from
|
|
# getting merged into the DifferentiableGraph, and that create_autodiff_subgraphs
|
|
# will instead add this to the type for %4.
|
|
ir = """
|
|
graph(%a : Tensor):
|
|
%1 : Tensor = prim::profile[profiled_type=Float(requires_grad=0)](%a)
|
|
%2 : Tensor = aten::relu(%1)
|
|
%3 : Tensor = prim::profile[profiled_type=Float(requires_grad=0)](%2)
|
|
%4 : Tensor = aten::relu(%3)
|
|
%5 : Tensor = prim::CallFunction(%4)
|
|
%6 : Tensor = prim::profile[profiled_type=Float(requires_grad=0)](%4)
|
|
return (%6)
|
|
"""
|
|
|
|
graph = torch._C.parse_ir(ir)
|
|
torch._C._jit_pass_create_autodiff_subgraphs(graph)
|
|
|
|
for n in graph.nodes():
|
|
if n.kind() == "prim::DifferentiableGraph":
|
|
diff_graph = n.g("Subgraph")
|
|
|
|
outputs = list(diff_graph.outputs())
|
|
self.assertEqual(1, len(outputs))
|
|
output = outputs[0]
|
|
self.assertEqual(False, output.requiresGrad())
|
|
|
|
FileCheck().check("= prim::DifferentiableGraph").check(
|
|
"with prim::DifferentiableGraph"
|
|
).check(" = aten::relu").check("requires_grad=0").check("aten::relu").run(graph)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_jit.py")
|