Add __sub__ function for schema.Struct

Summary:
This is for the ease of removing the common fields of a struct from another.
For example,
  s1 = Struct(
      ('a', Scalar()),
      ('b', Scalar()),
  )
  s2 = Struct(('a', Scalar()))
  s1 - s2 == Struct(('b', Scalar()))

More examples are provided in the code comments.

Differential Revision: D5299277

fbshipit-source-id: 7008586ffdc8e24e1eccc8757da70330c4d90370
This commit is contained in:
Brian Lan
2017-06-28 11:03:05 -07:00
committed by Facebook Github Bot
parent 8260002941
commit e2bd3cfc8b
2 changed files with 110 additions and 0 deletions

View File

@ -174,6 +174,46 @@ class TestDB(unittest.TestCase):
self.assertIn("b", sv.fields)
self.assertEqual(0, len(sv.b.fields))
def testStructSubstraction(self):
s1 = schema.Struct(
('a', schema.Scalar()),
('b', schema.Scalar()),
('c', schema.Scalar()),
)
s2 = schema.Struct(
('b', schema.Scalar())
)
s = s1 - s2
self.assertEqual(['a', 'c'], s.field_names())
s3 = schema.Struct(
('a', schema.Scalar())
)
s = s1 - s3
self.assertEqual(['b', 'c'], s.field_names())
with self.assertRaises(TypeError):
s1 - schema.Scalar()
def testStructNestedSubstraction(self):
s1 = schema.Struct(
('a', schema.Scalar()),
('b', schema.Struct(
('c', schema.Scalar()),
('d', schema.Scalar()),
('e', schema.Scalar()),
('f', schema.Scalar()),
)),
)
s2 = schema.Struct(
('b', schema.Struct(
('d', schema.Scalar()),
('e', schema.Scalar()),
)),
)
s = s1 - s2
self.assertEqual(['a', 'b:c', 'b:f'], s.field_names())
def testStructAddition(self):
s1 = schema.Struct(
('a', schema.Scalar())