]> granicus.if.org Git - libevent/commitdiff
slight refactoring
authorNiels Provos <provos@gmail.com>
Thu, 3 Apr 2008 03:33:07 +0000 (03:33 +0000)
committerNiels Provos <provos@gmail.com>
Thu, 3 Apr 2008 03:33:07 +0000 (03:33 +0000)
svn:r700

event_rpcgen.py

index a11af31ea50fc79a7a431bb508e980eb9f0af945..8f29edbd07b8345cbf5a661b51ef14aae9907e5e 100755 (executable)
@@ -52,6 +52,12 @@ class Struct:
         name = "%s_%s" % (self._name, entry.Name())
         return name.upper()
 
+class StructCCode(Struct):
+    """ Knows how to generate C code for a struct """
+    
+    def __init__(self, name):
+        Struct.__init__(self, name)
+        
     def PrintIndented(self, file, ident, code):
         """Takes an array, add indentation to each entry and prints it."""
         for entry in code:
@@ -328,22 +334,6 @@ class Entry:
     def GetInitializer(self):
         assert 0, "Entry does not provide initializer"
 
-    def GetTranslation(self, extradict = {}):
-        mapping = {
-            "parent_name" : self._struct.Name(),
-            "name" : self._name,
-            "ctype" : self._ctype,
-            "refname" : self._refname,
-            "optpointer" : self._optpointer and "*" or "",
-            "optreference" : self._optpointer and "&" or "",
-            "optaddarg" :
-            self._optaddarg and ", const %s value" % self._ctype or ""
-            }
-        for (k, v) in extradict.items():
-            mapping[k] = v
-
-        return mapping
-    
     def SetStruct(self, struct):
         self._struct = struct
 
@@ -375,6 +365,39 @@ class Entry:
     def MakeOptional(self):
         self._optional = 1
 
+    def Verify(self):
+        if self.Array() and not self._can_be_array:
+            print >>sys.stderr, (
+                'Entry "%s" cannot be created as an array '
+                'around line %d' ) % (self._name, self.LineCount())
+            sys.exit(1)
+        if not self._struct:
+            print >>sys.stderr, (
+                'Entry "%s" does not know which struct it belongs to '
+                'around line %d' ) % (self._name, self.LineCount())
+            sys.exit(1)
+        if self._optional and self._array:
+            print >>sys.stderr,  ( 'Entry "%s" has illegal combination of '
+                                   'optional and array around line %d' ) % (
+                self._name, self.LineCount() )
+            sys.exit(1)
+
+    def GetTranslation(self, extradict = {}):
+        mapping = {
+            "parent_name" : self._struct.Name(),
+            "name" : self._name,
+            "ctype" : self._ctype,
+            "refname" : self._refname,
+            "optpointer" : self._optpointer and "*" or "",
+            "optreference" : self._optpointer and "&" or "",
+            "optaddarg" :
+            self._optaddarg and ", const %s value" % self._ctype or ""
+            }
+        for (k, v) in extradict.items():
+            mapping[k] = v
+
+        return mapping
+    
     def GetVarName(self, var):
         return '%(var)s->%(name)s_data' % self.GetTranslation({ 'var' : var })
 
@@ -451,23 +474,6 @@ class Entry:
         code = code % self.GetTranslation()
         return code.split('\n')
 
-    def Verify(self):
-        if self.Array() and not self._can_be_array:
-            print >>sys.stderr, (
-                'Entry "%s" cannot be created as an array '
-                'around line %d' ) % (self._name, self.LineCount())
-            sys.exit(1)
-        if not self._struct:
-            print >>sys.stderr, (
-                'Entry "%s" does not know which struct it belongs to '
-                'around line %d' ) % (self._name, self.LineCount())
-            sys.exit(1)
-        if self._optional and self._array:
-            print >>sys.stderr,  ( 'Entry "%s" has illegal combination of '
-                                   'optional and array around line %d' ) % (
-                self._name, self.LineCount() )
-            sys.exit(1)
-
 class EntryBytes(Entry):
     def __init__(self, type, name, tag, length):
         # Init base class
@@ -1261,7 +1267,7 @@ def NormalizeLine(line):
 
     return line
 
-def ProcessOneEntry(newstruct, entry):
+def ProcessOneEntry(factory, newstruct, entry):
     optional = 0
     array = 0
     entry_type = ''
@@ -1327,19 +1333,19 @@ def ProcessOneEntry(newstruct, entry):
     # Create the right entry
     if entry_type == 'bytes':
         if fixed_length:
-            newentry = EntryBytes(entry_type, name, tag, fixed_length)
+            newentry = factory.EntryBytes(entry_type, name, tag, fixed_length)
         else:
-            newentry = EntryVarBytes(entry_type, name, tag)
+            newentry = factory.EntryVarBytes(entry_type, name, tag)
     elif entry_type == 'int' and not fixed_length:
-        newentry = EntryInt(entry_type, name, tag)
+        newentry = factory.EntryInt(entry_type, name, tag)
     elif entry_type == 'string' and not fixed_length:
-        newentry = EntryString(entry_type, name, tag)
+        newentry = factory.EntryString(entry_type, name, tag)
     else:
         res = re.match(r'^struct\[(%s)\]$' % _STRUCT_RE,
                        entry_type, re.IGNORECASE)
         if res:
             # References another struct defined in our file
-            newentry = EntryStruct(entry_type, name, tag, res.group(1))
+            newentry = factory.EntryStruct(entry_type, name, tag, res.group(1))
         else:
             print >>sys.stderr, 'Bad type: "%s" in "%s"' % (entry_type, entry)
             sys.exit(1)
@@ -1369,11 +1375,11 @@ def ProcessOneEntry(newstruct, entry):
 
     return structs
 
-def ProcessStruct(data):
+def ProcessStruct(factory, data):
     tokens = data.split(' ')
 
     # First three tokens are: 'struct' 'name' '{'
-    newstruct = Struct(tokens[1])
+    newstruct = factory.Struct(tokens[1])
 
     inside = ' '.join(tokens[3:-1])
 
@@ -1387,7 +1393,7 @@ def ProcessStruct(data):
             continue
 
         # It's possible that new structs get defined in here
-        structs.extend(ProcessOneEntry(newstruct, entry))
+        structs.extend(ProcessOneEntry(factory, newstruct, entry))
 
     structs.append(newstruct)
     return structs
@@ -1472,7 +1478,7 @@ def GetNextStruct(file):
     return data
         
 
-def Parse(file):
+def Parse(factory, file):
     """
     Parses the input file and returns C code and corresponding header file.
     """
@@ -1486,93 +1492,114 @@ def Parse(file):
         if not data:
             break
 
-        entities.extend(ProcessStruct(data))
+        entities.extend(ProcessStruct(factory, data))
 
     return entities
 
-def GuardName(name):
-    name = '_'.join(name.split('.'))
-    name = '_'.join(name.split('/'))
-    guard = '_'+name.upper()+'_'
-
-    return guard
-
-def HeaderPreamble(name):
-    guard = GuardName(name)
-    pre = (
-        '/*\n'
-        ' * Automatically generated from %s\n'
-        ' */\n\n'
-        '#ifndef %s\n'
-        '#define %s\n\n' ) % (
-        name, guard, guard)
-
-    # insert stdint.h - let's hope everyone has it
-    pre += (
-        '#include <event-config.h>\n'
-        '#ifdef _EVENT_HAVE_STDINT_H\n'
-        '#include <stdint.h>\n'
-        '#endif\n' )
-
-    for statement in headerdirect:
-        pre += '%s\n' % statement
-    if headerdirect:
-        pre += '\n'
-
-    pre += (
-        '#define EVTAG_HAS(msg, member) ((msg)->member##_set == 1)\n'
-        '#define EVTAG_ASSIGN(msg, member, args...) '
-        '(*(msg)->base->member##_assign)(msg, ## args)\n'
-        '#define EVTAG_GET(msg, member, args...) '
-        '(*(msg)->base->member##_get)(msg, ## args)\n'
-        '#define EVTAG_ADD(msg, member, args...) '
-        '(*(msg)->base->member##_add)(msg, ## args)\n'
-        '#define EVTAG_LEN(msg, member) ((msg)->member##_length)\n'
-        )
-
-    return pre
-     
-
-def HeaderPostamble(name):
-    guard = GuardName(name)
-    return '#endif  /* %s */' % guard
-
-def BodyPreamble(name):
-    global _NAME
-    global _VERSION
+class CCodeGenerator:
+    def __init__(self):
+        pass
     
-    header_file = '.'.join(name.split('.')[:-1]) + '.gen.h'
-
-    pre = ( '/*\n'
+    def GuardName(self, name):
+        name = '_'.join(name.split('.'))
+        name = '_'.join(name.split('/'))
+        guard = '_' + name.upper() + '_'
+    
+        return guard
+    
+    def HeaderPreamble(self, name):
+        guard = self.GuardName(name)
+        pre = (
+            '/*\n'
             ' * Automatically generated from %s\n'
-            ' * by %s/%s.  DO NOT EDIT THIS FILE.\n'
-            ' */\n\n' ) % (name, _NAME, _VERSION)
-    pre += ( '#include <sys/types.h>\n'
-             '#include <sys/time.h>\n'
-             '#include <stdlib.h>\n'
-             '#include <string.h>\n'
-             '#include <assert.h>\n'
-             '#include <event.h>\n\n' )
-
-    for statement in cppdirect:
-        pre += '%s\n' % statement
+            ' */\n\n'
+            '#ifndef %s\n'
+            '#define %s\n\n' ) % (
+            name, guard, guard)
+    
+        # insert stdint.h - let's hope everyone has it
+        pre += (
+            '#include <event-config.h>\n'
+            '#ifdef _EVENT_HAVE_STDINT_H\n'
+            '#include <stdint.h>\n'
+            '#endif\n' )
     
-    pre += '\n#include "%s"\n\n' % header_file
+        for statement in headerdirect:
+            pre += '%s\n' % statement
+        if headerdirect:
+            pre += '\n'
+    
+        pre += (
+            '#define EVTAG_HAS(msg, member) ((msg)->member##_set == 1)\n'
+            '#define EVTAG_ASSIGN(msg, member, args...) '
+            '(*(msg)->base->member##_assign)(msg, ## args)\n'
+            '#define EVTAG_GET(msg, member, args...) '
+            '(*(msg)->base->member##_get)(msg, ## args)\n'
+            '#define EVTAG_ADD(msg, member, args...) '
+            '(*(msg)->base->member##_add)(msg, ## args)\n'
+            '#define EVTAG_LEN(msg, member) ((msg)->member##_length)\n'
+            )
+    
+        return pre
+    
+    def HeaderPostamble(self, name):
+        guard = self.GuardName(name)
+        return '#endif  /* %s */' % guard
+    
+    def BodyPreamble(self, name):
+        global _NAME
+        global _VERSION
+        
+        header_file = '.'.join(name.split('.')[:-1]) + '.gen.h'
+    
+        pre = ( '/*\n'
+                ' * Automatically generated from %s\n'
+                ' * by %s/%s.  DO NOT EDIT THIS FILE.\n'
+                ' */\n\n' ) % (name, _NAME, _VERSION)
+        pre += ( '#include <sys/types.h>\n'
+                 '#include <sys/time.h>\n'
+                 '#include <stdlib.h>\n'
+                 '#include <string.h>\n'
+                 '#include <assert.h>\n'
+                 '#include <event.h>\n\n' )
+    
+        for statement in cppdirect:
+            pre += '%s\n' % statement
+        
+        pre += '\n#include "%s"\n\n' % header_file
+    
+        pre += 'void event_err(int eval, const char *fmt, ...);\n'
+        pre += 'void event_warn(const char *fmt, ...);\n'
+        pre += 'void event_errx(int eval, const char *fmt, ...);\n'
+        pre += 'void event_warnx(const char *fmt, ...);\n\n'
+    
+        return pre
 
-    pre += 'void event_err(int eval, const char *fmt, ...);\n'
-    pre += 'void event_warn(const char *fmt, ...);\n'
-    pre += 'void event_errx(int eval, const char *fmt, ...);\n'
-    pre += 'void event_warnx(const char *fmt, ...);\n\n'
+    def HeaderFilename(self, filename):
+        return '.'.join(filename.split('.')[:-1]) + '.gen.h'        
 
-    return pre
+    def CodeFilename(self, filename):
+        return '.'.join(filename.split('.')[:-1]) + '.gen.c'        
 
-def main(argv):
-    if len(argv) < 2 or not argv[1]:
-        print >>sys.stderr, 'Need RPC description file as first argument.'
-        sys.exit(1)
+    def Struct(self, name):
+        return StructCCode(name)
+
+    def EntryBytes(self, entry_type, name, tag, fixed_length):
+        return EntryBytes(entry_type, name, tag, fixed_length)
+
+    def EntryVarBytes(self, entry_type, name, tag):
+        return EntryVarBytes(entry_type, name, tag)
+
+    def EntryInt(self, entry_type, name, tag):
+        return EntryInt(entry_type, name, tag)
+
+    def EntryString(self, entry_type, name, tag):
+        return EntryString(entry_type, name, tag)
 
-    filename = argv[1]
+    def EntryStruct(self, entry_type, name, tag, struct_name):
+        return EntryStruct(entry_type, name, tag, struct_name)
 
+def Generate(factory, filename):
     ext = filename.split('.')[-1]
     if ext != 'rpc':
         print >>sys.stderr, 'Unrecognized file extension: %s' % ext
@@ -1581,15 +1608,15 @@ def main(argv):
     print >>sys.stderr, 'Reading \"%s\"' % filename
 
     fp = open(filename, 'r')
-    entities = Parse(fp)
+    entities = Parse(factory, fp)
     fp.close()
 
-    header_file = '.'.join(filename.split('.')[:-1]) + '.gen.h'
-    impl_file = '.'.join(filename.split('.')[:-1]) + '.gen.c'
+    header_file = factory.HeaderFilename(filename)
+    impl_file = factory.CodeFilename(filename)
 
     print >>sys.stderr, '... creating "%s"' % header_file
     header_fp = open(header_file, 'w')
-    print >>header_fp, HeaderPreamble(filename)
+    print >>header_fp, factory.HeaderPreamble(filename)
 
     # Create forward declarations: allows other structs to reference
     # each other
@@ -1600,15 +1627,22 @@ def main(argv):
     for entry in entities:
         entry.PrintTags(header_fp)
         entry.PrintDeclaration(header_fp)
-    print >>header_fp, HeaderPostamble(filename)
+    print >>header_fp, factory.HeaderPostamble(filename)
     header_fp.close()
 
     print >>sys.stderr, '... creating "%s"' % impl_file
     impl_fp = open(impl_file, 'w')
-    print >>impl_fp, BodyPreamble(filename)
+    print >>impl_fp, factory.BodyPreamble(filename)
     for entry in entities:
         entry.PrintCode(impl_fp)
     impl_fp.close()
+    
+def main(argv):
+    if len(argv) < 2 or not argv[1]:
+        print >>sys.stderr, 'Need RPC description file as first argument.'
+        sys.exit(1)
+
+    Generate(CCodeGenerator(), argv[1])
 
 if __name__ == '__main__':
     main(sys.argv)