Compare commits

..

2 Commits

Author SHA1 Message Date
3ddec713b8 Revert "[cuDNN][Quantization] Don't print when plan finalization fails in cuDNN quantization backend (#128177)"
This reverts commit cac7a22b92478d897488688010e562b7bd36b97f.

Reverted https://github.com/pytorch/pytorch/pull/128177 on behalf of https://github.com/clee2000 due to broke test/test_quantization.py::TestQuantizedLinear::test_qlinear_cudnn on sm86 tests cac7a22b92 https://github.com/pytorch/pytorch/actions/runs/9470648757/job/26100448913.  Probably a landrace, test ran on the PR and succeed ([comment](https://github.com/pytorch/pytorch/pull/128177#issuecomment-2161977110))
2024-06-12 02:20:15 +00:00
85eeb90d2c [dynamo] Fix graph breaks related to HF ModelOutput (#127780)
Fixes https://github.com/pytorch/pytorch/issues/126028 and https://github.com/pytorch/pytorch/issues/126027.

Changes:
- Support building `CustomizedDictVariable` in` VariableBuilder` (but only for HF `ModelOutput` subclasses)
- Remove `DataClassVariable` since it's not really being used anywhere (`CustomizedDictVariable` can be used instead)
- Support side effects for `CustomizedDictVariable`
- Allow `NO_HASATTR` leaf guard on `DictSubclassGuardManager`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127780
Approved by: https://github.com/jansel, https://github.com/anijain2305
2024-06-12 02:16:24 +00:00
30 changed files with 298 additions and 272 deletions

View File

@ -242,7 +242,7 @@ Tensor add(Tensor qa, Tensor qb, double output_scale, int64_t output_zero_point)
run(plan_desc);
execution_plan_cache[key] = plan_desc;
return quantized_output.view(orig_sizes);
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
}
TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Add Cudnn");

View File

@ -252,7 +252,7 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
run(plan);
execution_plan_cache.emplace(key, plan);
return;
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
}
TORCH_CHECK(false, "Unable to find an engine to execute this computation in Quantized Conv2D Cudnn");

View File

@ -286,7 +286,7 @@ void PackedLinearWeightCudnn::apply_impl_helper(const at::Tensor& quantized_outp
run(plan);
execution_plan_cache.emplace(key, plan);
return;
} catch (cudnn_frontend::cudnnException &e) {} catch(c10::CuDNNError &e) {}
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
}
TORCH_CHECK(false, "Unable to find an engine to execute this computation Quantized Linear Cudnn");

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -150,7 +150,7 @@ hf_Bert_large,pass,0
hf_BigBird,pass,46
hf_BigBird,pass,43

1 name accuracy graph_breaks
150
151
152
153
154
155
156

View File

@ -98,7 +98,7 @@ hf_Bert_large,pass,6
hf_BigBird,pass, 52
hf_BigBird,pass,49

1 name accuracy graph_breaks
98
99
100
101
102
103
104

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -150,7 +150,7 @@ hf_Bert_large,pass,0
hf_BigBird,pass,46
hf_BigBird,pass,43

1 name accuracy graph_breaks
150
151
152
153
154
155
156

View File

@ -98,7 +98,7 @@ hf_Bert_large,pass,6
hf_BigBird,pass,52
hf_BigBird,pass,49

1 name accuracy graph_breaks
98
99
100
101
102
103
104

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -150,7 +150,7 @@ hf_Bert_large,pass,0
hf_BigBird,fail_accuracy,46
hf_BigBird,fail_accuracy,43

1 name accuracy graph_breaks
150
151
152
153
154
155
156

View File

@ -98,7 +98,7 @@ hf_Bert_large,pass,6
hf_BigBird,pass,52
hf_BigBird,pass,49

1 name accuracy graph_breaks
98
99
100
101
102
103
104

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -150,7 +150,7 @@ hf_Bert_large,pass,0
hf_BigBird,pass,46
hf_BigBird,pass,43

1 name accuracy graph_breaks
150
151
152
153
154
155
156

View File

@ -98,7 +98,7 @@ hf_Bert_large,pass,6
hf_BigBird,pass,52
hf_BigBird,pass,49

1 name accuracy graph_breaks
98
99
100
101
102
103
104

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -150,7 +150,7 @@ hf_Bert_large,pass,0
hf_BigBird,fail_accuracy,46
hf_BigBird,fail_accuracy,43

1 name accuracy graph_breaks
150
151
152
153
154
155
156

View File

@ -98,7 +98,7 @@ hf_Bert_large,pass,6
hf_BigBird,pass,52
hf_BigBird,pass,49

1 name accuracy graph_breaks
98
99
100
101
102
103
104

View File

@ -101,6 +101,15 @@ class TestModelOutput(torch._dynamo.test_case.TestCase):
self._common(fn, 2)
@maybe_skip
def test_mo_getattr_missing(self):
def fn(obj: BaseModelOutput):
if getattr(obj, "asdf", None) is not None:
obj.asdf += 1
return obj.attentions + 1
self._common(fn, 1)
@maybe_skip
def test_mo_getitem(self):
def fn(obj: BaseModelOutput):
@ -166,6 +175,59 @@ class TestModelOutput(torch._dynamo.test_case.TestCase):
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
@maybe_skip
def test_mo_init2(self):
# this ModelOutput subclass runs a different __post_init__ codepath
@dataclasses.dataclass
class MyDataClass(ModelOutput):
x: torch.FloatTensor = None
def fn(x):
obj = MyDataClass(x=x)
return obj
inp = torch.randn(3, 3)
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
self.assertEqual(fn(inp).x, opt_fn(inp).x)
@maybe_skip
def test_mo_init_with_disable(self):
# Can result in "non-function or method super: <slot wrapper '__setattr__' of 'object' objects>"
# graph breaks (although it may not be the first)
# Minimal repro for https://github.com/pytorch/pytorch/issues/126028
@dataclasses.dataclass
class MyDataClass(ModelOutput):
x: torch.FloatTensor = None
@torch._dynamo.disable(recursive=False)
def fn(x):
return MyDataClass(x=x)
inp = torch.randn(3, 3)
opt_fn = torch._dynamo.optimize("eager")(fn)
self.assertEqual(fn(inp).x, opt_fn(inp).x)
@maybe_skip
def test_mo_newkey(self):
obj = BaseModelOutput()
def fn(obj):
return obj["wwww"] + 1
inp = torch.randn(3, 3)
obj["wwww"] = inp
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
self.assertEqual(fn(obj), opt_fn(obj))
@maybe_skip
def test_mo_from_outside(self):
def fn(obj):
return obj.attentions + 1
obj = BaseModelOutput(attentions=torch.randn(3, 3))
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
self.assertEqual(fn(obj), opt_fn(obj))
@maybe_skip
def test_HF_bert_model_output(self):
class BertPooler(torch.nn.Module):

View File

@ -4052,6 +4052,7 @@ class TestQuantizedLinear(TestCase):
use_channelwise=st.sampled_from([False])) # channelwise currently not supported for qlinear cudnn
@skipIfNoFBGEMM
@unittest.skipIf(not TEST_CUDNN, "cudnn is not enabled.")
@unittest.skipIf(TEST_CUDNN and torch.backends.cudnn.version() == 90100, "expected failure on cuDNN 9.1.0")
@unittest.skipIf(not SM80OrLater, "requires sm80 or later.")
@unittest.skipIf(TEST_ROCM, "not supported on rocm.")
# TODO: check with yang regarding CUDNN flags

View File

@ -7,6 +7,7 @@ import torch.nn
from . import utils, variables
from .bytecode_transformation import (
bytecode_from_template,
create_call_function,
create_call_method,
create_instruction,
@ -59,6 +60,11 @@ class AttributeMutationNew(AttributeMutation):
self.cls_source = cls_source
def _manual_update_dict(dict_from, dict_to):
for k, v in dict_from.items():
dict_to[k] = v
class SideEffects:
"""
Track side effects (list mutation, setattr, etc) that need to be
@ -480,6 +486,39 @@ class SideEffects:
]
)
suffixes.append([create_instruction("STORE_SUBSCR")])
elif isinstance(var, variables.CustomizedDictVariable):
# need to update the dict manually since update method may be invalid
varname_map = {}
for name in _manual_update_dict.__code__.co_varnames:
varname_map[name] = cg.tx.output.new_var()
cg(var.mutable_local.source) # type: ignore[attr-defined]
cg.extend_output(
[create_instruction("STORE_FAST", argval=varname_map["dict_to"])]
)
cg(var, allow_cache=False)
cg.extend_output(
[create_instruction("STORE_FAST", argval=varname_map["dict_from"])]
)
cg(var.mutable_local.source) # type: ignore[attr-defined]
cg.extend_output([create_load_method("clear")])
# unfortunately can't just use DICT_MERGE due to possible custom behaviors
dict_update_insts = bytecode_from_template(
_manual_update_dict, varname_map=varname_map
)
suffixes.append(
[
*create_call_method(0), # clear
create_instruction("POP_TOP"),
*dict_update_insts,
create_instruction("POP_TOP"),
]
)
elif isinstance(var, variables.ConstDictVariable):
cg.tx.output.update_co_names("clear")
cg.tx.output.update_co_names("update")

View File

@ -23,7 +23,6 @@ from .ctx_manager import (
from .dicts import (
ConstDictVariable,
CustomizedDictVariable,
DataClassVariable,
DefaultDictVariable,
SetVariable,
)
@ -113,7 +112,6 @@ __all__ = [
"CountIteratorVariable",
"CustomizedDictVariable",
"CycleIteratorVariable",
"DataClassVariable",
"DefaultDictVariable",
"DeletedVariable",
"DeterministicAlgorithmsVariable",

View File

@ -111,7 +111,7 @@ from .ctx_manager import (
)
from .dicts import (
ConstDictVariable,
DataClassVariable,
CustomizedDictVariable,
DefaultDictVariable,
HFPretrainedConfigVariable,
PythonSysModulesVariable,
@ -493,6 +493,11 @@ class VariableBuilder:
elif value is sys.modules:
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return PythonSysModulesVariable(source=self.source)
elif CustomizedDictVariable.is_matching_cls_hf(type(value)):
self.install_guards(GuardBuilder.TYPE_MATCH)
result = CustomizedDictVariable.wrap(self, value)
result.source = self.source
return self.tx.output.side_effects.track_object_existing(value, result)
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
if not value and self.get_source().is_nn_module():
# It is faster to guard on 'false' property than to guard
@ -711,9 +716,6 @@ class VariableBuilder:
)
elif np and isinstance(value, np.number):
return self.wrap_unspecialized_primitive(value)
elif DataClassVariable.is_matching_object(value):
self.install_guards(GuardBuilder.TYPE_MATCH)
return DataClassVariable.wrap(self, value)
elif HFPretrainedConfigVariable.is_matching_object(value):
self.install_guards(GuardBuilder.TYPE_MATCH)
return HFPretrainedConfigVariable(value)
@ -1701,7 +1703,7 @@ class VariableBuilder:
def _dataclasses_fields_lambda(obj):
if isinstance(obj, UserDefinedObjectVariable):
value = obj.value
elif isinstance(obj, DataClassVariable):
elif isinstance(obj, CustomizedDictVariable):
value = obj.user_cls
else:
unimplemented(f"Dataclass fields handling fails for type {obj}")

View File

@ -1633,7 +1633,6 @@ class BuiltinVariable(VariableTracker):
if isinstance(
obj,
(
variables.DataClassVariable,
variables.CustomizedDictVariable,
variables.PlacementVariable,
variables.UserDefinedObjectVariable,

View File

@ -545,6 +545,8 @@ class DictValues(DictView):
def _is_matching_transformers_cls(cls) -> bool:
mod = sys.modules.get("transformers.file_utils")
if mod is None:
mod = sys.modules.get("transformers.utils.generic")
return mod is not None and issubclass(cls, mod.ModelOutput)
@ -555,12 +557,20 @@ def _is_matching_diffusers_cls(cls) -> bool:
def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker":
"""Shared method between DataClassVariable and CustomizedDictVariable where items are attrs"""
if tx.output.side_effects.is_attribute_mutation(self):
try:
result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
return variables.ConstantVariable.create(
not isinstance(result, variables.DeletedVariable)
)
except KeyError:
pass
if name in self.items or hasattr(self.user_cls, name):
return ConstantVariable(True)
elif istype(self.mutable_local, MutableLocal) and self.source is None:
# Something created locally can't have any extra fields on it
return ConstantVariable(False)
elif self.mutable_local is None and self.source:
elif self.source:
# Maybe add a guard
try:
example = tx.output.root_tx.get_example_value(self.source)
@ -577,152 +587,27 @@ def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker":
class DataClassVariable(ConstDictVariable):
"""
This is a bit of a hack to deal with
transformers.file_utils.ModelOutput() from huggingface.
This class doesn't appear to be used anywhere.
It used to be used to deal with transformers.file_utils.ModelOutput
from huggingface.
ModelOutput causes trouble because it a a mix of a dataclass and a
OrderedDict and it calls super() methods implemented in C.
Keeping since we wish to support dataclasses in general in the future
"""
# ModelOutput() excludes None, though generic datclasses don't
include_none = False
@staticmethod
@functools.lru_cache(None)
def _patch_once():
try:
from transformers.file_utils import ModelOutput
for obj in ModelOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
except ImportError:
pass
try:
from diffusers.utils import BaseOutput
for obj in BaseOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
except ImportError:
pass
@staticmethod
def is_matching_cls(cls):
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
@classmethod
def create(cls, user_cls, args, kwargs, options):
DataClassVariable._patch_once()
skip_code(user_cls.__init__.__code__)
keys = [f.name for f in dataclasses.fields(user_cls)]
bound = inspect.signature(user_cls).bind(*args, **kwargs)
bound.apply_defaults()
assert set(bound.arguments.keys()) == set(keys)
items = {}
for key in keys:
val = bound.arguments[key]
key = ConstantVariable.create(key)
if isinstance(val, VariableTracker):
items[key] = val
else:
if cls.include_none:
assert variables.ConstantVariable.is_literal(val)
items[key] = variables.ConstantVariable.create(val)
else:
assert val is None, f"unexpected {val}"
if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable):
unimplemented("DataClassVariable iterator constructor")
# TODO(jansel): implement unpacking logic in ModelOutput.__post_init__
return cls(items, user_cls, **options)
@classmethod
def wrap(cls, builder, obj):
user_cls = type(obj)
keys = [f.name for f in dataclasses.fields(user_cls)]
excluded = []
items = {}
for key in keys:
# __init__ function of a dataclass might not have yet defined the key
if hasattr(obj, key):
val = getattr(obj, key)
var = builder.__class__(
tx=builder.tx, source=AttrSource(builder.source, key)
)(val)
if val is not None or cls.include_none:
key = ConstantVariable.create(key)
items[key] = var
else:
excluded.append(var)
return cls(items, user_cls)
def __init__(self, items, user_cls, **options):
super().__init__(items, user_cls, **options)
assert self.is_matching_cls(user_cls)
def as_proxy(self):
raise NotImplementedError
def reconstruct(self, codegen):
codegen.extend_output([codegen._create_load_const(self.user_cls)])
# All the keys are just wrapped strings
d = self.keys_as_python_constant()
codegen.foreach(d.values())
keys = tuple(d.keys())
codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True))
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
assert not kwargs and len(args) == 1
val = args[0]
if val.python_type() == str:
return self.getitem_const(val)
else:
return self.call_method(tx, "to_tuple", [], {}).call_method(
tx, "__getitem__", args, kwargs
)
elif name == "to_tuple":
assert not (args or kwargs)
return variables.TupleVariable(list(self.items.values()))
elif name == "__setattr__":
name = "__setitem__"
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name: str) -> "VariableTracker":
name_vt = ConstantVariable.create(name)
if name_vt in self:
return self.call_method(tx, "__getitem__", [name_vt], {})
elif not self.include_none:
defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
if name in defaults:
assert variables.ConstantVariable.is_literal(defaults[name])
return variables.ConstantVariable.create(defaults[name])
super().var_getattr(tx, name)
call_hasattr = _call_hasattr_customobj
pass
class CustomizedDictVariable(ConstDictVariable):
@staticmethod
def is_matching_cls_hf(cls):
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
@staticmethod
def is_matching_cls(cls):
# True if using default OrderedDict.__init__ and did not implement __post_init__
if (
issubclass(cls, collections.OrderedDict)
and cls is not collections.OrderedDict
and cls.__init__ is collections.OrderedDict.__init__
and not hasattr(cls, "__post_init__")
):
@ -730,7 +615,7 @@ class CustomizedDictVariable(ConstDictVariable):
# hack for HF usecase:
# assume dataclass annotation for ModelOutput subclass
# assume self.create is AA to ModelOutput.__post_init__
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
return CustomizedDictVariable.is_matching_cls_hf(cls)
@classmethod
def is_matching_object(cls, obj):
@ -764,9 +649,7 @@ class CustomizedDictVariable(ConstDictVariable):
)
bound_args = {}
if _is_matching_transformers_cls(user_cls) or _is_matching_diffusers_cls(
user_cls
):
if cls.is_matching_cls_hf(user_cls):
# Skip none
for k, v in bound.arguments.items():
if isinstance(v, ConstantVariable) and v.value is None or v is None:
@ -792,7 +675,27 @@ class CustomizedDictVariable(ConstDictVariable):
# called from builder.py
@classmethod
def wrap(cls, builder, obj):
raise NotImplementedError
user_cls = type(obj)
if not cls.is_matching_cls_hf(user_cls):
unimplemented("custom non-hf dict subclass wrap unimplemented")
items = builder.__class__(tx=builder.tx, source=builder.source)(
collections.OrderedDict(obj)
).items
keys = [f.name for f in dataclasses.fields(user_cls)]
for key in keys:
# __init__ function of a dataclass might not have yet defined the key
if hasattr(obj, key):
val = getattr(obj, key)
var = builder.__class__(
tx=builder.tx, source=AttrSource(builder.source, key)
)(val)
if val is not None:
key = ConstantVariable.create(key)
items[key] = var
return cls(items, user_cls)
def __init__(self, items, user_cls, **options):
super().__init__(items, user_cls, **options)
@ -804,9 +707,7 @@ class CustomizedDictVariable(ConstDictVariable):
# 'RETURN_VALUE triggered compile'
# called from torch/_dynamo/codegen.py
def reconstruct(self, codegen):
is_hf_model_output = _is_matching_transformers_cls(
self.user_cls
) or _is_matching_diffusers_cls(self.user_cls)
is_hf_model_output = self.is_matching_cls_hf(self.user_cls)
# If the user class is a ModelOutput, then wrap the instance creation in
# torch._dynamo.disable(). Even though we mark the __post_init__ as skip
@ -848,21 +749,34 @@ class CustomizedDictVariable(ConstDictVariable):
):
# for python dict method without overridden
return super().call_method(tx, name, args, kwargs)
elif name in ("__getitem__", "to_tuple", "__setitem__", "__setattr__"):
elif name in (
"__getitem__",
"to_tuple",
"__setitem__",
"__setattr__",
"__post_init__",
):
# for user overridden method
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=source),
[self] + list(args),
kwargs,
)
elif fn is getattr(collections.OrderedDict, name, None):
return super().call_method(tx, name, args, kwargs)
unimplemented("custom dict: call_method unimplemented name=%s", name)
unimplemented(f"custom dict: call_method unimplemented name={name}")
def var_getattr(self, tx, name: str) -> "VariableTracker":
name_vt = ConstantVariable.create(name)
if name_vt in self:
return self.call_method(tx, "__getitem__", [name_vt], {})
super().var_getattr(tx, name)
if dataclasses.is_dataclass(self.user_cls):
defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
if name in defaults:
assert variables.ConstantVariable.is_literal(defaults[name])
return variables.ConstantVariable.create(defaults[name])
return super().var_getattr(tx, name)
call_hasattr = _call_hasattr_customobj

View File

@ -171,6 +171,12 @@ class SuperVariable(VariableTracker):
return super(variables.CustomizedDictVariable, self.objvar).call_method(
tx, "__setitem__", args, kwargs
)
elif inner_fn is collections.OrderedDict.__getitem__ and isinstance(
self.objvar, variables.CustomizedDictVariable
):
return super(variables.CustomizedDictVariable, self.objvar).call_method(
tx, "__getitem__", args, kwargs
)
elif is_standard_setattr(inner_fn) and isinstance(
self.objvar, UserDefinedObjectVariable
):

View File

@ -396,9 +396,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
return variables.CustomizedDictVariable.create(
self.value, args, kwargs, options
)
elif variables.DataClassVariable.is_matching_cls(self.value):
options = {"mutable_local": MutableLocal()}
return variables.DataClassVariable.create(self.value, args, kwargs, options)
elif (
variables.RestrictedListSubclassVariable.is_matching_cls(self.value)
and self.source

View File

@ -3247,13 +3247,13 @@ void install_tensor_aliasing_guard(
void install_no_tensor_aliasing_guard(
const py::list& guard_managers,
py::list tensor_names,
const py::list& tensor_names,
py::object verbose_code_parts) {
// Adds a guard that checks none of tensors alias. This is a an example of
// relational guard. There is one guard object that is shared between multiple
// guard managers.
std::shared_ptr<RelationalGuard> guard = std::make_shared<NO_TENSOR_ALIASING>(
std::move(tensor_names), std::move(verbose_code_parts));
tensor_names, std::move(verbose_code_parts));
// Register the resetter on the toor guard mananger, so that it can reset
// the newly added relational guard when the guard eval fails.
@ -4006,7 +4006,15 @@ PyObject* torch_c_dynamo_guards_init() {
DictSubclassGuardManager,
DictGuardManager,
std::unique_ptr<DictSubclassGuardManager>>(
py_m, "DictSubclassGuardManager"); // NOLINT
py_m, "DictSubclassGuardManager") // NOLINT
.def(
"add_no_hasattr_guard",
[](DictSubclassGuardManager& self,
py::object attr_name,
py::object verbose_code_parts) -> void {
self.add_permitted_leaf_guard(std::make_shared<NO_HASATTR>(
std::move(attr_name), std::move(verbose_code_parts)));
});
py_m.def("install_tensor_aliasing_guard", install_tensor_aliasing_guard);
py_m.def(