mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add pad and narrow to torch/csrc/stable/ops.h (#159328)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159328 Approved by: https://github.com/janeyx99 ghstack dependencies: #159507
This commit is contained in:
committed by
PyTorch MergeBot
parent
655137b678
commit
4d419a7461
@ -291,10 +291,43 @@ void boxed_fill_infinity(
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_pad(Tensor t) {
|
||||
std::vector<int64_t> padding = {1, 2, 2, 1};
|
||||
std::string mode = "constant";
|
||||
double value = 0.0;
|
||||
return pad(t, padding, mode, value);
|
||||
}
|
||||
|
||||
void boxed_my_pad(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_pad(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
|
||||
return narrow(t, dim, start, length);
|
||||
}
|
||||
|
||||
void boxed_my_narrow(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_narrow(
|
||||
to<Tensor>(stack[0]),
|
||||
to<int64_t>(stack[1]),
|
||||
to<int64_t>(stack[2]),
|
||||
to<int64_t>(stack[3]));
|
||||
stack[0] = from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor");
|
||||
m.def("my_empty_like(Tensor t) -> Tensor");
|
||||
m.def("fill_infinity(Tensor(a!) t) -> Tensor(a!)");
|
||||
m.def("my_pad(Tensor t) -> Tensor");
|
||||
m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
@ -303,6 +336,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("fill_infinity", &boxed_fill_infinity);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) {
|
||||
m.impl("my_pad", &boxed_my_pad);
|
||||
m.impl("my_narrow", &boxed_my_narrow);
|
||||
}
|
||||
|
||||
Tensor my_zero_(Tensor t) {
|
||||
return zero_(t);
|
||||
|
@ -176,3 +176,30 @@ def test_default_constructor(defined) -> bool:
|
||||
Returns: bool - result of calling .defined() on the tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.test_default_constructor.default(defined)
|
||||
|
||||
|
||||
def my_pad(t) -> Tensor:
|
||||
"""
|
||||
Pads the input tensor with hardcoded padding parameters.
|
||||
|
||||
Args:
|
||||
t: Input tensor
|
||||
|
||||
Returns: Padded tensor with padding [1, 2, 2, 1], mode "constant", value 0.0
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_pad.default(t)
|
||||
|
||||
|
||||
def my_narrow(t, dim, start, length) -> Tensor:
|
||||
"""
|
||||
Returns a new tensor that is a narrowed version of the input tensor.
|
||||
|
||||
Args:
|
||||
t: Input tensor
|
||||
dim: Dimension along which to narrow
|
||||
start: Starting position
|
||||
length: Length of the narrowed section
|
||||
|
||||
Returns: Narrowed tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_narrow.default(t, dim, start, length)
|
||||
|
@ -232,6 +232,26 @@ if not IS_WINDOWS:
|
||||
)
|
||||
self.assertFalse(undefined_tensor_is_defined)
|
||||
|
||||
def test_my_pad(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.rand(2, 3, device=device)
|
||||
out = libtorch_agnostic.ops.my_pad(t)
|
||||
expected = torch.nn.functional.pad(t, [1, 2, 2, 1], "constant", 0.0)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
def test_my_narrow(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t = torch.randn(2, 5, device=device)
|
||||
|
||||
dim0 = 0
|
||||
start0 = 0
|
||||
length0 = 1
|
||||
out0 = libtorch_agnostic.ops.my_narrow(t, dim0, start0, length0)
|
||||
expected0 = torch.narrow(t, dim0, start0, length0)
|
||||
self.assertEqual(out0, expected0)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user