mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-28 18:54:57 +08:00
Add __sub__ function for schema.Struct
Summary:
This is for the ease of removing the common fields of a struct from another.
For example,
s1 = Struct(
('a', Scalar()),
('b', Scalar()),
)
s2 = Struct(('a', Scalar()))
s1 - s2 == Struct(('b', Scalar()))
More examples are provided in the code comments.
Differential Revision: D5299277
fbshipit-source-id: 7008586ffdc8e24e1eccc8757da70330c4d90370
This commit is contained in:
committed by
Facebook Github Bot
parent
8260002941
commit
e2bd3cfc8b
@ -174,6 +174,46 @@ class TestDB(unittest.TestCase):
|
||||
self.assertIn("b", sv.fields)
|
||||
self.assertEqual(0, len(sv.b.fields))
|
||||
|
||||
def testStructSubstraction(self):
|
||||
s1 = schema.Struct(
|
||||
('a', schema.Scalar()),
|
||||
('b', schema.Scalar()),
|
||||
('c', schema.Scalar()),
|
||||
)
|
||||
s2 = schema.Struct(
|
||||
('b', schema.Scalar())
|
||||
)
|
||||
s = s1 - s2
|
||||
self.assertEqual(['a', 'c'], s.field_names())
|
||||
|
||||
s3 = schema.Struct(
|
||||
('a', schema.Scalar())
|
||||
)
|
||||
s = s1 - s3
|
||||
self.assertEqual(['b', 'c'], s.field_names())
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
s1 - schema.Scalar()
|
||||
|
||||
def testStructNestedSubstraction(self):
|
||||
s1 = schema.Struct(
|
||||
('a', schema.Scalar()),
|
||||
('b', schema.Struct(
|
||||
('c', schema.Scalar()),
|
||||
('d', schema.Scalar()),
|
||||
('e', schema.Scalar()),
|
||||
('f', schema.Scalar()),
|
||||
)),
|
||||
)
|
||||
s2 = schema.Struct(
|
||||
('b', schema.Struct(
|
||||
('d', schema.Scalar()),
|
||||
('e', schema.Scalar()),
|
||||
)),
|
||||
)
|
||||
s = s1 - s2
|
||||
self.assertEqual(['a', 'b:c', 'b:f'], s.field_names())
|
||||
|
||||
def testStructAddition(self):
|
||||
s1 = schema.Struct(
|
||||
('a', schema.Scalar())
|
||||
|
||||
Reference in New Issue
Block a user