mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 23:04:52 +08:00
Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
3ddec713b8 | |||
85eeb90d2c |
@ -242,7 +242,7 @@ Tensor add(Tensor qa, Tensor qb, double output_scale, int64_t output_zero_point)
|
||||
run(plan_desc);
|
||||
execution_plan_cache[key] = plan_desc;
|
||||
return quantized_output.view(orig_sizes);
|
||||
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
|
||||
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Add Cudnn");
|
||||
|
@ -252,7 +252,7 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
||||
run(plan);
|
||||
execution_plan_cache.emplace(key, plan);
|
||||
return;
|
||||
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
|
||||
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Conv2D Cudnn");
|
||||
|
@ -286,7 +286,7 @@ void PackedLinearWeightCudnn::apply_impl_helper(const at::Tensor& quantized_outp
|
||||
run(plan);
|
||||
execution_plan_cache.emplace(key, plan);
|
||||
return;
|
||||
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
|
||||
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Unable to find an engine to execute this computation Quantized Linear Cudnn");
|
||||
|
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
|
||||
|
||||
|
||||
|
||||
BartForCausalLM,pass,12
|
||||
BartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BartForConditionalGeneration,pass,24
|
||||
BartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForCausalLM,pass,12
|
||||
BlenderbotSmallForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForConditionalGeneration,pass,24
|
||||
BlenderbotSmallForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
|
||||
|
||||
|
||||
|
||||
MBartForCausalLM,pass,12
|
||||
MBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
MBartForConditionalGeneration,pass,24
|
||||
MBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
|
||||
|
||||
|
||||
|
||||
OPTForCausalLM,pass,12
|
||||
OPTForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForCausalLM,pass,12
|
||||
PLBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForConditionalGeneration,pass,29
|
||||
PLBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
PegasusForCausalLM,pass,12
|
||||
PegasusForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PegasusForConditionalGeneration,pass,23
|
||||
PegasusForConditionalGeneration,pass,7
|
||||
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
|
||||
|
||||
|
||||
|
||||
Speech2Text2ForCausalLM,pass,12
|
||||
Speech2Text2ForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
@ -170,11 +170,11 @@ T5Small,pass,5
|
||||
|
||||
|
||||
|
||||
TrOCRForCausalLM,pass,12
|
||||
TrOCRForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
XGLMForCausalLM,pass,12
|
||||
XGLMForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
|
|
@ -150,7 +150,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,46
|
||||
hf_BigBird,pass,43
|
||||
|
||||
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass, 52
|
||||
hf_BigBird,pass,49
|
||||
|
||||
|
||||
|
||||
|
|
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
|
||||
|
||||
|
||||
|
||||
BartForCausalLM,pass,12
|
||||
BartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BartForConditionalGeneration,pass,24
|
||||
BartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForCausalLM,pass,12
|
||||
BlenderbotSmallForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForConditionalGeneration,pass,24
|
||||
BlenderbotSmallForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
|
||||
|
||||
|
||||
|
||||
MBartForCausalLM,pass,12
|
||||
MBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
MBartForConditionalGeneration,pass,24
|
||||
MBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
|
||||
|
||||
|
||||
|
||||
OPTForCausalLM,pass,12
|
||||
OPTForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForCausalLM,pass,12
|
||||
PLBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForConditionalGeneration,pass,29
|
||||
PLBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
PegasusForCausalLM,pass,12
|
||||
PegasusForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PegasusForConditionalGeneration,pass,23
|
||||
PegasusForConditionalGeneration,pass,7
|
||||
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
|
||||
|
||||
|
||||
|
||||
Speech2Text2ForCausalLM,pass,12
|
||||
Speech2Text2ForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
@ -170,11 +170,11 @@ T5Small,pass,5
|
||||
|
||||
|
||||
|
||||
TrOCRForCausalLM,pass,12
|
||||
TrOCRForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
XGLMForCausalLM,pass,12
|
||||
XGLMForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
|
|
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
|
||||
|
||||
|
||||
|
||||
BartForCausalLM,pass,12
|
||||
BartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BartForConditionalGeneration,pass,24
|
||||
BartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForCausalLM,pass,12
|
||||
BlenderbotSmallForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForConditionalGeneration,pass,24
|
||||
BlenderbotSmallForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
|
||||
|
||||
|
||||
|
||||
MBartForCausalLM,pass,12
|
||||
MBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
MBartForConditionalGeneration,pass,24
|
||||
MBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
|
||||
|
||||
|
||||
|
||||
OPTForCausalLM,pass,12
|
||||
OPTForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForCausalLM,pass,12
|
||||
PLBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForConditionalGeneration,pass,29
|
||||
PLBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
PegasusForCausalLM,pass,12
|
||||
PegasusForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PegasusForConditionalGeneration,pass,23
|
||||
PegasusForConditionalGeneration,pass,7
|
||||
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
|
||||
|
||||
|
||||
|
||||
Speech2Text2ForCausalLM,pass,12
|
||||
Speech2Text2ForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
@ -170,11 +170,11 @@ T5Small,pass,5
|
||||
|
||||
|
||||
|
||||
TrOCRForCausalLM,pass,12
|
||||
TrOCRForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
XGLMForCausalLM,pass,12
|
||||
XGLMForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
|
|
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
|
||||
|
||||
|
||||
|
||||
BartForCausalLM,pass,12
|
||||
BartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BartForConditionalGeneration,pass,24
|
||||
BartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForCausalLM,pass,12
|
||||
BlenderbotSmallForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForConditionalGeneration,pass,24
|
||||
BlenderbotSmallForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
|
||||
|
||||
|
||||
|
||||
MBartForCausalLM,pass,12
|
||||
MBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
MBartForConditionalGeneration,pass,24
|
||||
MBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
|
||||
|
||||
|
||||
|
||||
OPTForCausalLM,pass,12
|
||||
OPTForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForCausalLM,pass,12
|
||||
PLBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForConditionalGeneration,pass,29
|
||||
PLBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
PegasusForCausalLM,pass,12
|
||||
PegasusForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PegasusForConditionalGeneration,pass,23
|
||||
PegasusForConditionalGeneration,pass,7
|
||||
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
|
||||
|
||||
|
||||
|
||||
Speech2Text2ForCausalLM,pass,12
|
||||
Speech2Text2ForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
@ -170,11 +170,11 @@ T5Small,pass,5
|
||||
|
||||
|
||||
|
||||
TrOCRForCausalLM,pass,12
|
||||
TrOCRForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
XGLMForCausalLM,pass,12
|
||||
XGLMForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
|
|
@ -150,7 +150,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,46
|
||||
hf_BigBird,pass,43
|
||||
|
||||
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,52
|
||||
hf_BigBird,pass,49
|
||||
|
||||
|
||||
|
||||
|
|
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
|
||||
|
||||
|
||||
|
||||
BartForCausalLM,pass,12
|
||||
BartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BartForConditionalGeneration,pass,24
|
||||
BartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForCausalLM,pass,12
|
||||
BlenderbotSmallForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForConditionalGeneration,pass,24
|
||||
BlenderbotSmallForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
|
||||
|
||||
|
||||
|
||||
MBartForCausalLM,pass,12
|
||||
MBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
MBartForConditionalGeneration,pass,24
|
||||
MBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
|
||||
|
||||
|
||||
|
||||
OPTForCausalLM,pass,12
|
||||
OPTForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForCausalLM,pass,12
|
||||
PLBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForConditionalGeneration,pass,29
|
||||
PLBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
PegasusForCausalLM,pass,12
|
||||
PegasusForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PegasusForConditionalGeneration,pass,23
|
||||
PegasusForConditionalGeneration,pass,7
|
||||
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
|
||||
|
||||
|
||||
|
||||
Speech2Text2ForCausalLM,pass,12
|
||||
Speech2Text2ForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
@ -170,11 +170,11 @@ T5Small,pass,5
|
||||
|
||||
|
||||
|
||||
TrOCRForCausalLM,pass,12
|
||||
TrOCRForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
XGLMForCausalLM,pass,12
|
||||
XGLMForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
|
|
@ -150,7 +150,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_accuracy,46
|
||||
hf_BigBird,fail_accuracy,43
|
||||
|
||||
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,52
|
||||
hf_BigBird,pass,49
|
||||
|
||||
|
||||
|
||||
|
|
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
|
||||
|
||||
|
||||
|
||||
BartForCausalLM,pass,12
|
||||
BartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BartForConditionalGeneration,pass,24
|
||||
BartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForCausalLM,pass,12
|
||||
BlenderbotSmallForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForConditionalGeneration,pass,24
|
||||
BlenderbotSmallForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
|
||||
|
||||
|
||||
|
||||
MBartForCausalLM,pass,12
|
||||
MBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
MBartForConditionalGeneration,pass,24
|
||||
MBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
|
||||
|
||||
|
||||
|
||||
OPTForCausalLM,pass,12
|
||||
OPTForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForCausalLM,pass,12
|
||||
PLBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForConditionalGeneration,pass,29
|
||||
PLBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
PegasusForCausalLM,pass,12
|
||||
PegasusForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PegasusForConditionalGeneration,pass,23
|
||||
PegasusForConditionalGeneration,pass,7
|
||||
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
|
||||
|
||||
|
||||
|
||||
Speech2Text2ForCausalLM,pass,12
|
||||
Speech2Text2ForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
@ -170,11 +170,11 @@ T5Small,pass,5
|
||||
|
||||
|
||||
|
||||
TrOCRForCausalLM,pass,12
|
||||
TrOCRForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
XGLMForCausalLM,pass,12
|
||||
XGLMForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
|
|
@ -150,7 +150,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,46
|
||||
hf_BigBird,pass,43
|
||||
|
||||
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,52
|
||||
hf_BigBird,pass,49
|
||||
|
||||
|
||||
|
||||
|
|
@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
|
||||
|
||||
|
||||
|
||||
BartForCausalLM,pass,12
|
||||
BartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BartForConditionalGeneration,pass,24
|
||||
BartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForCausalLM,pass,12
|
||||
BlenderbotSmallForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
BlenderbotSmallForConditionalGeneration,pass,24
|
||||
BlenderbotSmallForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
|
||||
|
||||
|
||||
|
||||
MBartForCausalLM,pass,12
|
||||
MBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
MBartForConditionalGeneration,pass,24
|
||||
MBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
|
||||
|
||||
|
||||
|
||||
OPTForCausalLM,pass,12
|
||||
OPTForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForCausalLM,pass,12
|
||||
PLBartForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PLBartForConditionalGeneration,pass,29
|
||||
PLBartForConditionalGeneration,pass,8
|
||||
|
||||
|
||||
|
||||
PegasusForCausalLM,pass,12
|
||||
PegasusForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
PegasusForConditionalGeneration,pass,23
|
||||
PegasusForConditionalGeneration,pass,7
|
||||
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
|
||||
|
||||
|
||||
|
||||
Speech2Text2ForCausalLM,pass,12
|
||||
Speech2Text2ForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
@ -170,11 +170,11 @@ T5Small,pass,5
|
||||
|
||||
|
||||
|
||||
TrOCRForCausalLM,pass,12
|
||||
TrOCRForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
XGLMForCausalLM,pass,12
|
||||
XGLMForCausalLM,pass,6
|
||||
|
||||
|
||||
|
||||
|
|
@ -150,7 +150,7 @@ hf_Bert_large,pass,0
|
||||
|
||||
|
||||
|
||||
hf_BigBird,fail_accuracy,46
|
||||
hf_BigBird,fail_accuracy,43
|
||||
|
||||
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ hf_Bert_large,pass,6
|
||||
|
||||
|
||||
|
||||
hf_BigBird,pass,52
|
||||
hf_BigBird,pass,49
|
||||
|
||||
|
||||
|
||||
|
|
@ -101,6 +101,15 @@ class TestModelOutput(torch._dynamo.test_case.TestCase):
|
||||
|
||||
self._common(fn, 2)
|
||||
|
||||
@maybe_skip
|
||||
def test_mo_getattr_missing(self):
|
||||
def fn(obj: BaseModelOutput):
|
||||
if getattr(obj, "asdf", None) is not None:
|
||||
obj.asdf += 1
|
||||
return obj.attentions + 1
|
||||
|
||||
self._common(fn, 1)
|
||||
|
||||
@maybe_skip
|
||||
def test_mo_getitem(self):
|
||||
def fn(obj: BaseModelOutput):
|
||||
@ -166,6 +175,59 @@ class TestModelOutput(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 2)
|
||||
|
||||
@maybe_skip
|
||||
def test_mo_init2(self):
|
||||
# this ModelOutput subclass runs a different __post_init__ codepath
|
||||
@dataclasses.dataclass
|
||||
class MyDataClass(ModelOutput):
|
||||
x: torch.FloatTensor = None
|
||||
|
||||
def fn(x):
|
||||
obj = MyDataClass(x=x)
|
||||
return obj
|
||||
|
||||
inp = torch.randn(3, 3)
|
||||
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
||||
self.assertEqual(fn(inp).x, opt_fn(inp).x)
|
||||
|
||||
@maybe_skip
|
||||
def test_mo_init_with_disable(self):
|
||||
# Can result in "non-function or method super: <slot wrapper '__setattr__' of 'object' objects>"
|
||||
# graph breaks (although it may not be the first)
|
||||
# Minimal repro for https://github.com/pytorch/pytorch/issues/126028
|
||||
@dataclasses.dataclass
|
||||
class MyDataClass(ModelOutput):
|
||||
x: torch.FloatTensor = None
|
||||
|
||||
@torch._dynamo.disable(recursive=False)
|
||||
def fn(x):
|
||||
return MyDataClass(x=x)
|
||||
|
||||
inp = torch.randn(3, 3)
|
||||
opt_fn = torch._dynamo.optimize("eager")(fn)
|
||||
self.assertEqual(fn(inp).x, opt_fn(inp).x)
|
||||
|
||||
@maybe_skip
|
||||
def test_mo_newkey(self):
|
||||
obj = BaseModelOutput()
|
||||
|
||||
def fn(obj):
|
||||
return obj["wwww"] + 1
|
||||
|
||||
inp = torch.randn(3, 3)
|
||||
obj["wwww"] = inp
|
||||
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
||||
self.assertEqual(fn(obj), opt_fn(obj))
|
||||
|
||||
@maybe_skip
|
||||
def test_mo_from_outside(self):
|
||||
def fn(obj):
|
||||
return obj.attentions + 1
|
||||
|
||||
obj = BaseModelOutput(attentions=torch.randn(3, 3))
|
||||
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
||||
self.assertEqual(fn(obj), opt_fn(obj))
|
||||
|
||||
@maybe_skip
|
||||
def test_HF_bert_model_output(self):
|
||||
class BertPooler(torch.nn.Module):
|
||||
|
@ -4052,6 +4052,7 @@ class TestQuantizedLinear(TestCase):
|
||||
use_channelwise=st.sampled_from([False])) # channelwise currently not supported for qlinear cudnn
|
||||
@skipIfNoFBGEMM
|
||||
@unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
|
||||
@unittest.skipIf(TEST_CUDNN and torch.backends.cudnn.version() == 90100, "expected failure on cuDNN 9.1.0")
|
||||
@unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
|
||||
@unittest.skipIf(TEST_ROCM, "not supported on rocm.")
|
||||
# TODO: check with yang regarding CUDNN flags
|
||||
|
@ -7,6 +7,7 @@ import torch.nn
|
||||
|
||||
from . import utils, variables
|
||||
from .bytecode_transformation import (
|
||||
bytecode_from_template,
|
||||
create_call_function,
|
||||
create_call_method,
|
||||
create_instruction,
|
||||
@ -59,6 +60,11 @@ class AttributeMutationNew(AttributeMutation):
|
||||
self.cls_source = cls_source
|
||||
|
||||
|
||||
def _manual_update_dict(dict_from, dict_to):
|
||||
for k, v in dict_from.items():
|
||||
dict_to[k] = v
|
||||
|
||||
|
||||
class SideEffects:
|
||||
"""
|
||||
Track side effects (list mutation, setattr, etc) that need to be
|
||||
@ -480,6 +486,39 @@ class SideEffects:
|
||||
]
|
||||
)
|
||||
suffixes.append([create_instruction("STORE_SUBSCR")])
|
||||
elif isinstance(var, variables.CustomizedDictVariable):
|
||||
# need to update the dict manually since update method may be invalid
|
||||
varname_map = {}
|
||||
for name in _manual_update_dict.__code__.co_varnames:
|
||||
varname_map[name] = cg.tx.output.new_var()
|
||||
|
||||
cg(var.mutable_local.source) # type: ignore[attr-defined]
|
||||
cg.extend_output(
|
||||
[create_instruction("STORE_FAST", argval=varname_map["dict_to"])]
|
||||
)
|
||||
|
||||
cg(var, allow_cache=False)
|
||||
cg.extend_output(
|
||||
[create_instruction("STORE_FAST", argval=varname_map["dict_from"])]
|
||||
)
|
||||
|
||||
cg(var.mutable_local.source) # type: ignore[attr-defined]
|
||||
cg.extend_output([create_load_method("clear")])
|
||||
|
||||
# unfortunately can't just use DICT_MERGE due to possible custom behaviors
|
||||
dict_update_insts = bytecode_from_template(
|
||||
_manual_update_dict, varname_map=varname_map
|
||||
)
|
||||
|
||||
suffixes.append(
|
||||
[
|
||||
*create_call_method(0), # clear
|
||||
create_instruction("POP_TOP"),
|
||||
*dict_update_insts,
|
||||
create_instruction("POP_TOP"),
|
||||
]
|
||||
)
|
||||
|
||||
elif isinstance(var, variables.ConstDictVariable):
|
||||
cg.tx.output.update_co_names("clear")
|
||||
cg.tx.output.update_co_names("update")
|
||||
|
@ -23,7 +23,6 @@ from .ctx_manager import (
|
||||
from .dicts import (
|
||||
ConstDictVariable,
|
||||
CustomizedDictVariable,
|
||||
DataClassVariable,
|
||||
DefaultDictVariable,
|
||||
SetVariable,
|
||||
)
|
||||
@ -113,7 +112,6 @@ __all__ = [
|
||||
"CountIteratorVariable",
|
||||
"CustomizedDictVariable",
|
||||
"CycleIteratorVariable",
|
||||
"DataClassVariable",
|
||||
"DefaultDictVariable",
|
||||
"DeletedVariable",
|
||||
"DeterministicAlgorithmsVariable",
|
||||
|
@ -111,7 +111,7 @@ from .ctx_manager import (
|
||||
)
|
||||
from .dicts import (
|
||||
ConstDictVariable,
|
||||
DataClassVariable,
|
||||
CustomizedDictVariable,
|
||||
DefaultDictVariable,
|
||||
HFPretrainedConfigVariable,
|
||||
PythonSysModulesVariable,
|
||||
@ -493,6 +493,11 @@ class VariableBuilder:
|
||||
elif value is sys.modules:
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return PythonSysModulesVariable(source=self.source)
|
||||
elif CustomizedDictVariable.is_matching_cls_hf(type(value)):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
result = CustomizedDictVariable.wrap(self, value)
|
||||
result.source = self.source
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
|
||||
if not value and self.get_source().is_nn_module():
|
||||
# It is faster to guard on 'false' property than to guard
|
||||
@ -711,9 +716,6 @@ class VariableBuilder:
|
||||
)
|
||||
elif np and isinstance(value, np.number):
|
||||
return self.wrap_unspecialized_primitive(value)
|
||||
elif DataClassVariable.is_matching_object(value):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return DataClassVariable.wrap(self, value)
|
||||
elif HFPretrainedConfigVariable.is_matching_object(value):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return HFPretrainedConfigVariable(value)
|
||||
@ -1701,7 +1703,7 @@ class VariableBuilder:
|
||||
def _dataclasses_fields_lambda(obj):
|
||||
if isinstance(obj, UserDefinedObjectVariable):
|
||||
value = obj.value
|
||||
elif isinstance(obj, DataClassVariable):
|
||||
elif isinstance(obj, CustomizedDictVariable):
|
||||
value = obj.user_cls
|
||||
else:
|
||||
unimplemented(f"Dataclass fields handling fails for type {obj}")
|
||||
|
@ -1633,7 +1633,6 @@ class BuiltinVariable(VariableTracker):
|
||||
if isinstance(
|
||||
obj,
|
||||
(
|
||||
variables.DataClassVariable,
|
||||
variables.CustomizedDictVariable,
|
||||
variables.PlacementVariable,
|
||||
variables.UserDefinedObjectVariable,
|
||||
|
@ -545,6 +545,8 @@ class DictValues(DictView):
|
||||
|
||||
def _is_matching_transformers_cls(cls) -> bool:
|
||||
mod = sys.modules.get("transformers.file_utils")
|
||||
if mod is None:
|
||||
mod = sys.modules.get("transformers.utils.generic")
|
||||
return mod is not None and issubclass(cls, mod.ModelOutput)
|
||||
|
||||
|
||||
@ -555,12 +557,20 @@ def _is_matching_diffusers_cls(cls) -> bool:
|
||||
|
||||
def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker":
|
||||
"""Shared method between DataClassVariable and CustomizedDictVariable where items are attrs"""
|
||||
if tx.output.side_effects.is_attribute_mutation(self):
|
||||
try:
|
||||
result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
|
||||
return variables.ConstantVariable.create(
|
||||
not isinstance(result, variables.DeletedVariable)
|
||||
)
|
||||
except KeyError:
|
||||
pass
|
||||
if name in self.items or hasattr(self.user_cls, name):
|
||||
return ConstantVariable(True)
|
||||
elif istype(self.mutable_local, MutableLocal) and self.source is None:
|
||||
# Something created locally can't have any extra fields on it
|
||||
return ConstantVariable(False)
|
||||
elif self.mutable_local is None and self.source:
|
||||
elif self.source:
|
||||
# Maybe add a guard
|
||||
try:
|
||||
example = tx.output.root_tx.get_example_value(self.source)
|
||||
@ -577,152 +587,27 @@ def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker":
|
||||
|
||||
class DataClassVariable(ConstDictVariable):
|
||||
"""
|
||||
This is a bit of a hack to deal with
|
||||
transformers.file_utils.ModelOutput() from huggingface.
|
||||
This class doesn't appear to be used anywhere.
|
||||
It used to be used to deal with transformers.file_utils.ModelOutput
|
||||
from huggingface.
|
||||
|
||||
ModelOutput causes trouble because it a a mix of a dataclass and a
|
||||
OrderedDict and it calls super() methods implemented in C.
|
||||
Keeping since we wish to support dataclasses in general in the future
|
||||
"""
|
||||
|
||||
# ModelOutput() excludes None, though generic datclasses don't
|
||||
include_none = False
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def _patch_once():
|
||||
try:
|
||||
from transformers.file_utils import ModelOutput
|
||||
|
||||
for obj in ModelOutput.__dict__.values():
|
||||
if callable(obj):
|
||||
skip_code(obj.__code__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
for obj in BaseOutput.__dict__.values():
|
||||
if callable(obj):
|
||||
skip_code(obj.__code__)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_matching_cls(cls):
|
||||
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
|
||||
|
||||
@classmethod
|
||||
def is_matching_object(cls, obj):
|
||||
return cls.is_matching_cls(type(obj))
|
||||
|
||||
@classmethod
|
||||
def create(cls, user_cls, args, kwargs, options):
|
||||
DataClassVariable._patch_once()
|
||||
|
||||
skip_code(user_cls.__init__.__code__)
|
||||
keys = [f.name for f in dataclasses.fields(user_cls)]
|
||||
bound = inspect.signature(user_cls).bind(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
assert set(bound.arguments.keys()) == set(keys)
|
||||
items = {}
|
||||
for key in keys:
|
||||
val = bound.arguments[key]
|
||||
key = ConstantVariable.create(key)
|
||||
if isinstance(val, VariableTracker):
|
||||
items[key] = val
|
||||
else:
|
||||
if cls.include_none:
|
||||
assert variables.ConstantVariable.is_literal(val)
|
||||
items[key] = variables.ConstantVariable.create(val)
|
||||
else:
|
||||
assert val is None, f"unexpected {val}"
|
||||
|
||||
if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable):
|
||||
unimplemented("DataClassVariable iterator constructor")
|
||||
# TODO(jansel): implement unpacking logic in ModelOutput.__post_init__
|
||||
|
||||
return cls(items, user_cls, **options)
|
||||
|
||||
@classmethod
|
||||
def wrap(cls, builder, obj):
|
||||
user_cls = type(obj)
|
||||
keys = [f.name for f in dataclasses.fields(user_cls)]
|
||||
|
||||
excluded = []
|
||||
items = {}
|
||||
for key in keys:
|
||||
# __init__ function of a dataclass might not have yet defined the key
|
||||
if hasattr(obj, key):
|
||||
val = getattr(obj, key)
|
||||
var = builder.__class__(
|
||||
tx=builder.tx, source=AttrSource(builder.source, key)
|
||||
)(val)
|
||||
if val is not None or cls.include_none:
|
||||
key = ConstantVariable.create(key)
|
||||
items[key] = var
|
||||
else:
|
||||
excluded.append(var)
|
||||
return cls(items, user_cls)
|
||||
|
||||
def __init__(self, items, user_cls, **options):
|
||||
super().__init__(items, user_cls, **options)
|
||||
assert self.is_matching_cls(user_cls)
|
||||
|
||||
def as_proxy(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
codegen.extend_output([codegen._create_load_const(self.user_cls)])
|
||||
# All the keys are just wrapped strings
|
||||
d = self.keys_as_python_constant()
|
||||
codegen.foreach(d.values())
|
||||
keys = tuple(d.keys())
|
||||
codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True))
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if name == "__getitem__":
|
||||
assert not kwargs and len(args) == 1
|
||||
val = args[0]
|
||||
if val.python_type() == str:
|
||||
return self.getitem_const(val)
|
||||
else:
|
||||
return self.call_method(tx, "to_tuple", [], {}).call_method(
|
||||
tx, "__getitem__", args, kwargs
|
||||
)
|
||||
elif name == "to_tuple":
|
||||
assert not (args or kwargs)
|
||||
return variables.TupleVariable(list(self.items.values()))
|
||||
elif name == "__setattr__":
|
||||
name = "__setitem__"
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx, name: str) -> "VariableTracker":
|
||||
name_vt = ConstantVariable.create(name)
|
||||
if name_vt in self:
|
||||
return self.call_method(tx, "__getitem__", [name_vt], {})
|
||||
elif not self.include_none:
|
||||
defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
|
||||
if name in defaults:
|
||||
assert variables.ConstantVariable.is_literal(defaults[name])
|
||||
return variables.ConstantVariable.create(defaults[name])
|
||||
super().var_getattr(tx, name)
|
||||
|
||||
call_hasattr = _call_hasattr_customobj
|
||||
pass
|
||||
|
||||
|
||||
class CustomizedDictVariable(ConstDictVariable):
|
||||
@staticmethod
|
||||
def is_matching_cls_hf(cls):
|
||||
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
|
||||
|
||||
@staticmethod
|
||||
def is_matching_cls(cls):
|
||||
# True if using default OrderedDict.__init__ and did not implement __post_init__
|
||||
if (
|
||||
issubclass(cls, collections.OrderedDict)
|
||||
and cls is not collections.OrderedDict
|
||||
and cls.__init__ is collections.OrderedDict.__init__
|
||||
and not hasattr(cls, "__post_init__")
|
||||
):
|
||||
@ -730,7 +615,7 @@ class CustomizedDictVariable(ConstDictVariable):
|
||||
# hack for HF usecase:
|
||||
# assume dataclass annotation for ModelOutput subclass
|
||||
# assume self.create is AA to ModelOutput.__post_init__
|
||||
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
|
||||
return CustomizedDictVariable.is_matching_cls_hf(cls)
|
||||
|
||||
@classmethod
|
||||
def is_matching_object(cls, obj):
|
||||
@ -764,9 +649,7 @@ class CustomizedDictVariable(ConstDictVariable):
|
||||
)
|
||||
|
||||
bound_args = {}
|
||||
if _is_matching_transformers_cls(user_cls) or _is_matching_diffusers_cls(
|
||||
user_cls
|
||||
):
|
||||
if cls.is_matching_cls_hf(user_cls):
|
||||
# Skip none
|
||||
for k, v in bound.arguments.items():
|
||||
if isinstance(v, ConstantVariable) and v.value is None or v is None:
|
||||
@ -792,7 +675,27 @@ class CustomizedDictVariable(ConstDictVariable):
|
||||
# called from builder.py
|
||||
@classmethod
|
||||
def wrap(cls, builder, obj):
|
||||
raise NotImplementedError
|
||||
user_cls = type(obj)
|
||||
|
||||
if not cls.is_matching_cls_hf(user_cls):
|
||||
unimplemented("custom non-hf dict subclass wrap unimplemented")
|
||||
|
||||
items = builder.__class__(tx=builder.tx, source=builder.source)(
|
||||
collections.OrderedDict(obj)
|
||||
).items
|
||||
|
||||
keys = [f.name for f in dataclasses.fields(user_cls)]
|
||||
for key in keys:
|
||||
# __init__ function of a dataclass might not have yet defined the key
|
||||
if hasattr(obj, key):
|
||||
val = getattr(obj, key)
|
||||
var = builder.__class__(
|
||||
tx=builder.tx, source=AttrSource(builder.source, key)
|
||||
)(val)
|
||||
if val is not None:
|
||||
key = ConstantVariable.create(key)
|
||||
items[key] = var
|
||||
return cls(items, user_cls)
|
||||
|
||||
def __init__(self, items, user_cls, **options):
|
||||
super().__init__(items, user_cls, **options)
|
||||
@ -804,9 +707,7 @@ class CustomizedDictVariable(ConstDictVariable):
|
||||
# 'RETURN_VALUE triggered compile'
|
||||
# called from torch/_dynamo/codegen.py
|
||||
def reconstruct(self, codegen):
|
||||
is_hf_model_output = _is_matching_transformers_cls(
|
||||
self.user_cls
|
||||
) or _is_matching_diffusers_cls(self.user_cls)
|
||||
is_hf_model_output = self.is_matching_cls_hf(self.user_cls)
|
||||
|
||||
# If the user class is a ModelOutput, then wrap the instance creation in
|
||||
# torch._dynamo.disable(). Even though we mark the __post_init__ as skip
|
||||
@ -848,21 +749,34 @@ class CustomizedDictVariable(ConstDictVariable):
|
||||
):
|
||||
# for python dict method without overridden
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
elif name in ("__getitem__", "to_tuple", "__setitem__", "__setattr__"):
|
||||
elif name in (
|
||||
"__getitem__",
|
||||
"to_tuple",
|
||||
"__setitem__",
|
||||
"__setattr__",
|
||||
"__post_init__",
|
||||
):
|
||||
# for user overridden method
|
||||
return tx.inline_user_function_return(
|
||||
variables.UserFunctionVariable(fn, source=source),
|
||||
[self] + list(args),
|
||||
kwargs,
|
||||
)
|
||||
elif fn is getattr(collections.OrderedDict, name, None):
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
unimplemented("custom dict: call_method unimplemented name=%s", name)
|
||||
unimplemented(f"custom dict: call_method unimplemented name={name}")
|
||||
|
||||
def var_getattr(self, tx, name: str) -> "VariableTracker":
|
||||
name_vt = ConstantVariable.create(name)
|
||||
if name_vt in self:
|
||||
return self.call_method(tx, "__getitem__", [name_vt], {})
|
||||
super().var_getattr(tx, name)
|
||||
if dataclasses.is_dataclass(self.user_cls):
|
||||
defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
|
||||
if name in defaults:
|
||||
assert variables.ConstantVariable.is_literal(defaults[name])
|
||||
return variables.ConstantVariable.create(defaults[name])
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
call_hasattr = _call_hasattr_customobj
|
||||
|
||||
|
@ -171,6 +171,12 @@ class SuperVariable(VariableTracker):
|
||||
return super(variables.CustomizedDictVariable, self.objvar).call_method(
|
||||
tx, "__setitem__", args, kwargs
|
||||
)
|
||||
elif inner_fn is collections.OrderedDict.__getitem__ and isinstance(
|
||||
self.objvar, variables.CustomizedDictVariable
|
||||
):
|
||||
return super(variables.CustomizedDictVariable, self.objvar).call_method(
|
||||
tx, "__getitem__", args, kwargs
|
||||
)
|
||||
elif is_standard_setattr(inner_fn) and isinstance(
|
||||
self.objvar, UserDefinedObjectVariable
|
||||
):
|
||||
|
@ -396,9 +396,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
return variables.CustomizedDictVariable.create(
|
||||
self.value, args, kwargs, options
|
||||
)
|
||||
elif variables.DataClassVariable.is_matching_cls(self.value):
|
||||
options = {"mutable_local": MutableLocal()}
|
||||
return variables.DataClassVariable.create(self.value, args, kwargs, options)
|
||||
elif (
|
||||
variables.RestrictedListSubclassVariable.is_matching_cls(self.value)
|
||||
and self.source
|
||||
|
@ -3247,13 +3247,13 @@ void install_tensor_aliasing_guard(
|
||||
|
||||
void install_no_tensor_aliasing_guard(
|
||||
const py::list& guard_managers,
|
||||
py::list tensor_names,
|
||||
const py::list& tensor_names,
|
||||
py::object verbose_code_parts) {
|
||||
// Adds a guard that checks none of tensors alias. This is a an example of
|
||||
// relational guard. There is one guard object that is shared between multiple
|
||||
// guard managers.
|
||||
std::shared_ptr<RelationalGuard> guard = std::make_shared<NO_TENSOR_ALIASING>(
|
||||
std::move(tensor_names), std::move(verbose_code_parts));
|
||||
tensor_names, std::move(verbose_code_parts));
|
||||
|
||||
// Register the resetter on the toor guard mananger, so that it can reset
|
||||
// the newly added relational guard when the guard eval fails.
|
||||
@ -4006,7 +4006,15 @@ PyObject* torch_c_dynamo_guards_init() {
|
||||
DictSubclassGuardManager,
|
||||
DictGuardManager,
|
||||
std::unique_ptr<DictSubclassGuardManager>>(
|
||||
py_m, "DictSubclassGuardManager"); // NOLINT
|
||||
py_m, "DictSubclassGuardManager") // NOLINT
|
||||
.def(
|
||||
"add_no_hasattr_guard",
|
||||
[](DictSubclassGuardManager& self,
|
||||
py::object attr_name,
|
||||
py::object verbose_code_parts) -> void {
|
||||
self.add_permitted_leaf_guard(std::make_shared<NO_HASATTR>(
|
||||
std::move(attr_name), std::move(verbose_code_parts)));
|
||||
});
|
||||
|
||||
py_m.def("install_tensor_aliasing_guard", install_tensor_aliasing_guard);
|
||||
py_m.def(
|
||||
|
Reference in New Issue
Block a user