Reland "[pytorch][PR] Support dataclasses in TorchScript" take 2 (#74353)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74353

Repatched `d00de0d43598522b8f6ab2de553b6aaf6768faa5` by Nora Belrose (norabelrose). With following changes:
* Register fake source of generated methods in linecache so that inspect.get_source will succeed.
* this patching is only triggered if the given dataclass passed to torch.jit.script previously. Effectively we make this feature opt-in.

## Original Summary:
Fixes #72901.

Since we can't get access to the source code for synthesized magic methods on dataclasses, we have to synthesize our own versions. torch/jit/_dataclass_impls.py has the code that does this.

What's supported

Synthesized __init__, __eq__, and the comparison magic methods when order=True is set on the dataclass decorator
Default values for fields
__post_init__, including using InitVar fields inside of __post_init__, on Python 3.8+
Overriding __eq__ or any of the comparison magic methods to provide your own implementation
What's not supported

Default factory initializers for fields
Frozen dataclasses
InitVar on Python 3.7
__repr__ and __hash__ (these are actually implemented, but the TorchScript interpreter won't call them)
Using the != operator on dataclasses inside TorchScript; this is because TorchScript requires that you implement __ne__ to use this operator, whereas in regular Python the != operator will resolve to the negation of whatever is returned by __eq__ if there's no __ne__. Dataclasses don't actually synthesize an __ne__ method for this reason. I've been toying with different ways to fix this but != is not working in this PR at the moment.

Test Plan:
unittest

Also run previously failed test:
```
buck test mode/dev-nosan //fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests -- --exact 'fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests - test_mixmatch_multiclass (fblearner.flow.projects.fluent2.definition.transformers.contrib.faim.test.faim_mixmatch_test.TestFaimTransformerMixMatch)'
```
passes

Differential Revision: D35206262

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74889
Approved by: https://github.com/zhxchen17
This commit is contained in:
Han Qi
2022-03-31 00:20:48 +00:00
committed by PyTorch MergeBot
parent eb3fa2b0f1
commit 5547741960
9 changed files with 403 additions and 42 deletions

View File

@ -48,13 +48,20 @@ def normalize_source_lines(sourcelines: List[str]) -> List[str]:
return text[text.startswith(prefix) and len(prefix):]
# Find the line and line number containing the function definition
idx = None
for i, l in enumerate(sourcelines):
if l.lstrip().startswith("def"):
idx = i
break
fn_def = sourcelines[idx]
# This will happen when the function is a lambda- we won't find "def" anywhere in the source
# lines in that case. Currently trying to JIT compile a lambda will throw an error up in
# `parse_def()`, but we might want to handle this case in the future.
if idx is None:
return sourcelines
# Get a string representing the amount of leading whitespace
fn_def = sourcelines[idx]
whitespace = fn_def.split("def")[0]
# Add this leading whitespace to all lines before and after the `def`