Add new aten::device variant to TorchScript (#97023)

Fixes #96627

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97023
Approved by: https://github.com/jgong5, https://github.com/BowenBao, https://github.com/davidberard98
This commit is contained in:
Thiago Crepaldi
2023-04-06 14:19:00 +00:00
committed by PyTorch MergeBot
parent d1e7434bcf
commit bbf180af9f
5 changed files with 61 additions and 0 deletions

View File

@ -1241,6 +1241,30 @@ class TestQuantizeEagerONNXExport(common_utils.TestCase):
double_type_count += 1
self.assertNotEqual(double_type_count, 0)
@pytorch_test_common.skipIfNoCuda
def test_aten_device_with_index(self):
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
model = torch.compile(model, backend="onnxrt")
model = model.eval()
device = "cuda:0"
model = model.to(device)
ids = tokenizer.batch_encode_plus(["This is a test"], return_tensors="pt").to(
device
)
with torch.no_grad():
_ = model(
**{
"input_ids": ids["input_ids"],
"attention_mask": ids["attention_mask"],
"decoder_input_ids": ids["input_ids"],
"decoder_attention_mask": ids["attention_mask"],
}
)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -92,6 +92,14 @@ void device(Stack& stack) {
push(stack, pop(stack).toTensor().device());
}
void device_with_index(Stack& stack) {
std::string type = pop(stack).toStringRef();
int index = pop(stack).toInt();
std::string device_str = type + ":" + std::to_string(index);
auto device = c10::Device(device_str);
push(stack, device);
}
void dtype(Stack& stack) {
at::Tensor a;
pop(stack, a);

View File

@ -33,6 +33,8 @@ void sym_stride(Stack& stack);
void device(Stack& stack);
void device_with_index(Stack& stack);
void dtype(Stack& stack);
void layout(Stack& stack);

View File

@ -269,6 +269,28 @@ struct PeepholeOptimizeImpl {
node->output()->replaceAllUsesWith(output);
changed = true;
}
} else if (
node->matches("aten::device(str type, int index) -> Device") &&
shape_peepholes_) {
auto string_type = node->inputs().at(0)->type()->expect<StringType>();
if (string_type) {
WithInsertPoint guard(node);
std::string type_str = node->inputs().at(0)->node()->s(attr::value);
auto maybe_index = toIValue(node->inputs().at(1));
int64_t index = 0;
if (maybe_index) {
index = maybe_index->toInt();
}
auto device = c10::Device(type_str + ":" + std::to_string(index));
auto output = node->owningGraph()->insertConstant(device);
GRAPH_UPDATE(
"Replacing ",
getHeader(node),
" with a device constant ",
output->debugName());
node->output()->replaceAllUsesWith(output);
changed = true;
}
} else if (
node->matches("aten::dim(Tensor self) -> int") && shape_peepholes_) {
auto ptt = node->input()->type()->expect<TensorType>();

View File

@ -2292,6 +2292,11 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs1{
push(stack, c10::Device(pop(stack).toStringRef()));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::device.with_index(str type, int index) -> Device"),
device_with_index,
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::percentFormat(str self, ...) -> str"),
[](Stack& stack) {