""" Flag handling routines Copyright 2015 Cray Inc. All Rights Reserved """ ### TBD: The "Simple" in the addSimple* interfaces refers to a flag ### that's a single bit. It's meant to distinguish from flags that ### have multibit fields, such as the node/zone indices stuck in the ### high end of struct page.flags; or a field that's mostly a pointer ### but with some flags in the low bits. # ### To add cases like that will mean redoing most of the ### implementation, but all the current interfaces should be ok, with ### new interfaces added to let users define the non-simple flags. import uflookup class FlagSet: """A collection of flags and values, with routines for translating For decoding a flag int to a string, encoding a flag string to an int, and providing python identifiers for testing by name, e.g., jafs = FlagSet() # job_attach flagset jafs.addSimpleFlag("disable_affinity_apply") if job_attach.flags & jafs.disable_affinity_apply: ... The advantages over just using a dict include: * Define the values once, and get value->string, string->value, and python identifiers ns. and ns._shift as above. * The auto-incrementing _next_bit """ def __init__(self, mapping=None): """Create and initialize a FlagSet object Arguments: mapping: if specified, provides a mapping object, e.g. dict, that supplies the initial key(name)/value pairs. """ # Public dict of flag names to flag values (not the bit number) self.str_to_value = {} # Public dict of flag values to flag names self.value_to_str = {} self._next_bit = 0 # sorted_values is so that translating a value to a string # will report the strings in the same order every time. That # order is by numerically increasing value. self._sorted_values = [] self._sorted_strs = [] if mapping is not None: self.addMap(mapping) def addSimpleFlag(self, s, bit=None): """Add a single-bit flag. If bit is not specified, uses the bit one greater than the previously defined bit. If multiple flags are defined to use the same bit, value_to_str will remember only the first.""" if s in self.str_to_value.keys(): raise ValueError("Flag {0} already defined (value {1:x})".format( s, self.str_to_value[s])) if s + "_shift" in self.str_to_value.keys(): raise ValueError("Flag {0} conflicts with another " "flag ({1})".format(s, s + "_shift")) try: getattr(self, s) except AttributeError: pass else: raise ValueError("Value {0} already used by FlagSet object!". format(s)) try: getattr(self, s + "_shift") except AttributeError: pass else: raise valueError("{0}_shift already used by FlagSet object!". format(s)) if bit is None: bit = self._next_bit; self._next_bit = bit + 1 value = 1 << bit if value not in self.value_to_str: self.value_to_str[value] = s self.str_to_value[s] = value self._sorted_values = [] setattr(self, s, value) setattr(self, s+"_shift", bit) def addSimpleFlags(self, *l): """Adds a list of single-bit flags.""" map(self.addSimpleFlag, l) def addMap(self, mapping): """Add the key/value pairs from a mapping type""" for k, v in mapping.items(): self.addSimpleFlag(k, v) def _EnsureSorted(self): if self._sorted_values: return self._sorted_values = sorted(self.value_to_str.keys()) # self._sorted_strs = sorted(self.str_to_value.keys()) def flagsToStringList(self, flagint): """Translate a given flag int to a list of flag strings.""" self._EnsureSorted() strs = [] for v in self._sorted_values: if flagint & v != 0: strs.append(self.value_to_str[v]) flagint &= ~v if flagint != 0: strs.append("{0:#x}".format(flagint)) return strs def UFLookup(self, key, **kwargs): return uflookup.UFLookup(self.str_to_value, key, **kwargs) # TBD: interface to enable a script --dump-flag-translations argument? def join_flaglist(fl, sep = "|", empty = "0"): """Helper function to join a list of flag strings.""" if fl: return sep.join(fl) else: return empty ### Tests # I'm trying to follow the convention of # assertEquals(expectedvalue, function_under_test(args)) # I didn't discover that (on some unittest page) until I was halfway # through, so I may not have gotten them all the right order. if __name__ == '__main__': import unittest class Test_join_flaglist(unittest.TestCase): """Test the join_flaglist function""" def assertJoinFlaglistEqual(self, expectedstring, flaglist): self.assertEqual(expectedstring, join_flaglist(flaglist)) def test_single_value(self): """Test join_flaglist() with a single value""" self.assertJoinFlaglistEqual("aflag", ["aflag"]) def test_two_values(self): """Test join_flaglist() with two values""" self.assertJoinFlaglistEqual("aflag|bflag",["aflag", "bflag"]) def test_three_values(self): """Test join_flaglist() with three values""" self.assertJoinFlaglistEqual("af|bf|cf", ["af", "bf", "cf"]) def test_comma_sep(self): """Test join_flaglist() with a non-default sep""" self.assertEqual("af,bf,cf", join_flaglist(["af", "bf", "cf"], sep=',')) def test_join_empty(self): """Test join_flaglist() with an empty list""" self.assertEqual("0", join_flaglist([])) def test_join_empty_nondefault(self): """Test join_flaglist() with a non-default value of empty""" self.assertEqual(" ", join_flaglist([], empty=" ")) class Test_FlagSet(unittest.TestCase): """Test the FlagSet class""" def setUp(self): self.fs = FlagSet() def VerifyFlag(self, string, value): """Test string->value and value->string""" self.assertEqual(value, self.fs.str_to_value[string]) self.assertEqual(string, self.fs.value_to_str[value]) self.assertEqual(value, getattr(self.fs, string)) self.assertEqual(value, 1<