]> granicus.if.org Git - libevent/commitdiff
support for arrays on structs.
authorNiels Provos <provos@gmail.com>
Sun, 28 Aug 2005 23:48:16 +0000 (23:48 +0000)
committerNiels Provos <provos@gmail.com>
Sun, 28 Aug 2005 23:48:16 +0000 (23:48 +0000)
svn:r178

event_rpcgen.py
test/regress.c

index 7af014f5ca58ff0515980f33b0910f886cddb01f..50235d6cd653bf81bebe3bfe165f3af1812c527c 100755 (executable)
@@ -45,6 +45,8 @@ class Struct:
         return self._name
 
     def EntryTagName(self, entry):
+        """Creates the name inside an enumeration for distinguishing data
+        types."""
         name = "%s_%s" % (self._name, entry.Name())
         return name.upper()
 
@@ -54,6 +56,7 @@ class Struct:
             print >>file, '%s%s' % (ident, entry)
 
     def PrintTags(self, file):
+        """Prints the tag definitions for a structure."""
         print >>file, '/* Tag definition for %s */' % self._name
         print >>file, 'enum {'
         for entry in self._entries:
@@ -74,6 +77,9 @@ class Struct:
                 entry.AssignDeclaration('(*%s_assign)' % entry.Name()))
             dcl.extend(
                 entry.GetDeclaration('(*%s_get)' % entry.Name()))
+            if entry.Array():
+                dcl.extend(
+                    entry.AddDeclaration('(*%s_add)' % entry.Name()))
             self.PrintIdented(file, '  ', dcl)
         print >>file, ''
         for entry in self._entries:
@@ -101,6 +107,9 @@ class Struct:
                 entry.AssignFuncName()))
             self.PrintIdented(file, '', entry.GetDeclaration(
                 entry.GetFuncName()))
+            if entry.Array():
+                self.PrintIdented(file, '', entry.AddDeclaration(
+                    entry.AddFuncName()))
 
         print >>file, '/* --- %s done --- */\n' % self._name
 
@@ -126,6 +135,12 @@ class Struct:
         print >>file, ('  return (tmp);\n'
                        '}\n')
 
+        # Adding
+        for entry in self._entries:
+            if entry.Array():
+                self.PrintIdented(file, '', entry.CodeAdd())
+            print >>file, ''
+            
         # Assigning
         for entry in self._entries:
             self.PrintIdented(file, '', entry.CodeAssign())
@@ -191,10 +206,13 @@ class Struct:
                        '    switch (tag) {\n'
                        )
         for entry in self._entries:
-            print >>file, ('      case %s:\n' % self.EntryTagName(entry) +
-                           '        if (tmp->%s_set)\n'
-                           '          return (-1);'
-                           ) % (entry.Name())
+            print >>file, '      case %s:\n' % self.EntryTagName(entry)
+            if not entry.Array():
+                print >>file, (
+                    '        if (tmp->%s_set)\n'
+                    '          return (-1);'
+                    ) % (entry.Name())
+
             self.PrintIdented(
                 file, '        ',
                 entry.CodeUnmarshal('evbuf',
@@ -276,6 +294,7 @@ class Entry:
         self._tag = int(tag)
         self._ctype = type
         self._optional = 0
+        self._can_be_array = 0
         self._array = 0
         self._line_count = -1
         self._struct = None
@@ -290,6 +309,9 @@ class Entry:
     def SetLineCount(self, number):
         self._line_count = number
 
+    def Array(self):
+        return self._array
+
     def Optional(self):
         return self._optional
 
@@ -302,8 +324,8 @@ class Entry:
     def Type(self):
         return self._type
 
-    def MakeArray(self):
-        self._array = 1
+    def MakeArray(self, yes=1):
+        self._array = yes
         
     def MakeOptional(self):
         self._optional = 1
@@ -332,6 +354,9 @@ class Entry:
     def AssignFuncName(self):
         return '%s_%s_assign' % (self._struct.Name(), self._name)
     
+    def AddFuncName(self):
+        return '%s_%s_add' % (self._struct.Name(), self._name)
+    
     def AssignDeclaration(self, funcname):
         code = [ 'int %s(struct %s *, const %s);' % (
             funcname, self._struct.Name(), self._ctype ) ]
@@ -372,13 +397,23 @@ class Entry:
                  '%s->%s_get = %s_%s_get;' % (
             name, self._name, self._struct.Name(), self._name ),
         ]
+        if self.Array():
+            code.append(
+                '%s->%s_add = %s_%s_add;' % (
+                name, self._name, self._struct.Name(), self._name ) )
         return code
 
     def Verify(self):
+        if self.Array() and not self._can_be_array:
+            print >>sys.stderr, (
+                'Entry "%s" cannot be created as an array '
+                'around line %d' ) % (self._name, self.LineCount())
+            sys.exit(1)
         if not self._struct:
             print >>sys.stderr, (
                 'Entry "%s" does not know which struct it belongs to '
                 'around line %d' ) % (self._name, self.LineCount())
+            sys.exit(1)
         if self._optional and self._array:
             print >>sys.stderr,  ( 'Entry "%s" has illegal combination of '
                                    'optional and array around line %d' ) % (
@@ -569,6 +604,7 @@ class EntryStruct(Entry):
         # Init base class
         Entry.__init__(self, type, name, tag)
 
+        self._can_be_array = 1
         self._refname = refname
         self._ctype = 'struct %s' % refname
 
@@ -807,6 +843,199 @@ class EntryVarBytes(Entry):
 
         return dcl
 
+class EntryArray(Entry):
+    def __init__(self, entry):
+        # Init base class
+        Entry.__init__(self, entry._type, entry._name, entry._tag)
+
+        self._entry = entry
+        self._refname = entry._refname
+        self._ctype = 'struct %s' % self._refname
+
+    def GetDeclaration(self, funcname):
+        """Allows direct access to elements of the array."""
+        code = [ 'int %s(struct %s *, int, %s **);' % (
+            funcname, self._struct.Name(), self._ctype ) ]
+        return code
+        
+    def AssignDeclaration(self, funcname):
+        code = [ 'int %s(struct %s *, int, const %s *);' % (
+            funcname, self._struct.Name(), self._ctype ) ]
+        return code
+        
+    def AddDeclaration(self, funcname):
+        code = [ '%s *%s(struct %s *);' % (
+            self._ctype, funcname, self._struct.Name() ) ]
+        return code
+        
+    def CodeGet(self):
+        name = self._name
+        code = [ 'int',
+                 '%s_%s_get(struct %s *msg, int offset, %s **value)' % (
+            self._struct.Name(), name,
+            self._struct.Name(), self._ctype),
+                 '{',
+                 '  if (msg->%s_set != 1)' % name,
+                 '    return (-1);',
+                 '  if (offset >= msg->%s_length)' % name,
+                 '    return (-1);',
+                 '  *value = msg->%s_data[offset];' % name,
+                 '  return (0);',
+                 '}' ]
+        return code
+        
+    def CodeAssign(self):
+        name = self._name
+        code = [ 'int',
+                 '%s_%s_assign(struct %s *msg, int off, const %s *value)' % (
+            self._struct.Name(), name,
+            self._struct.Name(), self._ctype),
+                 '{',
+                 '  struct evbuffer *tmp = NULL;',
+                 '  if (msg->%s_set != 1)' % name,
+                 '    return (-1);',
+                 '  if (off >= msg->%s_length)' % name,
+                 '    return (-1);',
+                 '',
+                 '  %s_clear(msg->%s_data[off]);' % (self._refname, name),
+                 '  if ((tmp = evbuffer_new()) == NULL) {',
+                 '    event_warn("%s: evbuffer_new()", __func__);',
+                 '    goto error;',
+                 '  }',
+                 '  %s_marshal(tmp, value); ' % self._refname,
+                 '  if (%s_unmarshal(msg->%s_data[off], tmp) == -1) {' % (
+            self._refname, name ),
+                 '    event_warnx("%%s: %s_unmarshal", __func__);' % (
+            self._refname),
+                 '    goto error;',
+                 '  }',
+                 '  evbuffer_free(tmp);',
+                 '  return (0);',
+                 ' error:',
+                 '  if (tmp != NULL)',
+                 '    evbuffer_free(tmp);',
+                 '  %s_clear(msg->%s_data[off]);' % (self._refname, name),
+                 '  return (-1);',
+                 '}' ]
+        return code
+        
+    def CodeAdd(self):
+        name = self._name
+        code = [
+            '%s *' % self._ctype, 
+            '%s_%s_add(struct %s *msg)' % (
+            self._struct.Name(), name, self._struct.Name()),
+            '{',
+            '  msg->%s_length++;' % name,
+            '  msg->%s_data = (struct %s**)realloc(msg->%s_data, '
+            '  msg->%s_length * sizeof(struct %s*));' % (
+            name, self._refname, name, name, self._refname ),
+            '  if (msg->%s_data == NULL)' % name,
+            '    return (NULL);',
+            '  msg->%s_data[msg->%s_length - 1] = %s_new();' % (
+            name, name, self._refname),
+            '  if (msg->%s_data[msg->%s_length - 1] == NULL) {' % (name, name),
+            '    msg->%s_length--; ' % name,
+            '    return (NULL);',
+            '  }',
+            '  msg->%s_set = 1;' % name,
+            '  return (msg->%s_data[msg->%s_length - 1]);' % (name, name),
+            '}'
+            ]
+        return code
+        
+    def CodeComplete(self, structname):
+        code = []
+        if self.Optional():
+            code.append( 'if (%s->%s_set)'  % (structname, self.Name()))
+
+        code.extend(['{',
+                     '  int i;',
+                     '  for (i = 0; i < %s->%s_length; ++i) {' % (
+                structname, self.Name()),
+                     '    if (%s_complete(%s->%s_data[i]) == -1)' % (
+                self._refname, structname, self.Name()),
+                     '      return (-1);',
+                     '  }',
+                     '}'
+                     ])
+
+        return code
+    
+    def CodeUnmarshal(self, buf, tag_name, var_name):
+        code = ['if (%s_%s_add(%s) == NULL)' % (
+            self._struct.Name(), self._name, var_name),
+                '  return (-1);',
+                'if (evtag_unmarshal_%s(%s, %s, '
+                '%s->%s_data[%s->%s_length - 1]) == -1) {' % (
+            self._refname, buf, tag_name, var_name, self._name,
+            var_name, self._name),
+                '  %s->%s_length--; ' % (var_name, self._name),
+                '  event_warnx("%%s: failed to unmarshal %s", __func__);' % (
+            self._name ),
+                '  return (-1);',
+                '}'
+                ]
+        return code
+
+    def CodeMarshal(self, buf, tag_name, var_name):
+        code = ['{',
+                '  int i;',
+                '  for (i = 0; i < %s->%s_length; ++i) {' % (
+            var_name, self._name),
+                '    evtag_marshal_%s(%s, %s, %s->%s_data[i]);' % (
+            self._refname, buf, tag_name, var_name, self._name),
+                '  }',
+                '}'
+                ]
+        return code
+
+    def CodeClear(self, structname):
+        code = [ 'if (%s->%s_set == 1) {' % (structname, self.Name()),
+                 '  int i;',
+                 '  for (i = 0; i < %s->%s_length; ++i) {' % (
+            structname, self.Name()),
+                 '    %s_free(%s->%s_data[i]);' % (
+            self._refname, structname, self.Name()),
+                 '  }',
+                 '  free(%s->%s_data);' % (structname, self.Name()),
+                 '  %s->%s_data = NULL;' % (structname, self.Name()),
+                 '  %s->%s_set = 0;' % (structname, self.Name()),
+                 '  %s->%s_length = 0;' % (structname, self.Name()),
+                 '}'
+                 ]
+
+        return code
+        
+    def CodeNew(self, name):
+        code  = ['%s->%s_data = NULL;' % (name, self._name),
+                 '%s->%s_length = 0;' % (name, self._name)]
+        code.extend(Entry.CodeNew(self, name))
+        return code
+
+    def CodeFree(self, name):
+        code  = ['if (%s->%s_data != NULL) {' % (name, self._name),
+                 '  int i;',
+                 '  for (i = 0; i < %s->%s_length; ++i) {' % (
+            name, self._name),
+                 '    %s_free(%s->%s_data[i]); ' % (
+            self._refname, name, self._name),
+                 '    %s->%s_data[i] = NULL;' % (name, self._name),
+                 '  }',
+                 '  free(%s->%s_data);' % (name, self._name),
+                 '  %s->%s_data = NULL;' % (name, self._name),
+                 '  %s->%s_length = 0;' % (name, self._name),
+                 '}'
+                 ]
+
+        return code
+
+    def Declaration(self):
+        dcl  = ['struct %s **%s_data;' % (self._refname, self._name),
+                'int %s_length;' % self._name]
+
+        return dcl
+
 def NormalizeLine(line):
     global leading
     global trailing
@@ -901,6 +1130,8 @@ def ProcessOneEntry(newstruct, entry):
         else:
             print >>sys.stderr, 'Bad type: "%s" in "%s"' % (type, entry)
             sys.exit(1)
+
+    structs = []
         
     if optional:
         newentry.MakeOptional()
@@ -911,8 +1142,20 @@ def ProcessOneEntry(newstruct, entry):
     newentry.SetLineCount(line_count)
     newentry.Verify()
 
+    if array:
+        # We need to encapsulate this entry into a struct
+        newname = newentry.Name()+ '_array'
+
+        # Now borgify the new entry.
+        newentry = EntryArray(newentry)
+        newentry.SetStruct(newstruct)
+        newentry.SetLineCount(line_count)
+        newentry.MakeArray()
+
     newstruct.AddEntry(newentry)
 
+    return structs
+
 def ProcessStruct(data):
     tokens = data.split(' ')
 
@@ -923,14 +1166,18 @@ def ProcessStruct(data):
 
     tokens = inside.split(';')
 
+    structs = []
+
     for entry in tokens:
         entry = NormalizeLine(entry)
         if not entry:
             continue
-        
-        ProcessOneEntry(newstruct, entry)
 
-    return newstruct
+        # It's possible that new structs get defined in here
+        structs.extend(ProcessOneEntry(newstruct, entry))
+
+    structs.append(newstruct)
+    return structs
 
 def GetNextStruct(file):
     global line_count
@@ -1017,7 +1264,7 @@ def Parse(file):
         if not data:
             break
 
-        entities.append(ProcessStruct(data))
+        entities.extend(ProcessStruct(data))
 
     return entities
 
@@ -1038,11 +1285,13 @@ def HeaderPreamble(name):
         '#define %s\n\n' ) % (
         name, guard, guard)
     pre += (
-        '#define EVTAG_HAS(msg, member) ((msg)->member ## _set == 1)\n'
+        '#define EVTAG_HAS(msg, member) ((msg)->member##_set == 1)\n'
         '#define EVTAG_ASSIGN(msg, member, args...) '
         '(*(msg)->member##_assign)(msg, ## args)\n'
         '#define EVTAG_GET(msg, member, args...) '
         '(*(msg)->member##_get)(msg, ## args)\n'
+        '#define EVTAG_ADD(msg, member) (*(msg)->member##_add)(msg)\n'
+        '#define EVTAG_LEN(msg, member) ((msg)->member##_length)\n'
         )
 
     return pre
@@ -1066,6 +1315,7 @@ def BodyPreamble(name):
              '#include <sys/time.h>\n'
              '#include <stdlib.h>\n'
              '#include <string.h>\n'
+             '#include <assert.h>\n'
              '#include <event.h>\n\n' )
 
     for include in cppdirect:
index 52b7e6f31188e5ab0fec202e199256a4adcadb7f..026597cba90f778f6095034739b1aaba7282961b 100644 (file)
@@ -738,7 +738,9 @@ rpc_test(void)
 {
        struct msg *msg, *msg2;
        struct kill *kill;
+       struct run *run;
        struct evbuffer *tmp = evbuffer_new();
+       int i;
 
        fprintf(stdout, "Testing RPC: ");
 
@@ -754,6 +756,15 @@ rpc_test(void)
        EVTAG_ASSIGN(kill, weapon, "feather");
        EVTAG_ASSIGN(kill, action, "tickle");
 
+       for (i = 0; i < 3; ++i) {
+               run = EVTAG_ADD(msg, run);
+               if (run == NULL) {
+                       fprintf(stderr, "Failed to add run message.\n");
+                       exit(1);
+               }
+               EVTAG_ASSIGN(run, how, "very fast");
+       }
+
        if (msg_complete(msg) == -1) {
                fprintf(stderr, "Failed to make complete message.\n");
                exit(1);
@@ -774,6 +785,11 @@ rpc_test(void)
                exit(1);
        }
 
+       if (EVTAG_LEN(msg2, run) != 3) {
+               fprintf(stderr, "Wrong number of run messages.\n");
+               exit(1);
+       }
+
        msg_free(msg);
        msg_free(msg2);