mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Reland "[pytorch][PR] Support dataclasses in TorchScript" take 2 (#74353)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74353 Repatched `d00de0d43598522b8f6ab2de553b6aaf6768faa5` by Nora Belrose (norabelrose). With following changes: * Register fake source of generated methods in linecache so that inspect.get_source will succeed. * this patching is only triggered if the given dataclass passed to torch.jit.script previously. Effectively we make this feature opt-in. ## Original Summary: Fixes #72901. Since we can't get access to the source code for synthesized magic methods on dataclasses, we have to synthesize our own versions. torch/jit/_dataclass_impls.py has the code that does this. What's supported Synthesized __init__, __eq__, and the comparison magic methods when order=True is set on the dataclass decorator Default values for fields __post_init__, including using InitVar fields inside of __post_init__, on Python 3.8+ Overriding __eq__ or any of the comparison magic methods to provide your own implementation What's not supported Default factory initializers for fields Frozen dataclasses InitVar on Python 3.7 __repr__ and __hash__ (these are actually implemented, but the TorchScript interpreter won't call them) Using the != operator on dataclasses inside TorchScript; this is because TorchScript requires that you implement __ne__ to use this operator, whereas in regular Python the != operator will resolve to the negation of whatever is returned by __eq__ if there's no __ne__. Dataclasses don't actually synthesize an __ne__ method for this reason. I've been toying with different ways to fix this but != is not working in this PR at the moment. Test Plan: unittest Also run previously failed test: ``` buck test mode/dev-nosan //fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests -- --exact 'fblearner/flow/projects/fluent2/definition/transformers/contrib/faim/test:tests - test_mixmatch_multiclass (fblearner.flow.projects.fluent2.definition.transformers.contrib.faim.test.faim_mixmatch_test.TestFaimTransformerMixMatch)' ``` passes Differential Revision: D35206262 Pull Request resolved: https://github.com/pytorch/pytorch/pull/74889 Approved by: https://github.com/zhxchen17
This commit is contained in:
@ -48,13 +48,20 @@ def normalize_source_lines(sourcelines: List[str]) -> List[str]:
|
||||
return text[text.startswith(prefix) and len(prefix):]
|
||||
|
||||
# Find the line and line number containing the function definition
|
||||
idx = None
|
||||
for i, l in enumerate(sourcelines):
|
||||
if l.lstrip().startswith("def"):
|
||||
idx = i
|
||||
break
|
||||
fn_def = sourcelines[idx]
|
||||
|
||||
# This will happen when the function is a lambda- we won't find "def" anywhere in the source
|
||||
# lines in that case. Currently trying to JIT compile a lambda will throw an error up in
|
||||
# `parse_def()`, but we might want to handle this case in the future.
|
||||
if idx is None:
|
||||
return sourcelines
|
||||
|
||||
# Get a string representing the amount of leading whitespace
|
||||
fn_def = sourcelines[idx]
|
||||
whitespace = fn_def.split("def")[0]
|
||||
|
||||
# Add this leading whitespace to all lines before and after the `def`
|
||||
|
Reference in New Issue
Block a user