Enable numel tracing

clang-format

resolve onnx test failure

update expect file

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74081

Approved by: https://github.com/garymm, https://github.com/eellison, https://github.com/malfet
This commit is contained in:
BowenBao
2022-04-07 11:06:33 -07:00
committed by PyTorch MergeBot
parent db20a3b014
commit aa51ee2345
6 changed files with 161 additions and 19 deletions

View File

@ -377,6 +377,17 @@ class TestTracer(JitTestCase):
def test_trace_size_with_grad(self):
self.do_trace_size(True)
def test_trace_numel(self):
def fn(x):
return x.numel()
x = torch.randn(2, 3, 4)
y = torch.randn(4, 5, 6)
traced_fn = torch.jit.trace(fn, x)
self.assertEqual(traced_fn(y), fn(y))
self.assertEqual(traced_fn(x), fn(x))
def do_trace_arange(self, requires_grad):
def arange(x):
return torch.arange(x.shape[0])

View File

@ -3,21 +3,113 @@ producer_name: "pytorch"
producer_version: "CURRENT_VERSION"
graph {
node {
input: "onnx::Reshape_0"
input: "onnx::Reshape_11"
output: "8"
name: "Reshape_0"
input: "onnx::Shape_0"
output: "onnx::ReduceProd_2"
name: "Shape_0"
op_type: "Shape"
}
node {
input: "onnx::ReduceProd_2"
output: "onnx::Div_3"
name: "ReduceProd_1"
op_type: "ReduceProd"
attribute {
name: "keepdims"
i: 0
type: INT
}
}
node {
output: "onnx::Div_4"
name: "Constant_2"
op_type: "Constant"
attribute {
name: "value"
t {
data_type: 7
raw_data: "\001\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "onnx::Div_3"
input: "onnx::Div_4"
output: "onnx::Cast_5"
name: "Div_3"
op_type: "Div"
}
node {
input: "onnx::Cast_5"
output: "onnx::Cast_6"
name: "Cast_4"
op_type: "Cast"
attribute {
name: "to"
i: 7
type: INT
}
}
node {
input: "onnx::Cast_6"
output: "onnx::Unsqueeze_7"
name: "Cast_5"
op_type: "Cast"
attribute {
name: "to"
i: 7
type: INT
}
}
node {
output: "onnx::Unsqueeze_10"
name: "Constant_6"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
raw_data: "\000\000\000\000\000\000\000\000"
}
type: TENSOR
}
}
node {
input: "onnx::Unsqueeze_7"
input: "onnx::Unsqueeze_10"
output: "onnx::Concat_11"
name: "Unsqueeze_7"
op_type: "Unsqueeze"
}
node {
input: "onnx::Concat_14"
input: "onnx::Concat_11"
output: "onnx::Reshape_12"
name: "Concat_8"
op_type: "Concat"
attribute {
name: "axis"
i: 0
type: INT
}
}
node {
input: "onnx::Shape_0"
input: "onnx::Reshape_12"
output: "13"
name: "Reshape_9"
op_type: "Reshape"
}
name: "torch_jit"
initializer {
dims: 2
dims: 1
data_type: 7
name: "onnx::Reshape_11"
raw_data: "\001\000\000\000\000\000\000\000\030\000\000\000\000\000\000\000"
name: "onnx::Concat_14"
raw_data: "\001\000\000\000\000\000\000\000"
}
input {
name: "onnx::Reshape_0"
name: "onnx::Shape_0"
type {
tensor_type {
elem_type: 1
@ -39,16 +131,16 @@ graph {
}
}
output {
name: "8"
name: "13"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
dim_param: "Reshape13_dim_0"
}
dim {
dim_value: 24
dim_param: "Reshape13_dim_1"
}
}
}

View File

@ -6531,6 +6531,9 @@ class _TestONNXRuntime:
offset = torch.tensor([0, 2, 5, 6])
self.run_test(model, (input, offset))
@disableScriptTest() # error in propagate as assign input shape
@skipIfUnsupportedMinOpsetVersion(11)
def test_embedding_bag_with_offset(self):
model = torch.nn.EmbeddingBag(10, 5, mode="max")
input = torch.randint(10, (7, 5))
self.run_test(model, (input))
@ -6560,7 +6563,12 @@ class _TestONNXRuntime:
model = EmbeddingModel()
x = torch.randint(7, (2, 3))
w = torch.randn(2, 3)
self.run_test(model, (embedding_matrix, x, w))
x2 = torch.randint(7, (4, 3))
w2 = torch.randn(4, 3)
self.run_test(model, (embedding_matrix, x, w),
input_names=['embed', 'x', 'w'], dynamic_axes={'x': [0], 'w': [0]},
test_with_inputs=[(embedding_matrix, x2, w2)])
@disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
@skipIfUnsupportedMinOpsetVersion(11)
@ -6650,24 +6658,28 @@ class _TestONNXRuntime:
self.run_test(model, (x, batch1, batch2, alpha, beta))
def test_numel(self):
class MyModule(torch.jit.ScriptModule):
@torch.jit.script_method
class MyModule(torch.nn.Module):
def forward(self, input):
return input.numel() * input
x = torch.randn(2, 3, 5)
x2 = torch.randn(4, 5, 6)
model = MyModule()
self.run_test(model, (x,))
self.run_test(model, (x,),
input_names=['x'], dynamic_axes={'x': [0, 1, 2]},
test_with_inputs=[(x2,)])
def test_numel_empty(self):
class MyModule(torch.jit.ScriptModule):
@torch.jit.script_method
class MyModule(torch.nn.Module):
def forward(self, input):
return input.numel() * input
x = torch.randn(0)
x2 = torch.randn(4)
model = MyModule()
self.run_test(model, (x,))
self.run_test(model, (x,),
input_names=['x'], dynamic_axes={'x': [0]},
test_with_inputs=[(x2,)])
def test_dtype(self):
class MyModel(torch.jit.ScriptModule):

View File

@ -231,7 +231,11 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args)
return handle_torch_function(self, "numel", args);
}
auto& self_ = THPVariable_Unpack(self);
if (jit::tracer::isTracing()) {
return wrap(jit::tracer::getNumelOf(self_));
} else {
return THPUtils_packInt64(self_.numel());
}
END_HANDLE_TH_ERRORS
}

View File

@ -917,6 +917,27 @@ autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
return size_var;
}
autograd::Variable getNumelOf(const autograd::Variable& var) {
auto& tracing_state = getTracingState();
auto& graph = tracing_state->graph;
Variable numel_var;
{
// Make sure this scalar to tensor isn't traced!
at::AutoDispatchBelowADInplaceOrView guard;
numel_var = scalar_to_tensor(at::Scalar(var.numel()));
}
auto* value = getValueTrace(var);
auto* node = graph->insertNode(graph->create(Symbol::aten("numel"), {value}));
recordSourceLocation(node);
node->output()->setType(jit::IntType::get());
auto ten =
graph->insertNode(graph->createNumToTensor(node->output()))->output();
setValueTrace(numel_var, ten);
return numel_var;
}
void ensureUniqueIfOutOfPlaced(const char* name, const at::Tensor& tensor) {
auto& state = getTracingState();
if (state && state->force_outplace == false) {

View File

@ -390,6 +390,8 @@ TORCH_API autograd::Variable getSizeOf(
const autograd::Variable& var,
int64_t dim);
TORCH_API autograd::Variable getNumelOf(const autograd::Variable& var);
} // namespace tracer
} // namespace jit
} // namespace torch