mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-25 08:11:06 +08:00 
			
		
		
		
	Compare commits
	
		
			2 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3ddec713b8 | |||
| 85eeb90d2c | 
| @ -242,7 +242,7 @@ Tensor add(Tensor qa, Tensor qb, double output_scale, int64_t output_zero_point) | ||||
|       run(plan_desc); | ||||
|       execution_plan_cache[key] = plan_desc; | ||||
|       return quantized_output.view(orig_sizes); | ||||
|     } catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {} | ||||
|     } catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;} | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Add Cudnn"); | ||||
|  | ||||
| @ -252,7 +252,7 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua | ||||
|       run(plan); | ||||
|       execution_plan_cache.emplace(key, plan); | ||||
|       return; | ||||
|     } catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {} | ||||
|     } catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;} | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Conv2D Cudnn"); | ||||
|  | ||||
| @ -286,7 +286,7 @@ void PackedLinearWeightCudnn::apply_impl_helper(const at::Tensor& quantized_outp | ||||
|       run(plan); | ||||
|       execution_plan_cache.emplace(key, plan); | ||||
|       return; | ||||
|     } catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {} | ||||
|     } catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;} | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK(false, "Unable to find an engine to execute this computation Quantized Linear Cudnn"); | ||||
|  | ||||
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForCausalLM,pass,12 | ||||
| BartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForConditionalGeneration,pass,24 | ||||
| BartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForCausalLM,pass,12 | ||||
| BlenderbotSmallForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForConditionalGeneration,pass,24 | ||||
| BlenderbotSmallForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForCausalLM,pass,12 | ||||
| MBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForConditionalGeneration,pass,24 | ||||
| MBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | ||||
|  | ||||
|  | ||||
|  | ||||
| OPTForCausalLM,pass,12 | ||||
| OPTForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForCausalLM,pass,12 | ||||
| PLBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForConditionalGeneration,pass,29 | ||||
| PLBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForCausalLM,pass,12 | ||||
| PegasusForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForConditionalGeneration,pass,23 | ||||
| PegasusForConditionalGeneration,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| Speech2Text2ForCausalLM,pass,12 | ||||
| Speech2Text2ForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| TrOCRForCausalLM,pass,12 | ||||
| TrOCRForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| XGLMForCausalLM,pass,12 | ||||
| XGLMForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -150,7 +150,7 @@ hf_Bert_large,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| hf_BigBird,pass,46 | ||||
| hf_BigBird,pass,43 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -98,7 +98,7 @@ hf_Bert_large,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| hf_BigBird,pass, 52 | ||||
| hf_BigBird,pass,49 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForCausalLM,pass,12 | ||||
| BartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForConditionalGeneration,pass,24 | ||||
| BartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForCausalLM,pass,12 | ||||
| BlenderbotSmallForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForConditionalGeneration,pass,24 | ||||
| BlenderbotSmallForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForCausalLM,pass,12 | ||||
| MBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForConditionalGeneration,pass,24 | ||||
| MBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | ||||
|  | ||||
|  | ||||
|  | ||||
| OPTForCausalLM,pass,12 | ||||
| OPTForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForCausalLM,pass,12 | ||||
| PLBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForConditionalGeneration,pass,29 | ||||
| PLBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForCausalLM,pass,12 | ||||
| PegasusForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForConditionalGeneration,pass,23 | ||||
| PegasusForConditionalGeneration,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| Speech2Text2ForCausalLM,pass,12 | ||||
| Speech2Text2ForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| TrOCRForCausalLM,pass,12 | ||||
| TrOCRForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| XGLMForCausalLM,pass,12 | ||||
| XGLMForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForCausalLM,pass,12 | ||||
| BartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForConditionalGeneration,pass,24 | ||||
| BartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForCausalLM,pass,12 | ||||
| BlenderbotSmallForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForConditionalGeneration,pass,24 | ||||
| BlenderbotSmallForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForCausalLM,pass,12 | ||||
| MBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForConditionalGeneration,pass,24 | ||||
| MBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | ||||
|  | ||||
|  | ||||
|  | ||||
| OPTForCausalLM,pass,12 | ||||
| OPTForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForCausalLM,pass,12 | ||||
| PLBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForConditionalGeneration,pass,29 | ||||
| PLBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForCausalLM,pass,12 | ||||
| PegasusForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForConditionalGeneration,pass,23 | ||||
| PegasusForConditionalGeneration,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| Speech2Text2ForCausalLM,pass,12 | ||||
| Speech2Text2ForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| TrOCRForCausalLM,pass,12 | ||||
| TrOCRForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| XGLMForCausalLM,pass,12 | ||||
| XGLMForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForCausalLM,pass,12 | ||||
| BartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForConditionalGeneration,pass,24 | ||||
| BartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForCausalLM,pass,12 | ||||
| BlenderbotSmallForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForConditionalGeneration,pass,24 | ||||
| BlenderbotSmallForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForCausalLM,pass,12 | ||||
| MBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForConditionalGeneration,pass,24 | ||||
| MBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | ||||
|  | ||||
|  | ||||
|  | ||||
| OPTForCausalLM,pass,12 | ||||
| OPTForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForCausalLM,pass,12 | ||||
| PLBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForConditionalGeneration,pass,29 | ||||
| PLBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForCausalLM,pass,12 | ||||
| PegasusForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForConditionalGeneration,pass,23 | ||||
| PegasusForConditionalGeneration,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| Speech2Text2ForCausalLM,pass,12 | ||||
| Speech2Text2ForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| TrOCRForCausalLM,pass,12 | ||||
| TrOCRForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| XGLMForCausalLM,pass,12 | ||||
| XGLMForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -150,7 +150,7 @@ hf_Bert_large,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| hf_BigBird,pass,46 | ||||
| hf_BigBird,pass,43 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -98,7 +98,7 @@ hf_Bert_large,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| hf_BigBird,pass,52 | ||||
| hf_BigBird,pass,49 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForCausalLM,pass,12 | ||||
| BartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForConditionalGeneration,pass,24 | ||||
| BartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForCausalLM,pass,12 | ||||
| BlenderbotSmallForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForConditionalGeneration,pass,24 | ||||
| BlenderbotSmallForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForCausalLM,pass,12 | ||||
| MBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForConditionalGeneration,pass,24 | ||||
| MBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | ||||
|  | ||||
|  | ||||
|  | ||||
| OPTForCausalLM,pass,12 | ||||
| OPTForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForCausalLM,pass,12 | ||||
| PLBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForConditionalGeneration,pass,29 | ||||
| PLBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForCausalLM,pass,12 | ||||
| PegasusForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForConditionalGeneration,pass,23 | ||||
| PegasusForConditionalGeneration,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| Speech2Text2ForCausalLM,pass,12 | ||||
| Speech2Text2ForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| TrOCRForCausalLM,pass,12 | ||||
| TrOCRForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| XGLMForCausalLM,pass,12 | ||||
| XGLMForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -150,7 +150,7 @@ hf_Bert_large,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| hf_BigBird,fail_accuracy,46 | ||||
| hf_BigBird,fail_accuracy,43 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -98,7 +98,7 @@ hf_Bert_large,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| hf_BigBird,pass,52 | ||||
| hf_BigBird,pass,49 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForCausalLM,pass,12 | ||||
| BartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForConditionalGeneration,pass,24 | ||||
| BartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForCausalLM,pass,12 | ||||
| BlenderbotSmallForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForConditionalGeneration,pass,24 | ||||
| BlenderbotSmallForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForCausalLM,pass,12 | ||||
| MBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForConditionalGeneration,pass,24 | ||||
| MBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | ||||
|  | ||||
|  | ||||
|  | ||||
| OPTForCausalLM,pass,12 | ||||
| OPTForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForCausalLM,pass,12 | ||||
| PLBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForConditionalGeneration,pass,29 | ||||
| PLBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForCausalLM,pass,12 | ||||
| PegasusForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForConditionalGeneration,pass,23 | ||||
| PegasusForConditionalGeneration,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| Speech2Text2ForCausalLM,pass,12 | ||||
| Speech2Text2ForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| TrOCRForCausalLM,pass,12 | ||||
| TrOCRForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| XGLMForCausalLM,pass,12 | ||||
| XGLMForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -150,7 +150,7 @@ hf_Bert_large,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| hf_BigBird,pass,46 | ||||
| hf_BigBird,pass,43 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -98,7 +98,7 @@ hf_Bert_large,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| hf_BigBird,pass,52 | ||||
| hf_BigBird,pass,49 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForCausalLM,pass,12 | ||||
| BartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BartForConditionalGeneration,pass,24 | ||||
| BartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForCausalLM,pass,12 | ||||
| BlenderbotSmallForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| BlenderbotSmallForConditionalGeneration,pass,24 | ||||
| BlenderbotSmallForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForCausalLM,pass,12 | ||||
| MBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| MBartForConditionalGeneration,pass,24 | ||||
| MBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3 | ||||
|  | ||||
|  | ||||
|  | ||||
| OPTForCausalLM,pass,12 | ||||
| OPTForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForCausalLM,pass,12 | ||||
| PLBartForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PLBartForConditionalGeneration,pass,29 | ||||
| PLBartForConditionalGeneration,pass,8 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForCausalLM,pass,12 | ||||
| PegasusForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| PegasusForConditionalGeneration,pass,23 | ||||
| PegasusForConditionalGeneration,pass,7 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| Speech2Text2ForCausalLM,pass,12 | ||||
| Speech2Text2ForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -170,11 +170,11 @@ T5Small,pass,5 | ||||
|  | ||||
|  | ||||
|  | ||||
| TrOCRForCausalLM,pass,12 | ||||
| TrOCRForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| XGLMForCausalLM,pass,12 | ||||
| XGLMForCausalLM,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -150,7 +150,7 @@ hf_Bert_large,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| hf_BigBird,fail_accuracy,46 | ||||
| hf_BigBird,fail_accuracy,43 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -98,7 +98,7 @@ hf_Bert_large,pass,6 | ||||
|  | ||||
|  | ||||
|  | ||||
| hf_BigBird,pass,52 | ||||
| hf_BigBird,pass,49 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -101,6 +101,15 @@ class TestModelOutput(torch._dynamo.test_case.TestCase): | ||||
|  | ||||
|         self._common(fn, 2) | ||||
|  | ||||
|     @maybe_skip | ||||
|     def test_mo_getattr_missing(self): | ||||
|         def fn(obj: BaseModelOutput): | ||||
|             if getattr(obj, "asdf", None) is not None: | ||||
|                 obj.asdf += 1 | ||||
|             return obj.attentions + 1 | ||||
|  | ||||
|         self._common(fn, 1) | ||||
|  | ||||
|     @maybe_skip | ||||
|     def test_mo_getitem(self): | ||||
|         def fn(obj: BaseModelOutput): | ||||
| @ -166,6 +175,59 @@ class TestModelOutput(torch._dynamo.test_case.TestCase): | ||||
|         self.assertEqual(cnts.frame_count, 1) | ||||
|         self.assertEqual(cnts.op_count, 2) | ||||
|  | ||||
|     @maybe_skip | ||||
|     def test_mo_init2(self): | ||||
|         # this ModelOutput subclass runs a different __post_init__ codepath | ||||
|         @dataclasses.dataclass | ||||
|         class MyDataClass(ModelOutput): | ||||
|             x: torch.FloatTensor = None | ||||
|  | ||||
|         def fn(x): | ||||
|             obj = MyDataClass(x=x) | ||||
|             return obj | ||||
|  | ||||
|         inp = torch.randn(3, 3) | ||||
|         opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) | ||||
|         self.assertEqual(fn(inp).x, opt_fn(inp).x) | ||||
|  | ||||
|     @maybe_skip | ||||
|     def test_mo_init_with_disable(self): | ||||
|         # Can result in "non-function or method super: <slot wrapper '__setattr__' of 'object' objects>" | ||||
|         # graph breaks (although it may not be the first) | ||||
|         # Minimal repro for https://github.com/pytorch/pytorch/issues/126028 | ||||
|         @dataclasses.dataclass | ||||
|         class MyDataClass(ModelOutput): | ||||
|             x: torch.FloatTensor = None | ||||
|  | ||||
|         @torch._dynamo.disable(recursive=False) | ||||
|         def fn(x): | ||||
|             return MyDataClass(x=x) | ||||
|  | ||||
|         inp = torch.randn(3, 3) | ||||
|         opt_fn = torch._dynamo.optimize("eager")(fn) | ||||
|         self.assertEqual(fn(inp).x, opt_fn(inp).x) | ||||
|  | ||||
|     @maybe_skip | ||||
|     def test_mo_newkey(self): | ||||
|         obj = BaseModelOutput() | ||||
|  | ||||
|         def fn(obj): | ||||
|             return obj["wwww"] + 1 | ||||
|  | ||||
|         inp = torch.randn(3, 3) | ||||
|         obj["wwww"] = inp | ||||
|         opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) | ||||
|         self.assertEqual(fn(obj), opt_fn(obj)) | ||||
|  | ||||
|     @maybe_skip | ||||
|     def test_mo_from_outside(self): | ||||
|         def fn(obj): | ||||
|             return obj.attentions + 1 | ||||
|  | ||||
|         obj = BaseModelOutput(attentions=torch.randn(3, 3)) | ||||
|         opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) | ||||
|         self.assertEqual(fn(obj), opt_fn(obj)) | ||||
|  | ||||
|     @maybe_skip | ||||
|     def test_HF_bert_model_output(self): | ||||
|         class BertPooler(torch.nn.Module): | ||||
|  | ||||
| @ -4052,6 +4052,7 @@ class TestQuantizedLinear(TestCase): | ||||
|            use_channelwise=st.sampled_from([False]))  # channelwise currently not supported for qlinear cudnn | ||||
|     @skipIfNoFBGEMM | ||||
|     @unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.") | ||||
|     @unittest.skipIf(TEST_CUDNN and torch.backends.cudnn.version() == 90100, "expected failure on cuDNN 9.1.0") | ||||
|     @unittest.skipIf(not SM80OrLater, "requires sm80 or later.") | ||||
|     @unittest.skipIf(TEST_ROCM, "not supported on rocm.") | ||||
|     # TODO: check with yang regarding CUDNN flags | ||||
|  | ||||
| @ -7,6 +7,7 @@ import torch.nn | ||||
|  | ||||
| from . import utils, variables | ||||
| from .bytecode_transformation import ( | ||||
|     bytecode_from_template, | ||||
|     create_call_function, | ||||
|     create_call_method, | ||||
|     create_instruction, | ||||
| @ -59,6 +60,11 @@ class AttributeMutationNew(AttributeMutation): | ||||
|         self.cls_source = cls_source | ||||
|  | ||||
|  | ||||
| def _manual_update_dict(dict_from, dict_to): | ||||
|     for k, v in dict_from.items(): | ||||
|         dict_to[k] = v | ||||
|  | ||||
|  | ||||
| class SideEffects: | ||||
|     """ | ||||
|     Track side effects (list mutation, setattr, etc) that need to be | ||||
| @ -480,6 +486,39 @@ class SideEffects: | ||||
|                     ] | ||||
|                 ) | ||||
|                 suffixes.append([create_instruction("STORE_SUBSCR")]) | ||||
|             elif isinstance(var, variables.CustomizedDictVariable): | ||||
|                 # need to update the dict manually since update method may be invalid | ||||
|                 varname_map = {} | ||||
|                 for name in _manual_update_dict.__code__.co_varnames: | ||||
|                     varname_map[name] = cg.tx.output.new_var() | ||||
|  | ||||
|                 cg(var.mutable_local.source)  # type: ignore[attr-defined] | ||||
|                 cg.extend_output( | ||||
|                     [create_instruction("STORE_FAST", argval=varname_map["dict_to"])] | ||||
|                 ) | ||||
|  | ||||
|                 cg(var, allow_cache=False) | ||||
|                 cg.extend_output( | ||||
|                     [create_instruction("STORE_FAST", argval=varname_map["dict_from"])] | ||||
|                 ) | ||||
|  | ||||
|                 cg(var.mutable_local.source)  # type: ignore[attr-defined] | ||||
|                 cg.extend_output([create_load_method("clear")]) | ||||
|  | ||||
|                 # unfortunately can't just use DICT_MERGE due to possible custom behaviors | ||||
|                 dict_update_insts = bytecode_from_template( | ||||
|                     _manual_update_dict, varname_map=varname_map | ||||
|                 ) | ||||
|  | ||||
|                 suffixes.append( | ||||
|                     [ | ||||
|                         *create_call_method(0),  # clear | ||||
|                         create_instruction("POP_TOP"), | ||||
|                         *dict_update_insts, | ||||
|                         create_instruction("POP_TOP"), | ||||
|                     ] | ||||
|                 ) | ||||
|  | ||||
|             elif isinstance(var, variables.ConstDictVariable): | ||||
|                 cg.tx.output.update_co_names("clear") | ||||
|                 cg.tx.output.update_co_names("update") | ||||
|  | ||||
| @ -23,7 +23,6 @@ from .ctx_manager import ( | ||||
| from .dicts import ( | ||||
|     ConstDictVariable, | ||||
|     CustomizedDictVariable, | ||||
|     DataClassVariable, | ||||
|     DefaultDictVariable, | ||||
|     SetVariable, | ||||
| ) | ||||
| @ -113,7 +112,6 @@ __all__ = [ | ||||
|     "CountIteratorVariable", | ||||
|     "CustomizedDictVariable", | ||||
|     "CycleIteratorVariable", | ||||
|     "DataClassVariable", | ||||
|     "DefaultDictVariable", | ||||
|     "DeletedVariable", | ||||
|     "DeterministicAlgorithmsVariable", | ||||
|  | ||||
| @ -111,7 +111,7 @@ from .ctx_manager import ( | ||||
| ) | ||||
| from .dicts import ( | ||||
|     ConstDictVariable, | ||||
|     DataClassVariable, | ||||
|     CustomizedDictVariable, | ||||
|     DefaultDictVariable, | ||||
|     HFPretrainedConfigVariable, | ||||
|     PythonSysModulesVariable, | ||||
| @ -493,6 +493,11 @@ class VariableBuilder: | ||||
|         elif value is sys.modules: | ||||
|             self.install_guards(GuardBuilder.FUNCTION_MATCH) | ||||
|             return PythonSysModulesVariable(source=self.source) | ||||
|         elif CustomizedDictVariable.is_matching_cls_hf(type(value)): | ||||
|             self.install_guards(GuardBuilder.TYPE_MATCH) | ||||
|             result = CustomizedDictVariable.wrap(self, value) | ||||
|             result.source = self.source | ||||
|             return self.tx.output.side_effects.track_object_existing(value, result) | ||||
|         elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): | ||||
|             if not value and self.get_source().is_nn_module(): | ||||
|                 # It is faster to guard on 'false' property than to guard | ||||
| @ -711,9 +716,6 @@ class VariableBuilder: | ||||
|             ) | ||||
|         elif np and isinstance(value, np.number): | ||||
|             return self.wrap_unspecialized_primitive(value) | ||||
|         elif DataClassVariable.is_matching_object(value): | ||||
|             self.install_guards(GuardBuilder.TYPE_MATCH) | ||||
|             return DataClassVariable.wrap(self, value) | ||||
|         elif HFPretrainedConfigVariable.is_matching_object(value): | ||||
|             self.install_guards(GuardBuilder.TYPE_MATCH) | ||||
|             return HFPretrainedConfigVariable(value) | ||||
| @ -1701,7 +1703,7 @@ class VariableBuilder: | ||||
| def _dataclasses_fields_lambda(obj): | ||||
|     if isinstance(obj, UserDefinedObjectVariable): | ||||
|         value = obj.value | ||||
|     elif isinstance(obj, DataClassVariable): | ||||
|     elif isinstance(obj, CustomizedDictVariable): | ||||
|         value = obj.user_cls | ||||
|     else: | ||||
|         unimplemented(f"Dataclass fields handling fails for type {obj}") | ||||
|  | ||||
| @ -1633,7 +1633,6 @@ class BuiltinVariable(VariableTracker): | ||||
|         if isinstance( | ||||
|             obj, | ||||
|             ( | ||||
|                 variables.DataClassVariable, | ||||
|                 variables.CustomizedDictVariable, | ||||
|                 variables.PlacementVariable, | ||||
|                 variables.UserDefinedObjectVariable, | ||||
|  | ||||
| @ -545,6 +545,8 @@ class DictValues(DictView): | ||||
|  | ||||
| def _is_matching_transformers_cls(cls) -> bool: | ||||
|     mod = sys.modules.get("transformers.file_utils") | ||||
|     if mod is None: | ||||
|         mod = sys.modules.get("transformers.utils.generic") | ||||
|     return mod is not None and issubclass(cls, mod.ModelOutput) | ||||
|  | ||||
|  | ||||
| @ -555,12 +557,20 @@ def _is_matching_diffusers_cls(cls) -> bool: | ||||
|  | ||||
| def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker": | ||||
|     """Shared method between DataClassVariable and CustomizedDictVariable where items are attrs""" | ||||
|     if tx.output.side_effects.is_attribute_mutation(self): | ||||
|         try: | ||||
|             result = tx.output.side_effects.load_attr(self, name, deleted_ok=True) | ||||
|             return variables.ConstantVariable.create( | ||||
|                 not isinstance(result, variables.DeletedVariable) | ||||
|             ) | ||||
|         except KeyError: | ||||
|             pass | ||||
|     if name in self.items or hasattr(self.user_cls, name): | ||||
|         return ConstantVariable(True) | ||||
|     elif istype(self.mutable_local, MutableLocal) and self.source is None: | ||||
|         # Something created locally can't have any extra fields on it | ||||
|         return ConstantVariable(False) | ||||
|     elif self.mutable_local is None and self.source: | ||||
|     elif self.source: | ||||
|         # Maybe add a guard | ||||
|         try: | ||||
|             example = tx.output.root_tx.get_example_value(self.source) | ||||
| @ -577,152 +587,27 @@ def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker": | ||||
|  | ||||
| class DataClassVariable(ConstDictVariable): | ||||
|     """ | ||||
|     This is a bit of a hack to deal with | ||||
|     transformers.file_utils.ModelOutput() from huggingface. | ||||
|     This class doesn't appear to be used anywhere. | ||||
|     It used to be used to deal with transformers.file_utils.ModelOutput | ||||
|     from huggingface. | ||||
|  | ||||
|     ModelOutput causes trouble because it a a mix of a dataclass and a | ||||
|     OrderedDict and it calls super() methods implemented in C. | ||||
|     Keeping since we wish to support dataclasses in general in the future | ||||
|     """ | ||||
|  | ||||
|     # ModelOutput() excludes None, though generic datclasses don't | ||||
|     include_none = False | ||||
|  | ||||
|     @staticmethod | ||||
|     @functools.lru_cache(None) | ||||
|     def _patch_once(): | ||||
|         try: | ||||
|             from transformers.file_utils import ModelOutput | ||||
|  | ||||
|             for obj in ModelOutput.__dict__.values(): | ||||
|                 if callable(obj): | ||||
|                     skip_code(obj.__code__) | ||||
|         except ImportError: | ||||
|             pass | ||||
|  | ||||
|         try: | ||||
|             from diffusers.utils import BaseOutput | ||||
|  | ||||
|             for obj in BaseOutput.__dict__.values(): | ||||
|                 if callable(obj): | ||||
|                     skip_code(obj.__code__) | ||||
|         except ImportError: | ||||
|             pass | ||||
|  | ||||
|     @staticmethod | ||||
|     def is_matching_cls(cls): | ||||
|         return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) | ||||
|  | ||||
|     @classmethod | ||||
|     def is_matching_object(cls, obj): | ||||
|         return cls.is_matching_cls(type(obj)) | ||||
|  | ||||
|     @classmethod | ||||
|     def create(cls, user_cls, args, kwargs, options): | ||||
|         DataClassVariable._patch_once() | ||||
|  | ||||
|         skip_code(user_cls.__init__.__code__) | ||||
|         keys = [f.name for f in dataclasses.fields(user_cls)] | ||||
|         bound = inspect.signature(user_cls).bind(*args, **kwargs) | ||||
|         bound.apply_defaults() | ||||
|         assert set(bound.arguments.keys()) == set(keys) | ||||
|         items = {} | ||||
|         for key in keys: | ||||
|             val = bound.arguments[key] | ||||
|             key = ConstantVariable.create(key) | ||||
|             if isinstance(val, VariableTracker): | ||||
|                 items[key] = val | ||||
|             else: | ||||
|                 if cls.include_none: | ||||
|                     assert variables.ConstantVariable.is_literal(val) | ||||
|                     items[key] = variables.ConstantVariable.create(val) | ||||
|                 else: | ||||
|                     assert val is None, f"unexpected {val}" | ||||
|  | ||||
|         if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable): | ||||
|             unimplemented("DataClassVariable iterator constructor") | ||||
|             # TODO(jansel): implement unpacking logic in ModelOutput.__post_init__ | ||||
|  | ||||
|         return cls(items, user_cls, **options) | ||||
|  | ||||
|     @classmethod | ||||
|     def wrap(cls, builder, obj): | ||||
|         user_cls = type(obj) | ||||
|         keys = [f.name for f in dataclasses.fields(user_cls)] | ||||
|  | ||||
|         excluded = [] | ||||
|         items = {} | ||||
|         for key in keys: | ||||
|             # __init__ function of a dataclass might not have yet defined the key | ||||
|             if hasattr(obj, key): | ||||
|                 val = getattr(obj, key) | ||||
|                 var = builder.__class__( | ||||
|                     tx=builder.tx, source=AttrSource(builder.source, key) | ||||
|                 )(val) | ||||
|                 if val is not None or cls.include_none: | ||||
|                     key = ConstantVariable.create(key) | ||||
|                     items[key] = var | ||||
|                 else: | ||||
|                     excluded.append(var) | ||||
|         return cls(items, user_cls) | ||||
|  | ||||
|     def __init__(self, items, user_cls, **options): | ||||
|         super().__init__(items, user_cls, **options) | ||||
|         assert self.is_matching_cls(user_cls) | ||||
|  | ||||
|     def as_proxy(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def reconstruct(self, codegen): | ||||
|         codegen.extend_output([codegen._create_load_const(self.user_cls)]) | ||||
|         # All the keys are just wrapped strings | ||||
|         d = self.keys_as_python_constant() | ||||
|         codegen.foreach(d.values()) | ||||
|         keys = tuple(d.keys()) | ||||
|         codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True)) | ||||
|  | ||||
|     def call_method( | ||||
|         self, | ||||
|         tx, | ||||
|         name, | ||||
|         args: "List[VariableTracker]", | ||||
|         kwargs: "Dict[str, VariableTracker]", | ||||
|     ) -> "VariableTracker": | ||||
|         if name == "__getitem__": | ||||
|             assert not kwargs and len(args) == 1 | ||||
|             val = args[0] | ||||
|             if val.python_type() == str: | ||||
|                 return self.getitem_const(val) | ||||
|             else: | ||||
|                 return self.call_method(tx, "to_tuple", [], {}).call_method( | ||||
|                     tx, "__getitem__", args, kwargs | ||||
|                 ) | ||||
|         elif name == "to_tuple": | ||||
|             assert not (args or kwargs) | ||||
|             return variables.TupleVariable(list(self.items.values())) | ||||
|         elif name == "__setattr__": | ||||
|             name = "__setitem__" | ||||
|         return super().call_method(tx, name, args, kwargs) | ||||
|  | ||||
|     def var_getattr(self, tx, name: str) -> "VariableTracker": | ||||
|         name_vt = ConstantVariable.create(name) | ||||
|         if name_vt in self: | ||||
|             return self.call_method(tx, "__getitem__", [name_vt], {}) | ||||
|         elif not self.include_none: | ||||
|             defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)} | ||||
|             if name in defaults: | ||||
|                 assert variables.ConstantVariable.is_literal(defaults[name]) | ||||
|                 return variables.ConstantVariable.create(defaults[name]) | ||||
|         super().var_getattr(tx, name) | ||||
|  | ||||
|     call_hasattr = _call_hasattr_customobj | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class CustomizedDictVariable(ConstDictVariable): | ||||
|     @staticmethod | ||||
|     def is_matching_cls_hf(cls): | ||||
|         return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) | ||||
|  | ||||
|     @staticmethod | ||||
|     def is_matching_cls(cls): | ||||
|         # True if using default OrderedDict.__init__ and did not implement __post_init__ | ||||
|         if ( | ||||
|             issubclass(cls, collections.OrderedDict) | ||||
|             and cls is not collections.OrderedDict | ||||
|             and cls.__init__ is collections.OrderedDict.__init__ | ||||
|             and not hasattr(cls, "__post_init__") | ||||
|         ): | ||||
| @ -730,7 +615,7 @@ class CustomizedDictVariable(ConstDictVariable): | ||||
|         # hack for HF usecase: | ||||
|         #   assume dataclass annotation for ModelOutput subclass | ||||
|         #   assume self.create is AA to ModelOutput.__post_init__ | ||||
|         return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) | ||||
|         return CustomizedDictVariable.is_matching_cls_hf(cls) | ||||
|  | ||||
|     @classmethod | ||||
|     def is_matching_object(cls, obj): | ||||
| @ -764,9 +649,7 @@ class CustomizedDictVariable(ConstDictVariable): | ||||
|                     ) | ||||
|  | ||||
|             bound_args = {} | ||||
|             if _is_matching_transformers_cls(user_cls) or _is_matching_diffusers_cls( | ||||
|                 user_cls | ||||
|             ): | ||||
|             if cls.is_matching_cls_hf(user_cls): | ||||
|                 # Skip none | ||||
|                 for k, v in bound.arguments.items(): | ||||
|                     if isinstance(v, ConstantVariable) and v.value is None or v is None: | ||||
| @ -792,7 +675,27 @@ class CustomizedDictVariable(ConstDictVariable): | ||||
|     # called from builder.py | ||||
|     @classmethod | ||||
|     def wrap(cls, builder, obj): | ||||
|         raise NotImplementedError | ||||
|         user_cls = type(obj) | ||||
|  | ||||
|         if not cls.is_matching_cls_hf(user_cls): | ||||
|             unimplemented("custom non-hf dict subclass wrap unimplemented") | ||||
|  | ||||
|         items = builder.__class__(tx=builder.tx, source=builder.source)( | ||||
|             collections.OrderedDict(obj) | ||||
|         ).items | ||||
|  | ||||
|         keys = [f.name for f in dataclasses.fields(user_cls)] | ||||
|         for key in keys: | ||||
|             # __init__ function of a dataclass might not have yet defined the key | ||||
|             if hasattr(obj, key): | ||||
|                 val = getattr(obj, key) | ||||
|                 var = builder.__class__( | ||||
|                     tx=builder.tx, source=AttrSource(builder.source, key) | ||||
|                 )(val) | ||||
|                 if val is not None: | ||||
|                     key = ConstantVariable.create(key) | ||||
|                     items[key] = var | ||||
|         return cls(items, user_cls) | ||||
|  | ||||
|     def __init__(self, items, user_cls, **options): | ||||
|         super().__init__(items, user_cls, **options) | ||||
| @ -804,9 +707,7 @@ class CustomizedDictVariable(ConstDictVariable): | ||||
|     # 'RETURN_VALUE triggered compile' | ||||
|     # called from torch/_dynamo/codegen.py | ||||
|     def reconstruct(self, codegen): | ||||
|         is_hf_model_output = _is_matching_transformers_cls( | ||||
|             self.user_cls | ||||
|         ) or _is_matching_diffusers_cls(self.user_cls) | ||||
|         is_hf_model_output = self.is_matching_cls_hf(self.user_cls) | ||||
|  | ||||
|         # If the user class is a ModelOutput, then wrap the instance creation in | ||||
|         # torch._dynamo.disable(). Even though we mark the __post_init__ as skip | ||||
| @ -848,21 +749,34 @@ class CustomizedDictVariable(ConstDictVariable): | ||||
|         ): | ||||
|             # for python dict method without overridden | ||||
|             return super().call_method(tx, name, args, kwargs) | ||||
|         elif name in ("__getitem__", "to_tuple", "__setitem__", "__setattr__"): | ||||
|         elif name in ( | ||||
|             "__getitem__", | ||||
|             "to_tuple", | ||||
|             "__setitem__", | ||||
|             "__setattr__", | ||||
|             "__post_init__", | ||||
|         ): | ||||
|             # for user overridden method | ||||
|             return tx.inline_user_function_return( | ||||
|                 variables.UserFunctionVariable(fn, source=source), | ||||
|                 [self] + list(args), | ||||
|                 kwargs, | ||||
|             ) | ||||
|         elif fn is getattr(collections.OrderedDict, name, None): | ||||
|             return super().call_method(tx, name, args, kwargs) | ||||
|  | ||||
|         unimplemented("custom dict: call_method unimplemented name=%s", name) | ||||
|         unimplemented(f"custom dict: call_method unimplemented name={name}") | ||||
|  | ||||
|     def var_getattr(self, tx, name: str) -> "VariableTracker": | ||||
|         name_vt = ConstantVariable.create(name) | ||||
|         if name_vt in self: | ||||
|             return self.call_method(tx, "__getitem__", [name_vt], {}) | ||||
|         super().var_getattr(tx, name) | ||||
|         if dataclasses.is_dataclass(self.user_cls): | ||||
|             defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)} | ||||
|             if name in defaults: | ||||
|                 assert variables.ConstantVariable.is_literal(defaults[name]) | ||||
|                 return variables.ConstantVariable.create(defaults[name]) | ||||
|         return super().var_getattr(tx, name) | ||||
|  | ||||
|     call_hasattr = _call_hasattr_customobj | ||||
|  | ||||
|  | ||||
| @ -171,6 +171,12 @@ class SuperVariable(VariableTracker): | ||||
|             return super(variables.CustomizedDictVariable, self.objvar).call_method( | ||||
|                 tx, "__setitem__", args, kwargs | ||||
|             ) | ||||
|         elif inner_fn is collections.OrderedDict.__getitem__ and isinstance( | ||||
|             self.objvar, variables.CustomizedDictVariable | ||||
|         ): | ||||
|             return super(variables.CustomizedDictVariable, self.objvar).call_method( | ||||
|                 tx, "__getitem__", args, kwargs | ||||
|             ) | ||||
|         elif is_standard_setattr(inner_fn) and isinstance( | ||||
|             self.objvar, UserDefinedObjectVariable | ||||
|         ): | ||||
|  | ||||
| @ -396,9 +396,6 @@ class UserDefinedClassVariable(UserDefinedVariable): | ||||
|             return variables.CustomizedDictVariable.create( | ||||
|                 self.value, args, kwargs, options | ||||
|             ) | ||||
|         elif variables.DataClassVariable.is_matching_cls(self.value): | ||||
|             options = {"mutable_local": MutableLocal()} | ||||
|             return variables.DataClassVariable.create(self.value, args, kwargs, options) | ||||
|         elif ( | ||||
|             variables.RestrictedListSubclassVariable.is_matching_cls(self.value) | ||||
|             and self.source | ||||
|  | ||||
| @ -3247,13 +3247,13 @@ void install_tensor_aliasing_guard( | ||||
|  | ||||
| void install_no_tensor_aliasing_guard( | ||||
|     const py::list& guard_managers, | ||||
|     py::list tensor_names, | ||||
|     const py::list& tensor_names, | ||||
|     py::object verbose_code_parts) { | ||||
|   // Adds a guard that checks none of tensors alias. This is a an example of | ||||
|   // relational guard. There is one guard object that is shared between multiple | ||||
|   // guard managers. | ||||
|   std::shared_ptr<RelationalGuard> guard = std::make_shared<NO_TENSOR_ALIASING>( | ||||
|       std::move(tensor_names), std::move(verbose_code_parts)); | ||||
|       tensor_names, std::move(verbose_code_parts)); | ||||
|  | ||||
|   // Register the resetter on the toor guard mananger, so that it can reset | ||||
|   // the newly added relational guard when the guard eval fails. | ||||
| @ -4006,7 +4006,15 @@ PyObject* torch_c_dynamo_guards_init() { | ||||
|       DictSubclassGuardManager, | ||||
|       DictGuardManager, | ||||
|       std::unique_ptr<DictSubclassGuardManager>>( | ||||
|       py_m, "DictSubclassGuardManager"); // NOLINT | ||||
|       py_m, "DictSubclassGuardManager") // NOLINT | ||||
|       .def( | ||||
|           "add_no_hasattr_guard", | ||||
|           [](DictSubclassGuardManager& self, | ||||
|              py::object attr_name, | ||||
|              py::object verbose_code_parts) -> void { | ||||
|             self.add_permitted_leaf_guard(std::make_shared<NO_HASATTR>( | ||||
|                 std::move(attr_name), std::move(verbose_code_parts))); | ||||
|           }); | ||||
|  | ||||
|   py_m.def("install_tensor_aliasing_guard", install_tensor_aliasing_guard); | ||||
|   py_m.def( | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	