mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Enable numel tracing
clang-format resolve onnx test failure update expect file Pull Request resolved: https://github.com/pytorch/pytorch/pull/74081 Approved by: https://github.com/garymm, https://github.com/eellison, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
db20a3b014
commit
aa51ee2345
@ -377,6 +377,17 @@ class TestTracer(JitTestCase):
|
||||
def test_trace_size_with_grad(self):
|
||||
self.do_trace_size(True)
|
||||
|
||||
def test_trace_numel(self):
|
||||
def fn(x):
|
||||
return x.numel()
|
||||
|
||||
x = torch.randn(2, 3, 4)
|
||||
y = torch.randn(4, 5, 6)
|
||||
|
||||
traced_fn = torch.jit.trace(fn, x)
|
||||
self.assertEqual(traced_fn(y), fn(y))
|
||||
self.assertEqual(traced_fn(x), fn(x))
|
||||
|
||||
def do_trace_arange(self, requires_grad):
|
||||
def arange(x):
|
||||
return torch.arange(x.shape[0])
|
||||
|
@ -3,21 +3,113 @@ producer_name: "pytorch"
|
||||
producer_version: "CURRENT_VERSION"
|
||||
graph {
|
||||
node {
|
||||
input: "onnx::Reshape_0"
|
||||
input: "onnx::Reshape_11"
|
||||
output: "8"
|
||||
name: "Reshape_0"
|
||||
input: "onnx::Shape_0"
|
||||
output: "onnx::ReduceProd_2"
|
||||
name: "Shape_0"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
input: "onnx::ReduceProd_2"
|
||||
output: "onnx::Div_3"
|
||||
name: "ReduceProd_1"
|
||||
op_type: "ReduceProd"
|
||||
attribute {
|
||||
name: "keepdims"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "onnx::Div_4"
|
||||
name: "Constant_2"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
data_type: 7
|
||||
raw_data: "\001\000\000\000\000\000\000\000"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "onnx::Div_3"
|
||||
input: "onnx::Div_4"
|
||||
output: "onnx::Cast_5"
|
||||
name: "Div_3"
|
||||
op_type: "Div"
|
||||
}
|
||||
node {
|
||||
input: "onnx::Cast_5"
|
||||
output: "onnx::Cast_6"
|
||||
name: "Cast_4"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 7
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "onnx::Cast_6"
|
||||
output: "onnx::Unsqueeze_7"
|
||||
name: "Cast_5"
|
||||
op_type: "Cast"
|
||||
attribute {
|
||||
name: "to"
|
||||
i: 7
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "onnx::Unsqueeze_10"
|
||||
name: "Constant_6"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
raw_data: "\000\000\000\000\000\000\000\000"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "onnx::Unsqueeze_7"
|
||||
input: "onnx::Unsqueeze_10"
|
||||
output: "onnx::Concat_11"
|
||||
name: "Unsqueeze_7"
|
||||
op_type: "Unsqueeze"
|
||||
}
|
||||
node {
|
||||
input: "onnx::Concat_14"
|
||||
input: "onnx::Concat_11"
|
||||
output: "onnx::Reshape_12"
|
||||
name: "Concat_8"
|
||||
op_type: "Concat"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "onnx::Shape_0"
|
||||
input: "onnx::Reshape_12"
|
||||
output: "13"
|
||||
name: "Reshape_9"
|
||||
op_type: "Reshape"
|
||||
}
|
||||
name: "torch_jit"
|
||||
initializer {
|
||||
dims: 2
|
||||
dims: 1
|
||||
data_type: 7
|
||||
name: "onnx::Reshape_11"
|
||||
raw_data: "\001\000\000\000\000\000\000\000\030\000\000\000\000\000\000\000"
|
||||
name: "onnx::Concat_14"
|
||||
raw_data: "\001\000\000\000\000\000\000\000"
|
||||
}
|
||||
input {
|
||||
name: "onnx::Reshape_0"
|
||||
name: "onnx::Shape_0"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
@ -39,16 +131,16 @@ graph {
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "8"
|
||||
name: "13"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
dim_param: "Reshape13_dim_0"
|
||||
}
|
||||
dim {
|
||||
dim_value: 24
|
||||
dim_param: "Reshape13_dim_1"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6531,6 +6531,9 @@ class _TestONNXRuntime:
|
||||
offset = torch.tensor([0, 2, 5, 6])
|
||||
self.run_test(model, (input, offset))
|
||||
|
||||
@disableScriptTest() # error in propagate as assign input shape
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_embedding_bag_with_offset(self):
|
||||
model = torch.nn.EmbeddingBag(10, 5, mode="max")
|
||||
input = torch.randint(10, (7, 5))
|
||||
self.run_test(model, (input))
|
||||
@ -6560,7 +6563,12 @@ class _TestONNXRuntime:
|
||||
model = EmbeddingModel()
|
||||
x = torch.randint(7, (2, 3))
|
||||
w = torch.randn(2, 3)
|
||||
self.run_test(model, (embedding_matrix, x, w))
|
||||
|
||||
x2 = torch.randint(7, (4, 3))
|
||||
w2 = torch.randn(4, 3)
|
||||
self.run_test(model, (embedding_matrix, x, w),
|
||||
input_names=['embed', 'x', 'w'], dynamic_axes={'x': [0], 'w': [0]},
|
||||
test_with_inputs=[(embedding_matrix, x2, w2)])
|
||||
|
||||
@disableScriptTest() # scripting prim::Uninitialized, prim::dtype, prim::unchecked_cast
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
@ -6650,24 +6658,28 @@ class _TestONNXRuntime:
|
||||
self.run_test(model, (x, batch1, batch2, alpha, beta))
|
||||
|
||||
def test_numel(self):
|
||||
class MyModule(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return input.numel() * input
|
||||
|
||||
x = torch.randn(2, 3, 5)
|
||||
x2 = torch.randn(4, 5, 6)
|
||||
model = MyModule()
|
||||
self.run_test(model, (x,))
|
||||
self.run_test(model, (x,),
|
||||
input_names=['x'], dynamic_axes={'x': [0, 1, 2]},
|
||||
test_with_inputs=[(x2,)])
|
||||
|
||||
def test_numel_empty(self):
|
||||
class MyModule(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
return input.numel() * input
|
||||
|
||||
x = torch.randn(0)
|
||||
x2 = torch.randn(4)
|
||||
model = MyModule()
|
||||
self.run_test(model, (x,))
|
||||
self.run_test(model, (x,),
|
||||
input_names=['x'], dynamic_axes={'x': [0]},
|
||||
test_with_inputs=[(x2,)])
|
||||
|
||||
def test_dtype(self):
|
||||
class MyModel(torch.jit.ScriptModule):
|
||||
|
@ -231,7 +231,11 @@ static PyObject * THPVariable_numel(PyObject* self, PyObject* args)
|
||||
return handle_torch_function(self, "numel", args);
|
||||
}
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
if (jit::tracer::isTracing()) {
|
||||
return wrap(jit::tracer::getNumelOf(self_));
|
||||
} else {
|
||||
return THPUtils_packInt64(self_.numel());
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
@ -917,6 +917,27 @@ autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
|
||||
return size_var;
|
||||
}
|
||||
|
||||
autograd::Variable getNumelOf(const autograd::Variable& var) {
|
||||
auto& tracing_state = getTracingState();
|
||||
auto& graph = tracing_state->graph;
|
||||
|
||||
Variable numel_var;
|
||||
{
|
||||
// Make sure this scalar to tensor isn't traced!
|
||||
at::AutoDispatchBelowADInplaceOrView guard;
|
||||
numel_var = scalar_to_tensor(at::Scalar(var.numel()));
|
||||
}
|
||||
auto* value = getValueTrace(var);
|
||||
auto* node = graph->insertNode(graph->create(Symbol::aten("numel"), {value}));
|
||||
recordSourceLocation(node);
|
||||
node->output()->setType(jit::IntType::get());
|
||||
|
||||
auto ten =
|
||||
graph->insertNode(graph->createNumToTensor(node->output()))->output();
|
||||
setValueTrace(numel_var, ten);
|
||||
return numel_var;
|
||||
}
|
||||
|
||||
void ensureUniqueIfOutOfPlaced(const char* name, const at::Tensor& tensor) {
|
||||
auto& state = getTracingState();
|
||||
if (state && state->force_outplace == false) {
|
||||
|
@ -390,6 +390,8 @@ TORCH_API autograd::Variable getSizeOf(
|
||||
const autograd::Variable& var,
|
||||
int64_t dim);
|
||||
|
||||
TORCH_API autograd::Variable getNumelOf(const autograd::Variable& var);
|
||||
|
||||
} // namespace tracer
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user