mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add new aten::device variant to TorchScript (#97023)
Fixes #96627 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97023 Approved by: https://github.com/jgong5, https://github.com/BowenBao, https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
d1e7434bcf
commit
bbf180af9f
@ -1241,6 +1241,30 @@ class TestQuantizeEagerONNXExport(common_utils.TestCase):
|
||||
double_type_count += 1
|
||||
self.assertNotEqual(double_type_count, 0)
|
||||
|
||||
@pytorch_test_common.skipIfNoCuda
|
||||
def test_aten_device_with_index(self):
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
|
||||
model = torch.compile(model, backend="onnxrt")
|
||||
model = model.eval()
|
||||
device = "cuda:0"
|
||||
model = model.to(device)
|
||||
ids = tokenizer.batch_encode_plus(["This is a test"], return_tensors="pt").to(
|
||||
device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(
|
||||
**{
|
||||
"input_ids": ids["input_ids"],
|
||||
"attention_mask": ids["attention_mask"],
|
||||
"decoder_input_ids": ids["input_ids"],
|
||||
"decoder_attention_mask": ids["attention_mask"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
@ -92,6 +92,14 @@ void device(Stack& stack) {
|
||||
push(stack, pop(stack).toTensor().device());
|
||||
}
|
||||
|
||||
void device_with_index(Stack& stack) {
|
||||
std::string type = pop(stack).toStringRef();
|
||||
int index = pop(stack).toInt();
|
||||
std::string device_str = type + ":" + std::to_string(index);
|
||||
auto device = c10::Device(device_str);
|
||||
push(stack, device);
|
||||
}
|
||||
|
||||
void dtype(Stack& stack) {
|
||||
at::Tensor a;
|
||||
pop(stack, a);
|
||||
|
@ -33,6 +33,8 @@ void sym_stride(Stack& stack);
|
||||
|
||||
void device(Stack& stack);
|
||||
|
||||
void device_with_index(Stack& stack);
|
||||
|
||||
void dtype(Stack& stack);
|
||||
|
||||
void layout(Stack& stack);
|
||||
|
@ -269,6 +269,28 @@ struct PeepholeOptimizeImpl {
|
||||
node->output()->replaceAllUsesWith(output);
|
||||
changed = true;
|
||||
}
|
||||
} else if (
|
||||
node->matches("aten::device(str type, int index) -> Device") &&
|
||||
shape_peepholes_) {
|
||||
auto string_type = node->inputs().at(0)->type()->expect<StringType>();
|
||||
if (string_type) {
|
||||
WithInsertPoint guard(node);
|
||||
std::string type_str = node->inputs().at(0)->node()->s(attr::value);
|
||||
auto maybe_index = toIValue(node->inputs().at(1));
|
||||
int64_t index = 0;
|
||||
if (maybe_index) {
|
||||
index = maybe_index->toInt();
|
||||
}
|
||||
auto device = c10::Device(type_str + ":" + std::to_string(index));
|
||||
auto output = node->owningGraph()->insertConstant(device);
|
||||
GRAPH_UPDATE(
|
||||
"Replacing ",
|
||||
getHeader(node),
|
||||
" with a device constant ",
|
||||
output->debugName());
|
||||
node->output()->replaceAllUsesWith(output);
|
||||
changed = true;
|
||||
}
|
||||
} else if (
|
||||
node->matches("aten::dim(Tensor self) -> int") && shape_peepholes_) {
|
||||
auto ptt = node->input()->type()->expect<TensorType>();
|
||||
|
@ -2292,6 +2292,11 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs1{
|
||||
push(stack, c10::Device(pop(stack).toStringRef()));
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA(
|
||||
"aten::device.with_index(str type, int index) -> Device"),
|
||||
device_with_index,
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGeneratorArgs(
|
||||
TORCH_SELECTIVE_SCHEMA("aten::percentFormat(str self, ...) -> str"),
|
||||
[](Stack& stack) {
|
||||
|
Reference in New Issue
Block a user