]> granicus.if.org Git - postgresql/commitdiff
Replace plain-memory ordered array by binary tree in ts_stat() function.
authorTeodor Sigaev <teodor@sigaev.ru>
Mon, 17 Nov 2008 12:17:09 +0000 (12:17 +0000)
committerTeodor Sigaev <teodor@sigaev.ru>
Mon, 17 Nov 2008 12:17:09 +0000 (12:17 +0000)
Performance is increased from 50% up to 10^3 times depending on data.

src/backend/utils/adt/tsvector_op.c

index bc342839d99d174d90bd32ee31d0a35391818da1..5cb4f4f1d9db5c9182a35b5145b32912d69585f3 100644 (file)
@@ -7,7 +7,7 @@
  *
  *
  * IDENTIFICATION
- *       $PostgreSQL: pgsql/src/backend/utils/adt/tsvector_op.c,v 1.17 2008/11/10 21:49:16 alvherre Exp $
+ *       $PostgreSQL: pgsql/src/backend/utils/adt/tsvector_op.c,v 1.18 2008/11/17 12:17:09 teodor Exp $
  *
  *-------------------------------------------------------------------------
  */
@@ -34,34 +34,33 @@ typedef struct
        char       *operand;
 } CHKVAL;
 
-typedef struct
-{
-       uint32          cur;
-       TSVector        stat;
-} StatStorage;
 
-typedef struct
+typedef struct StatEntry
 {
-       uint32          len;
-       uint32          pos;
-       uint32          ndoc;
+       uint32          ndoc; /* zero indicates that we already was here while
+                                                walking throug the tree */
        uint32          nentry;
+       struct StatEntry *left;
+       struct StatEntry *right;
+       uint32          lenlexeme;
+       char            lexeme[1];
 } StatEntry;
 
+#define STATENTRYHDRSZ (offsetof(StatEntry, lexeme))
+
 typedef struct
 {
-       int32           vl_len_;                /* varlena header (do not touch directly!) */
-       int4            size;
        int4            weight;
-       char            data[1];
-} tsstat;
 
-#define STATHDRSIZE (sizeof(int4) * 4)
-#define CALCSTATSIZE(x, lenstr) ( (x) * sizeof(StatEntry) + STATHDRSIZE + (lenstr) )
-#define STATPTR(x)     ( (StatEntry*) ( (char*)(x) + STATHDRSIZE ) )
-#define STATSTRPTR(x)  ( (char*)(x) + STATHDRSIZE + ( sizeof(StatEntry) * ((TSVector)(x))->size ) )
-#define STATSTRSIZE(x) ( VARSIZE((TSVector)(x)) - STATHDRSIZE - ( sizeof(StatEntry) * ((TSVector)(x))->size ) )
+       uint32          maxdepth;
+       
+       StatEntry       **stack;
+       uint32          stackpos;
 
+       StatEntry*      root;
+} TSVectorStat;
+
+#define STATHDRSIZE (offsetof(TSVectorStat, data))
 
 static Datum tsvector_update_trigger(PG_FUNCTION_ARGS, bool config_column);
 
@@ -801,92 +800,95 @@ check_weight(TSVector txt, WordEntry *wptr, int8 weight)
        return num;
 }
 
-#define compareStatWord(a,e,s,t) \
-       tsCompareString(STATSTRPTR(s) + (a)->pos, (a)->len,     \
+#define compareStatWord(a,e,t)                                                         \
+       tsCompareString((a)->lexeme, (a)->lenlexeme,            \
                                        STRPTR(t) + (e)->pos, (e)->len,         \
                                        false)
 
-typedef struct WordEntryMark
+static void
+insertStatEntry(MemoryContext persistentContext, TSVectorStat *stat, TSVector txt, uint32 off)
 {
-       WordEntry       *newentry;
-       StatEntry       *pos;
-} WordEntryMark;
+       WordEntry       *we = ARRPTR(txt) + off;
+       StatEntry       *node = stat->root, 
+                               *pnode=NULL;
+       int                     n,
+                               res;
+       uint32          depth=1;
+
+       if (stat->weight == 0) 
+               n = (we->haspos) ? POSDATALEN(txt, we) : 1;
+       else
+               n = (we->haspos) ? check_weight(txt, we, stat->weight) : 0;
 
-static tsstat *
-formstat(tsstat *stat, TSVector txt, List *entries)
-{
-       tsstat             *newstat;
-       uint32                  totallen,
-                                       nentry,
-                                       len = list_length(entries);
-       uint32                  slen = 0;
-       WordEntry          *ptr;
-       char               *curptr;
-       StatEntry          *sptr,
-                                  *nptr;
-       ListCell           *entry;
-       StatEntry          *PosSE = STATPTR(stat),
-                                  *prevPosSE;
-       WordEntryMark  *mark;
-
-       foreach( entry, entries )
-       {
-               mark = (WordEntryMark*)lfirst(entry);
-               slen += mark->newentry->len;
-       }
+       if ( n == 0 )
+               return; /* nothing to insert */
 
-       nentry = stat->size + len;
-       slen += STATSTRSIZE(stat);
-       totallen = CALCSTATSIZE(nentry, slen);
-       newstat = palloc(totallen);
-       SET_VARSIZE(newstat, totallen);
-       newstat->weight = stat->weight;
-       newstat->size = nentry;
+       while( node ) 
+       {
+               res = compareStatWord(node, we, txt);
 
-       memcpy(STATSTRPTR(newstat), STATSTRPTR(stat), STATSTRSIZE(stat));
-       curptr = STATSTRPTR(newstat) + STATSTRSIZE(stat);
+               if (res == 0)
+               {
+                       break;
+               }
+               else
+               {
+                       pnode = node;
+                       node = ( res < 0 ) ? node->left : node->right;
+               }
+               depth++;
+       }
 
-       sptr = STATPTR(stat);
-       nptr = STATPTR(newstat);
+       if (depth > stat->maxdepth)
+               stat->maxdepth = depth;
 
-       foreach(entry, entries)
+       if (node == NULL)
        {
-               prevPosSE = PosSE;
-
-               mark = (WordEntryMark*)lfirst(entry);
-               ptr  = mark->newentry;
-               PosSE = mark->pos;
-
-               /*
-                * Copy missed entries 
-                */
-               if ( PosSE > prevPosSE )
+               node = MemoryContextAlloc(persistentContext, STATENTRYHDRSZ + we->len );
+               node->left = node->right = NULL;
+               node->ndoc = 1;
+               node->nentry = n;
+               node->lenlexeme = we->len;
+               memcpy(node->lexeme, STRPTR(txt) + we->pos, node->lenlexeme);
+
+               if ( pnode==NULL )
                {
-                       memcpy( nptr, prevPosSE, sizeof(StatEntry) * (PosSE-prevPosSE) );
-                       nptr += PosSE-prevPosSE;
+                       stat->root = node;
                }
-
-               /*
-                * Copy new entry
-                */
-               if (ptr->haspos)
-                       nptr->nentry = (stat->weight) ? check_weight(txt, ptr, stat->weight) : POSDATALEN(txt, ptr);
                else
-                       nptr->nentry = 1;
-               nptr->ndoc = 1;
-               nptr->len = ptr->len;
-               memcpy(curptr, STRPTR(txt) + ptr->pos, nptr->len);
-               nptr->pos = curptr - STATSTRPTR(newstat);
-               curptr += nptr->len;
-               nptr++;
-
-               pfree(mark);
+               {
+                       if (res < 0)
+                               pnode->left = node;
+                       else
+                               pnode->right = node;
+               }
+                       
        }
+       else
+       {
+               node->ndoc++;
+               node->nentry += n;
+       }
+}
 
-       if ( PosSE < (StatEntry *) STATSTRPTR(stat) )
-               memcpy(nptr, PosSE, sizeof(StatEntry) * (stat->size - (PosSE - STATPTR(stat))));
-
-       return newstat;
+static void
+chooseNextStatEntry(MemoryContext persistentContext, TSVectorStat *stat, TSVector txt, 
+                       uint32 low, uint32 high, uint32 offset)
+{
+       uint32      pos;
+       uint32      middle = (low + high) >> 1;
+
+       pos = (low + middle) >> 1;
+       if (low != middle && pos >= offset && pos - offset < txt->size)
+               insertStatEntry( persistentContext, stat, txt, pos - offset );
+       pos = (high + middle + 1) >> 1;
+       if (middle + 1 != high && pos >= offset && pos - offset < txt->size)
+               insertStatEntry( persistentContext, stat, txt, pos - offset );
+
+       if (low != middle)
+               chooseNextStatEntry(persistentContext, stat, txt, low, middle, offset);
+       if (high != middle + 1)
+               chooseNextStatEntry(persistentContext, stat, txt, middle + 1, high, offset);
 }
 
 /*
@@ -901,115 +903,69 @@ formstat(tsstat *stat, TSVector txt, List *entries)
  *     where vector_column is a tsvector-type column in vector_table.
  */
 
-static tsstat *
-ts_accum(tsstat *stat, Datum data)
+static TSVectorStat *
+ts_accum(MemoryContext persistentContext, TSVectorStat *stat, Datum data)
 {
-       tsstat     *newstat;
-       TSVector        txt = DatumGetTSVector(data);
-       StatEntry  *sptr;
-       WordEntry  *wptr;
-       int                     n = 0;
-       List       *newentries=NIL;
-       StatEntry  *StopLow;
+       TSVector                txt = DatumGetTSVector(data);
+       uint32                  i,
+                                       nbit = 0,
+                                       offset;
 
        if (stat == NULL)
-       {                                                       /* Init in first */
-               stat = palloc(STATHDRSIZE);
-               SET_VARSIZE(stat, STATHDRSIZE);
-               stat->size = 0;
-               stat->weight = 0;
+       {       /* Init in first */
+               stat = MemoryContextAllocZero(persistentContext, sizeof(TSVectorStat));
+               stat->maxdepth = 1;
        }
 
        /* simple check of correctness */
        if (txt == NULL || txt->size == 0)
        {
-               if (txt != (TSVector) DatumGetPointer(data))
+               if (txt && txt != (TSVector) DatumGetPointer(data))
                        pfree(txt);
                return stat;
        }
 
-       sptr = STATPTR(stat);
-       wptr = ARRPTR(txt);
-       StopLow = STATPTR(stat);
-
-       while (wptr - ARRPTR(txt) < txt->size)
-       {
-               StatEntry  *StopHigh = (StatEntry *) STATSTRPTR(stat);
-               int                     cmp;
-
-               /*
-                * We do not set StopLow to begin of array because tsvector is ordered 
-                * with the sames rule, so we can search from last stopped position
-                */
-
-               while (StopLow < StopHigh)
-               {
-                       sptr = StopLow + (StopHigh - StopLow) / 2;
-                       cmp = compareStatWord(sptr, wptr, stat, txt);
-                       if (cmp == 0)
-                       {
-                               if (stat->weight == 0)
-                               {
-                                       sptr->ndoc++;
-                                       sptr->nentry += (wptr->haspos) ? POSDATALEN(txt, wptr) : 1;
-                               }
-                               else if (wptr->haspos && (n = check_weight(txt, wptr, stat->weight)) != 0)
-                               {
-                                       sptr->ndoc++;
-                                       sptr->nentry += n;
-                               }
-                               break;
-                       }
-                       else if (cmp < 0)
-                               StopLow = sptr + 1;
-                       else
-                               StopHigh = sptr;
-               }
-
-               if (StopLow >= StopHigh)
-               {                                       /* not found */
-                       if (stat->weight == 0 || check_weight(txt, wptr, stat->weight) != 0)
-                       {
-                               WordEntryMark *mark = (WordEntryMark*)palloc(sizeof(WordEntryMark));
+       i = txt->size - 1;
+       for (; i > 0; i >>= 1)
+               nbit++;
 
-                               mark->newentry = wptr;
-                               mark->pos = StopLow;
-                               newentries = lappend( newentries, mark );
+       nbit = 1 << nbit;
+       offset = (nbit - txt->size) / 2;
 
-                       }
-               }
-               wptr++;
-       }
+       insertStatEntry( persistentContext, stat, txt, (nbit >> 1) - offset );
+       chooseNextStatEntry(persistentContext, stat, txt, 0, nbit, offset);
 
-       if (list_length(newentries) == 0)
-       {                                                       /* no new words */
-               if (txt != (TSVector) DatumGetPointer(data))
-                       pfree(txt);
-               return stat;
-       }
-
-       newstat = formstat(stat, txt, newentries);
-       list_free(newentries);
-
-       if (txt != (TSVector) DatumGetPointer(data))
-               pfree(txt);
-       return newstat;
+       return stat;
 }
 
 static void
 ts_setup_firstcall(FunctionCallInfo fcinfo, FuncCallContext *funcctx,
-                                  tsstat *stat)
+                                  TSVectorStat *stat)
 {
-       TupleDesc       tupdesc;
-       MemoryContext oldcontext;
-       StatStorage *st;
+       TupleDesc               tupdesc;
+       MemoryContext   oldcontext;
+       StatEntry               *node;
+
+       funcctx->user_fctx = (void *) stat;
 
        oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
-       st = palloc(sizeof(StatStorage));
-       st->cur = 0;
-       st->stat = palloc(VARSIZE(stat));
-       memcpy(st->stat, stat, VARSIZE(stat));
-       funcctx->user_fctx = (void *) st;
+
+       stat->stack = palloc0(sizeof(StatEntry *) * (stat->maxdepth + 1));
+       stat->stackpos = 0; 
+
+       node = stat->root;
+       /* find leftmost value */
+       for (;;)
+       {
+               stat->stack[ stat->stackpos ] = node;
+               if (node->left)
+               {
+                       stat->stackpos++;
+                       node = node->left;
+               }
+               else
+                       break;
+       }
 
        tupdesc = CreateTemplateTupleDesc(3, false);
        TupleDescInitEntry(tupdesc, (AttrNumber) 1, "word",
@@ -1024,26 +980,72 @@ ts_setup_firstcall(FunctionCallInfo fcinfo, FuncCallContext *funcctx,
        MemoryContextSwitchTo(oldcontext);
 }
 
+static StatEntry *
+walkStatEntryTree(TSVectorStat *stat) 
+{
+       StatEntry       *node = stat->stack[ stat->stackpos ];
+
+       if ( node == NULL )
+               return NULL;
+
+       if ( node->ndoc != 0 )
+       {
+               /* return entry itself: we already was at left sublink */
+               return node;
+       }
+       else if (node->right && node->right != stat->stack[stat->stackpos + 1])
+       {
+               /* go on right sublink */
+               stat->stackpos++;
+               node = node->right;
+
+               /* find most-left value */
+               for (;;)
+               {
+                       stat->stack[stat->stackpos] = node;
+                       if (node->left)
+                       {
+                               stat->stackpos++;
+                               node = node->left;
+                       }
+                       else
+                               break;
+               }
+       }
+       else
+       {
+               /* we already return all left subtree, itself and  right subtree */
+               if (stat->stackpos == 0)
+                       return NULL;
+
+               stat->stackpos--;
+               return walkStatEntryTree(stat);
+       }
+
+       return node;
+}
 
 static Datum
 ts_process_call(FuncCallContext *funcctx)
 {
-       StatStorage *st;
+       TSVectorStat    *st;
+       StatEntry               *entry;
+
+       st = (TSVectorStat *) funcctx->user_fctx;
 
-       st = (StatStorage *) funcctx->user_fctx;
+       entry = walkStatEntryTree(st);
 
-       if (st->cur < st->stat->size)
+       if (entry != NULL)
        {
                Datum           result;
                char       *values[3];
                char            ndoc[16];
                char            nentry[16];
-               StatEntry  *entry = STATPTR(st->stat) + st->cur;
                HeapTuple       tuple;
 
-               values[0] = palloc(entry->len + 1);
-               memcpy(values[0], STATSTRPTR(st->stat) + entry->pos, entry->len);
-               (values[0])[entry->len] = '\0';
+               values[0] = palloc(entry->lenlexeme + 1);
+               memcpy(values[0], entry->lexeme, entry->lenlexeme);
+               (values[0])[entry->lenlexeme] = '\0';
                sprintf(ndoc, "%d", entry->ndoc);
                values[1] = ndoc;
                sprintf(nentry, "%d", entry->nentry);
@@ -1053,25 +1055,22 @@ ts_process_call(FuncCallContext *funcctx)
                result = HeapTupleGetDatum(tuple);
 
                pfree(values[0]);
-               st->cur++;
+
+               /* mark entry as already visited */
+               entry->ndoc = 0;
+
                return result;
        }
-       else
-       {
-               pfree(st->stat);
-               pfree(st);
-       }
 
        return (Datum) 0;
 }
 
-static tsstat *
-ts_stat_sql(text *txt, text *ws)
+static TSVectorStat *
+ts_stat_sql(MemoryContext persistentContext, text *txt, text *ws)
 {
        char       *query = text_to_cstring(txt);
        int                     i;
-       tsstat     *newstat,
-                          *stat;
+       TSVectorStat *stat;
        bool            isnull;
        Portal          portal;
        SPIPlanPtr      plan;
@@ -1094,10 +1093,8 @@ ts_stat_sql(text *txt, text *ws)
                                (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
                                 errmsg("ts_stat query must return one tsvector column")));
 
-       stat = palloc(STATHDRSIZE);
-       SET_VARSIZE(stat, STATHDRSIZE);
-       stat->size = 0;
-       stat->weight = 0;
+       stat = MemoryContextAllocZero(persistentContext, sizeof(TSVectorStat));
+       stat->maxdepth = 1;
 
        if (ws)
        {
@@ -1141,12 +1138,7 @@ ts_stat_sql(text *txt, text *ws)
                        Datum           data = SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 1, &isnull);
 
                        if (!isnull)
-                       {
-                               newstat = ts_accum(stat, data);
-                               if (stat != newstat && stat)
-                                       pfree(stat);
-                               stat = newstat;
-                       }
+                               stat = ts_accum(persistentContext, stat, data);
                }
 
                SPI_freetuptable(SPI_tuptable);
@@ -1169,12 +1161,12 @@ ts_stat1(PG_FUNCTION_ARGS)
 
        if (SRF_IS_FIRSTCALL())
        {
-               tsstat     *stat;
+               TSVectorStat       *stat;
                text       *txt = PG_GETARG_TEXT_P(0);
 
                funcctx = SRF_FIRSTCALL_INIT();
                SPI_connect();
-               stat = ts_stat_sql(txt, NULL);
+               stat = ts_stat_sql(funcctx->multi_call_memory_ctx, txt, NULL);
                PG_FREE_IF_COPY(txt, 0);
                ts_setup_firstcall(fcinfo, funcctx, stat);
                SPI_finish();
@@ -1194,13 +1186,13 @@ ts_stat2(PG_FUNCTION_ARGS)
 
        if (SRF_IS_FIRSTCALL())
        {
-               tsstat     *stat;
+               TSVectorStat       *stat;
                text       *txt = PG_GETARG_TEXT_P(0);
                text       *ws = PG_GETARG_TEXT_P(1);
 
                funcctx = SRF_FIRSTCALL_INIT();
                SPI_connect();
-               stat = ts_stat_sql(txt, ws);
+               stat = ts_stat_sql(funcctx->multi_call_memory_ctx, txt, ws);
                PG_FREE_IF_COPY(txt, 0);
                PG_FREE_IF_COPY(ws, 1);
                ts_setup_firstcall(fcinfo, funcctx, stat);