]> granicus.if.org Git - postgresql/blob - src/backend/utils/adt/tsvector_op.c
Fix encoding issue when lc_monetary or lc_numeric are different encoding
[postgresql] / src / backend / utils / adt / tsvector_op.c
1 /*-------------------------------------------------------------------------
2  *
3  * tsvector_op.c
4  *        operations over tsvector
5  *
6  * Portions Copyright (c) 1996-2010, PostgreSQL Global Development Group
7  *
8  *
9  * IDENTIFICATION
10  *        $PostgreSQL: pgsql/src/backend/utils/adt/tsvector_op.c,v 1.26 2010/01/02 16:57:55 momjian Exp $
11  *
12  *-------------------------------------------------------------------------
13  */
14
15 #include "postgres.h"
16
17 #include "catalog/namespace.h"
18 #include "catalog/pg_type.h"
19 #include "commands/trigger.h"
20 #include "executor/spi.h"
21 #include "funcapi.h"
22 #include "mb/pg_wchar.h"
23 #include "miscadmin.h"
24 #include "tsearch/ts_type.h"
25 #include "tsearch/ts_utils.h"
26 #include "utils/builtins.h"
27 #include "utils/lsyscache.h"
28
29
30 typedef struct
31 {
32         WordEntry  *arrb;
33         WordEntry  *arre;
34         char       *values;
35         char       *operand;
36 } CHKVAL;
37
38
39 typedef struct StatEntry
40 {
41         uint32          ndoc;                   /* zero indicates that we already was here
42                                                                  * while walking throug the tree */
43         uint32          nentry;
44         struct StatEntry *left;
45         struct StatEntry *right;
46         uint32          lenlexeme;
47         char            lexeme[1];
48 } StatEntry;
49
50 #define STATENTRYHDRSZ  (offsetof(StatEntry, lexeme))
51
52 typedef struct
53 {
54         int4            weight;
55
56         uint32          maxdepth;
57
58         StatEntry **stack;
59         uint32          stackpos;
60
61         StatEntry  *root;
62 } TSVectorStat;
63
64 #define STATHDRSIZE (offsetof(TSVectorStat, data))
65
66 static Datum tsvector_update_trigger(PG_FUNCTION_ARGS, bool config_column);
67
68
69 /*
70  * Check if datatype is the specified type or equivalent to it.
71  *
72  * Note: we could just do getBaseType() unconditionally, but since that's
73  * a relatively expensive catalog lookup that most users won't need, we
74  * try the straight comparison first.
75  */
76 static bool
77 is_expected_type(Oid typid, Oid expected_type)
78 {
79         if (typid == expected_type)
80                 return true;
81         typid = getBaseType(typid);
82         if (typid == expected_type)
83                 return true;
84         return false;
85 }
86
87 /* Check if datatype is TEXT or binary-equivalent to it */
88 static bool
89 is_text_type(Oid typid)
90 {
91         /* varchar(n) and char(n) are binary-compatible with text */
92         if (typid == TEXTOID || typid == VARCHAROID || typid == BPCHAROID)
93                 return true;
94         /* Allow domains over these types, too */
95         typid = getBaseType(typid);
96         if (typid == TEXTOID || typid == VARCHAROID || typid == BPCHAROID)
97                 return true;
98         return false;
99 }
100
101
102 /*
103  * Order: haspos, len, word, for all positions (pos, weight)
104  */
105 static int
106 silly_cmp_tsvector(const TSVector a, const TSVector b)
107 {
108         if (VARSIZE(a) < VARSIZE(b))
109                 return -1;
110         else if (VARSIZE(a) > VARSIZE(b))
111                 return 1;
112         else if (a->size < b->size)
113                 return -1;
114         else if (a->size > b->size)
115                 return 1;
116         else
117         {
118                 WordEntry  *aptr = ARRPTR(a);
119                 WordEntry  *bptr = ARRPTR(b);
120                 int                     i = 0;
121                 int                     res;
122
123
124                 for (i = 0; i < a->size; i++)
125                 {
126                         if (aptr->haspos != bptr->haspos)
127                         {
128                                 return (aptr->haspos > bptr->haspos) ? -1 : 1;
129                         }
130                         else if ((res = tsCompareString(STRPTR(a) + aptr->pos, aptr->len, STRPTR(b) + bptr->pos, bptr->len, false)) != 0)
131                         {
132                                 return res;
133                         }
134                         else if (aptr->haspos)
135                         {
136                                 WordEntryPos *ap = POSDATAPTR(a, aptr);
137                                 WordEntryPos *bp = POSDATAPTR(b, bptr);
138                                 int                     j;
139
140                                 if (POSDATALEN(a, aptr) != POSDATALEN(b, bptr))
141                                         return (POSDATALEN(a, aptr) > POSDATALEN(b, bptr)) ? -1 : 1;
142
143                                 for (j = 0; j < POSDATALEN(a, aptr); j++)
144                                 {
145                                         if (WEP_GETPOS(*ap) != WEP_GETPOS(*bp))
146                                         {
147                                                 return (WEP_GETPOS(*ap) > WEP_GETPOS(*bp)) ? -1 : 1;
148                                         }
149                                         else if (WEP_GETWEIGHT(*ap) != WEP_GETWEIGHT(*bp))
150                                         {
151                                                 return (WEP_GETWEIGHT(*ap) > WEP_GETWEIGHT(*bp)) ? -1 : 1;
152                                         }
153                                         ap++, bp++;
154                                 }
155                         }
156
157                         aptr++;
158                         bptr++;
159                 }
160         }
161
162         return 0;
163 }
164
165 #define TSVECTORCMPFUNC( type, action, ret )                    \
166 Datum                                                                                                   \
167 tsvector_##type(PG_FUNCTION_ARGS)                                               \
168 {                                                                                                               \
169         TSVector        a = PG_GETARG_TSVECTOR(0);                              \
170         TSVector        b = PG_GETARG_TSVECTOR(1);                              \
171         int                     res = silly_cmp_tsvector(a, b);                 \
172         PG_FREE_IF_COPY(a,0);                                                           \
173         PG_FREE_IF_COPY(b,1);                                                           \
174         PG_RETURN_##ret( res action 0 );                                        \
175 }       \
176 /* keep compiler quiet - no extra ; */                                  \
177 extern int no_such_variable
178
179 TSVECTORCMPFUNC(lt, <, BOOL);
180 TSVECTORCMPFUNC(le, <=, BOOL);
181 TSVECTORCMPFUNC(eq, ==, BOOL);
182 TSVECTORCMPFUNC(ge, >=, BOOL);
183 TSVECTORCMPFUNC(gt, >, BOOL);
184 TSVECTORCMPFUNC(ne, !=, BOOL);
185 TSVECTORCMPFUNC(cmp, +, INT32);
186
187 Datum
188 tsvector_strip(PG_FUNCTION_ARGS)
189 {
190         TSVector        in = PG_GETARG_TSVECTOR(0);
191         TSVector        out;
192         int                     i,
193                                 len = 0;
194         WordEntry  *arrin = ARRPTR(in),
195                            *arrout;
196         char       *cur;
197
198         for (i = 0; i < in->size; i++)
199                 len += arrin[i].len;
200
201         len = CALCDATASIZE(in->size, len);
202         out = (TSVector) palloc0(len);
203         SET_VARSIZE(out, len);
204         out->size = in->size;
205         arrout = ARRPTR(out);
206         cur = STRPTR(out);
207         for (i = 0; i < in->size; i++)
208         {
209                 memcpy(cur, STRPTR(in) + arrin[i].pos, arrin[i].len);
210                 arrout[i].haspos = 0;
211                 arrout[i].len = arrin[i].len;
212                 arrout[i].pos = cur - STRPTR(out);
213                 cur += arrout[i].len;
214         }
215
216         PG_FREE_IF_COPY(in, 0);
217         PG_RETURN_POINTER(out);
218 }
219
220 Datum
221 tsvector_length(PG_FUNCTION_ARGS)
222 {
223         TSVector        in = PG_GETARG_TSVECTOR(0);
224         int4            ret = in->size;
225
226         PG_FREE_IF_COPY(in, 0);
227         PG_RETURN_INT32(ret);
228 }
229
230 Datum
231 tsvector_setweight(PG_FUNCTION_ARGS)
232 {
233         TSVector        in = PG_GETARG_TSVECTOR(0);
234         char            cw = PG_GETARG_CHAR(1);
235         TSVector        out;
236         int                     i,
237                                 j;
238         WordEntry  *entry;
239         WordEntryPos *p;
240         int                     w = 0;
241
242         switch (cw)
243         {
244                 case 'A':
245                 case 'a':
246                         w = 3;
247                         break;
248                 case 'B':
249                 case 'b':
250                         w = 2;
251                         break;
252                 case 'C':
253                 case 'c':
254                         w = 1;
255                         break;
256                 case 'D':
257                 case 'd':
258                         w = 0;
259                         break;
260                 default:
261                         /* internal error */
262                         elog(ERROR, "unrecognized weight: %d", cw);
263         }
264
265         out = (TSVector) palloc(VARSIZE(in));
266         memcpy(out, in, VARSIZE(in));
267         entry = ARRPTR(out);
268         i = out->size;
269         while (i--)
270         {
271                 if ((j = POSDATALEN(out, entry)) != 0)
272                 {
273                         p = POSDATAPTR(out, entry);
274                         while (j--)
275                         {
276                                 WEP_SETWEIGHT(*p, w);
277                                 p++;
278                         }
279                 }
280                 entry++;
281         }
282
283         PG_FREE_IF_COPY(in, 0);
284         PG_RETURN_POINTER(out);
285 }
286
287 #define compareEntry(pa, a, pb, b) \
288         tsCompareString((pa) + (a)->pos, (a)->len,      \
289                                         (pb) + (b)->pos, (b)->len,      \
290                                         false)
291
292 /*
293  * Add positions from src to dest after offsetting them by maxpos.
294  * Return the number added (might be less than expected due to overflow)
295  */
296 static int4
297 add_pos(TSVector src, WordEntry *srcptr,
298                 TSVector dest, WordEntry *destptr,
299                 int4 maxpos)
300 {
301         uint16     *clen = &_POSVECPTR(dest, destptr)->npos;
302         int                     i;
303         uint16          slen = POSDATALEN(src, srcptr),
304                                 startlen;
305         WordEntryPos *spos = POSDATAPTR(src, srcptr),
306                            *dpos = POSDATAPTR(dest, destptr);
307
308         if (!destptr->haspos)
309                 *clen = 0;
310
311         startlen = *clen;
312         for (i = 0;
313                  i < slen && *clen < MAXNUMPOS &&
314                  (*clen == 0 || WEP_GETPOS(dpos[*clen - 1]) != MAXENTRYPOS - 1);
315                  i++)
316         {
317                 WEP_SETWEIGHT(dpos[*clen], WEP_GETWEIGHT(spos[i]));
318                 WEP_SETPOS(dpos[*clen], LIMITPOS(WEP_GETPOS(spos[i]) + maxpos));
319                 (*clen)++;
320         }
321
322         if (*clen != startlen)
323                 destptr->haspos = 1;
324         return *clen - startlen;
325 }
326
327
328 Datum
329 tsvector_concat(PG_FUNCTION_ARGS)
330 {
331         TSVector        in1 = PG_GETARG_TSVECTOR(0);
332         TSVector        in2 = PG_GETARG_TSVECTOR(1);
333         TSVector        out;
334         WordEntry  *ptr;
335         WordEntry  *ptr1,
336                            *ptr2;
337         WordEntryPos *p;
338         int                     maxpos = 0,
339                                 i,
340                                 j,
341                                 i1,
342                                 i2,
343                                 dataoff;
344         char       *data,
345                            *data1,
346                            *data2;
347
348         ptr = ARRPTR(in1);
349         i = in1->size;
350         while (i--)
351         {
352                 if ((j = POSDATALEN(in1, ptr)) != 0)
353                 {
354                         p = POSDATAPTR(in1, ptr);
355                         while (j--)
356                         {
357                                 if (WEP_GETPOS(*p) > maxpos)
358                                         maxpos = WEP_GETPOS(*p);
359                                 p++;
360                         }
361                 }
362                 ptr++;
363         }
364
365         ptr1 = ARRPTR(in1);
366         ptr2 = ARRPTR(in2);
367         data1 = STRPTR(in1);
368         data2 = STRPTR(in2);
369         i1 = in1->size;
370         i2 = in2->size;
371         /* conservative estimate of space needed */
372         out = (TSVector) palloc0(VARSIZE(in1) + VARSIZE(in2));
373         SET_VARSIZE(out, VARSIZE(in1) + VARSIZE(in2));
374         out->size = in1->size + in2->size;
375         ptr = ARRPTR(out);
376         data = STRPTR(out);
377         dataoff = 0;
378         while (i1 && i2)
379         {
380                 int                     cmp = compareEntry(data1, ptr1, data2, ptr2);
381
382                 if (cmp < 0)
383                 {                                               /* in1 first */
384                         ptr->haspos = ptr1->haspos;
385                         ptr->len = ptr1->len;
386                         memcpy(data + dataoff, data1 + ptr1->pos, ptr1->len);
387                         ptr->pos = dataoff;
388                         dataoff += ptr1->len;
389                         if (ptr->haspos)
390                         {
391                                 dataoff = SHORTALIGN(dataoff);
392                                 memcpy(data + dataoff, _POSVECPTR(in1, ptr1), POSDATALEN(in1, ptr1) * sizeof(WordEntryPos) + sizeof(uint16));
393                                 dataoff += POSDATALEN(in1, ptr1) * sizeof(WordEntryPos) + sizeof(uint16);
394                         }
395
396                         ptr++;
397                         ptr1++;
398                         i1--;
399                 }
400                 else if (cmp > 0)
401                 {                                               /* in2 first */
402                         ptr->haspos = ptr2->haspos;
403                         ptr->len = ptr2->len;
404                         memcpy(data + dataoff, data2 + ptr2->pos, ptr2->len);
405                         ptr->pos = dataoff;
406                         dataoff += ptr2->len;
407                         if (ptr->haspos)
408                         {
409                                 int                     addlen = add_pos(in2, ptr2, out, ptr, maxpos);
410
411                                 if (addlen == 0)
412                                         ptr->haspos = 0;
413                                 else
414                                 {
415                                         dataoff = SHORTALIGN(dataoff);
416                                         dataoff += addlen * sizeof(WordEntryPos) + sizeof(uint16);
417                                 }
418                         }
419
420                         ptr++;
421                         ptr2++;
422                         i2--;
423                 }
424                 else
425                 {
426                         ptr->haspos = ptr1->haspos | ptr2->haspos;
427                         ptr->len = ptr1->len;
428                         memcpy(data + dataoff, data1 + ptr1->pos, ptr1->len);
429                         ptr->pos = dataoff;
430                         dataoff += ptr1->len;
431                         if (ptr->haspos)
432                         {
433                                 if (ptr1->haspos)
434                                 {
435                                         dataoff = SHORTALIGN(dataoff);
436                                         memcpy(data + dataoff, _POSVECPTR(in1, ptr1), POSDATALEN(in1, ptr1) * sizeof(WordEntryPos) + sizeof(uint16));
437                                         dataoff += POSDATALEN(in1, ptr1) * sizeof(WordEntryPos) + sizeof(uint16);
438                                         if (ptr2->haspos)
439                                                 dataoff += add_pos(in2, ptr2, out, ptr, maxpos) * sizeof(WordEntryPos);
440                                 }
441                                 else    /* must have ptr2->haspos */
442                                 {
443                                         int                     addlen = add_pos(in2, ptr2, out, ptr, maxpos);
444
445                                         if (addlen == 0)
446                                                 ptr->haspos = 0;
447                                         else
448                                         {
449                                                 dataoff = SHORTALIGN(dataoff);
450                                                 dataoff += addlen * sizeof(WordEntryPos) + sizeof(uint16);
451                                         }
452                                 }
453                         }
454
455                         ptr++;
456                         ptr1++;
457                         ptr2++;
458                         i1--;
459                         i2--;
460                 }
461         }
462
463         while (i1)
464         {
465                 ptr->haspos = ptr1->haspos;
466                 ptr->len = ptr1->len;
467                 memcpy(data + dataoff, data1 + ptr1->pos, ptr1->len);
468                 ptr->pos = dataoff;
469                 dataoff += ptr1->len;
470                 if (ptr->haspos)
471                 {
472                         dataoff = SHORTALIGN(dataoff);
473                         memcpy(data + dataoff, _POSVECPTR(in1, ptr1), POSDATALEN(in1, ptr1) * sizeof(WordEntryPos) + sizeof(uint16));
474                         dataoff += POSDATALEN(in1, ptr1) * sizeof(WordEntryPos) + sizeof(uint16);
475                 }
476
477                 ptr++;
478                 ptr1++;
479                 i1--;
480         }
481
482         while (i2)
483         {
484                 ptr->haspos = ptr2->haspos;
485                 ptr->len = ptr2->len;
486                 memcpy(data + dataoff, data2 + ptr2->pos, ptr2->len);
487                 ptr->pos = dataoff;
488                 dataoff += ptr2->len;
489                 if (ptr->haspos)
490                 {
491                         int                     addlen = add_pos(in2, ptr2, out, ptr, maxpos);
492
493                         if (addlen == 0)
494                                 ptr->haspos = 0;
495                         else
496                         {
497                                 dataoff = SHORTALIGN(dataoff);
498                                 dataoff += addlen * sizeof(WordEntryPos) + sizeof(uint16);
499                         }
500                 }
501
502                 ptr++;
503                 ptr2++;
504                 i2--;
505         }
506
507         /*
508          * Instead of checking each offset individually, we check for overflow of
509          * pos fields once at the end.
510          */
511         if (dataoff > MAXSTRPOS)
512                 ereport(ERROR,
513                                 (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
514                                  errmsg("string is too long for tsvector (%d bytes, max %d bytes)", dataoff, MAXSTRPOS)));
515
516         out->size = ptr - ARRPTR(out);
517         SET_VARSIZE(out, CALCDATASIZE(out->size, dataoff));
518         if (data != STRPTR(out))
519                 memmove(STRPTR(out), data, dataoff);
520
521         PG_FREE_IF_COPY(in1, 0);
522         PG_FREE_IF_COPY(in2, 1);
523         PG_RETURN_POINTER(out);
524 }
525
526 /*
527  * Compare two strings by tsvector rules.
528  * if isPrefix = true then it returns not-zero value if b has prefix a
529  */
530 int4
531 tsCompareString(char *a, int lena, char *b, int lenb, bool prefix)
532 {
533         int                     cmp;
534
535         if (lena == 0)
536         {
537                 if (prefix)
538                         cmp = 0;                        /* emtry string is equal to any if a prefix
539                                                                  * match */
540                 else
541                         cmp = (lenb > 0) ? -1 : 0;
542         }
543         else if (lenb == 0)
544         {
545                 cmp = (lena > 0) ? 1 : 0;
546         }
547         else
548         {
549                 cmp = memcmp(a, b, Min(lena, lenb));
550
551                 if (prefix)
552                 {
553                         if (cmp == 0 && lena > lenb)
554                         {
555                                 /*
556                                  * b argument is not beginning with argument a
557                                  */
558                                 cmp = 1;
559                         }
560                 }
561                 else if ((cmp == 0) && (lena != lenb))
562                 {
563                         cmp = (lena < lenb) ? -1 : 1;
564                 }
565         }
566
567         return cmp;
568 }
569
570 /*
571  * check weight info
572  */
573 static bool
574 checkclass_str(CHKVAL *chkval, WordEntry *val, QueryOperand *item)
575 {
576         WordEntryPosVector *posvec;
577         WordEntryPos *ptr;
578         uint16          len;
579
580         posvec = (WordEntryPosVector *)
581                 (chkval->values + SHORTALIGN(val->pos + val->len));
582
583         len = posvec->npos;
584         ptr = posvec->pos;
585
586         while (len--)
587         {
588                 if (item->weight & (1 << WEP_GETWEIGHT(*ptr)))
589                         return true;
590                 ptr++;
591         }
592         return false;
593 }
594
595 /*
596  * is there value 'val' in array or not ?
597  */
598 static bool
599 checkcondition_str(void *checkval, QueryOperand *val)
600 {
601         CHKVAL     *chkval = (CHKVAL *) checkval;
602         WordEntry  *StopLow = chkval->arrb;
603         WordEntry  *StopHigh = chkval->arre;
604         WordEntry  *StopMiddle = StopHigh;
605         int                     difference = -1;
606         bool            res = false;
607
608         /* Loop invariant: StopLow <= val < StopHigh */
609         while (StopLow < StopHigh)
610         {
611                 StopMiddle = StopLow + (StopHigh - StopLow) / 2;
612                 difference = tsCompareString(chkval->operand + val->distance, val->length,
613                                                    chkval->values + StopMiddle->pos, StopMiddle->len,
614                                                                          false);
615
616                 if (difference == 0)
617                 {
618                         res = (val->weight && StopMiddle->haspos) ?
619                                 checkclass_str(chkval, StopMiddle, val) : true;
620                         break;
621                 }
622                 else if (difference > 0)
623                         StopLow = StopMiddle + 1;
624                 else
625                         StopHigh = StopMiddle;
626         }
627
628         if (res == false && val->prefix == true)
629         {
630                 /*
631                  * there was a failed exact search, so we should scan further to find
632                  * a prefix match.
633                  */
634                 if (StopLow >= StopHigh)
635                         StopMiddle = StopHigh;
636
637                 while (res == false && StopMiddle < chkval->arre &&
638                            tsCompareString(chkval->operand + val->distance, val->length,
639                                                    chkval->values + StopMiddle->pos, StopMiddle->len,
640                                                            true) == 0)
641                 {
642                         res = (val->weight && StopMiddle->haspos) ?
643                                 checkclass_str(chkval, StopMiddle, val) : true;
644
645                         StopMiddle++;
646                 }
647         }
648
649         return res;
650 }
651
652 /*
653  * check for boolean condition.
654  *
655  * if calcnot is false, NOT expressions are always evaluated to be true. This is used in ranking.
656  * checkval can be used to pass information to the callback. TS_execute doesn't
657  * do anything with it.
658  * chkcond is a callback function used to evaluate each VAL node in the query.
659  *
660  */
661 bool
662 TS_execute(QueryItem *curitem, void *checkval, bool calcnot,
663                    bool (*chkcond) (void *checkval, QueryOperand *val))
664 {
665         /* since this function recurses, it could be driven to stack overflow */
666         check_stack_depth();
667
668         if (curitem->type == QI_VAL)
669                 return chkcond(checkval, (QueryOperand *) curitem);
670
671         switch (curitem->qoperator.oper)
672         {
673                 case OP_NOT:
674                         if (calcnot)
675                                 return !TS_execute(curitem + 1, checkval, calcnot, chkcond);
676                         else
677                                 return true;
678                 case OP_AND:
679                         if (TS_execute(curitem + curitem->qoperator.left, checkval, calcnot, chkcond))
680                                 return TS_execute(curitem + 1, checkval, calcnot, chkcond);
681                         else
682                                 return false;
683
684                 case OP_OR:
685                         if (TS_execute(curitem + curitem->qoperator.left, checkval, calcnot, chkcond))
686                                 return true;
687                         else
688                                 return TS_execute(curitem + 1, checkval, calcnot, chkcond);
689
690                 default:
691                         elog(ERROR, "unrecognized operator: %d", curitem->qoperator.oper);
692         }
693
694         /* not reachable, but keep compiler quiet */
695         return false;
696 }
697
698 /*
699  * boolean operations
700  */
701 Datum
702 ts_match_qv(PG_FUNCTION_ARGS)
703 {
704         PG_RETURN_DATUM(DirectFunctionCall2(ts_match_vq,
705                                                                                 PG_GETARG_DATUM(1),
706                                                                                 PG_GETARG_DATUM(0)));
707 }
708
709 Datum
710 ts_match_vq(PG_FUNCTION_ARGS)
711 {
712         TSVector        val = PG_GETARG_TSVECTOR(0);
713         TSQuery         query = PG_GETARG_TSQUERY(1);
714         CHKVAL          chkval;
715         bool            result;
716
717         if (!val->size || !query->size)
718         {
719                 PG_FREE_IF_COPY(val, 0);
720                 PG_FREE_IF_COPY(query, 1);
721                 PG_RETURN_BOOL(false);
722         }
723
724         chkval.arrb = ARRPTR(val);
725         chkval.arre = chkval.arrb + val->size;
726         chkval.values = STRPTR(val);
727         chkval.operand = GETOPERAND(query);
728         result = TS_execute(
729                                                 GETQUERY(query),
730                                                 &chkval,
731                                                 true,
732                                                 checkcondition_str
733                 );
734
735         PG_FREE_IF_COPY(val, 0);
736         PG_FREE_IF_COPY(query, 1);
737         PG_RETURN_BOOL(result);
738 }
739
740 Datum
741 ts_match_tt(PG_FUNCTION_ARGS)
742 {
743         TSVector        vector;
744         TSQuery         query;
745         bool            res;
746
747         vector = DatumGetTSVector(DirectFunctionCall1(to_tsvector,
748                                                                                                   PG_GETARG_DATUM(0)));
749         query = DatumGetTSQuery(DirectFunctionCall1(plainto_tsquery,
750                                                                                                 PG_GETARG_DATUM(1)));
751
752         res = DatumGetBool(DirectFunctionCall2(ts_match_vq,
753                                                                                    TSVectorGetDatum(vector),
754                                                                                    TSQueryGetDatum(query)));
755
756         pfree(vector);
757         pfree(query);
758
759         PG_RETURN_BOOL(res);
760 }
761
762 Datum
763 ts_match_tq(PG_FUNCTION_ARGS)
764 {
765         TSVector        vector;
766         TSQuery         query = PG_GETARG_TSQUERY(1);
767         bool            res;
768
769         vector = DatumGetTSVector(DirectFunctionCall1(to_tsvector,
770                                                                                                   PG_GETARG_DATUM(0)));
771
772         res = DatumGetBool(DirectFunctionCall2(ts_match_vq,
773                                                                                    TSVectorGetDatum(vector),
774                                                                                    TSQueryGetDatum(query)));
775
776         pfree(vector);
777         PG_FREE_IF_COPY(query, 1);
778
779         PG_RETURN_BOOL(res);
780 }
781
782 /*
783  * ts_stat statistic function support
784  */
785
786
787 /*
788  * Returns the number of positions in value 'wptr' within tsvector 'txt',
789  * that have a weight equal to one of the weights in 'weight' bitmask.
790  */
791 static int
792 check_weight(TSVector txt, WordEntry *wptr, int8 weight)
793 {
794         int                     len = POSDATALEN(txt, wptr);
795         int                     num = 0;
796         WordEntryPos *ptr = POSDATAPTR(txt, wptr);
797
798         while (len--)
799         {
800                 if (weight & (1 << WEP_GETWEIGHT(*ptr)))
801                         num++;
802                 ptr++;
803         }
804         return num;
805 }
806
807 #define compareStatWord(a,e,t)                                                  \
808         tsCompareString((a)->lexeme, (a)->lenlexeme,            \
809                                         STRPTR(t) + (e)->pos, (e)->len,         \
810                                         false)
811
812 static void
813 insertStatEntry(MemoryContext persistentContext, TSVectorStat *stat, TSVector txt, uint32 off)
814 {
815         WordEntry  *we = ARRPTR(txt) + off;
816         StatEntry  *node = stat->root,
817                            *pnode = NULL;
818         int                     n,
819                                 res = 0;
820         uint32          depth = 1;
821
822         if (stat->weight == 0)
823                 n = (we->haspos) ? POSDATALEN(txt, we) : 1;
824         else
825                 n = (we->haspos) ? check_weight(txt, we, stat->weight) : 0;
826
827         if (n == 0)
828                 return;                                 /* nothing to insert */
829
830         while (node)
831         {
832                 res = compareStatWord(node, we, txt);
833
834                 if (res == 0)
835                 {
836                         break;
837                 }
838                 else
839                 {
840                         pnode = node;
841                         node = (res < 0) ? node->left : node->right;
842                 }
843                 depth++;
844         }
845
846         if (depth > stat->maxdepth)
847                 stat->maxdepth = depth;
848
849         if (node == NULL)
850         {
851                 node = MemoryContextAlloc(persistentContext, STATENTRYHDRSZ + we->len);
852                 node->left = node->right = NULL;
853                 node->ndoc = 1;
854                 node->nentry = n;
855                 node->lenlexeme = we->len;
856                 memcpy(node->lexeme, STRPTR(txt) + we->pos, node->lenlexeme);
857
858                 if (pnode == NULL)
859                 {
860                         stat->root = node;
861                 }
862                 else
863                 {
864                         if (res < 0)
865                                 pnode->left = node;
866                         else
867                                 pnode->right = node;
868                 }
869
870         }
871         else
872         {
873                 node->ndoc++;
874                 node->nentry += n;
875         }
876 }
877
878 static void
879 chooseNextStatEntry(MemoryContext persistentContext, TSVectorStat *stat, TSVector txt,
880                                         uint32 low, uint32 high, uint32 offset)
881 {
882         uint32          pos;
883         uint32          middle = (low + high) >> 1;
884
885         pos = (low + middle) >> 1;
886         if (low != middle && pos >= offset && pos - offset < txt->size)
887                 insertStatEntry(persistentContext, stat, txt, pos - offset);
888         pos = (high + middle + 1) >> 1;
889         if (middle + 1 != high && pos >= offset && pos - offset < txt->size)
890                 insertStatEntry(persistentContext, stat, txt, pos - offset);
891
892         if (low != middle)
893                 chooseNextStatEntry(persistentContext, stat, txt, low, middle, offset);
894         if (high != middle + 1)
895                 chooseNextStatEntry(persistentContext, stat, txt, middle + 1, high, offset);
896 }
897
898 /*
899  * This is written like a custom aggregate function, because the
900  * original plan was to do just that. Unfortunately, an aggregate function
901  * can't return a set, so that plan was abandoned. If that limitation is
902  * lifted in the future, ts_stat could be a real aggregate function so that
903  * you could use it like this:
904  *
905  *       SELECT ts_stat(vector_column) FROM vector_table;
906  *
907  *      where vector_column is a tsvector-type column in vector_table.
908  */
909
910 static TSVectorStat *
911 ts_accum(MemoryContext persistentContext, TSVectorStat *stat, Datum data)
912 {
913         TSVector        txt = DatumGetTSVector(data);
914         uint32          i,
915                                 nbit = 0,
916                                 offset;
917
918         if (stat == NULL)
919         {                                                       /* Init in first */
920                 stat = MemoryContextAllocZero(persistentContext, sizeof(TSVectorStat));
921                 stat->maxdepth = 1;
922         }
923
924         /* simple check of correctness */
925         if (txt == NULL || txt->size == 0)
926         {
927                 if (txt && txt != (TSVector) DatumGetPointer(data))
928                         pfree(txt);
929                 return stat;
930         }
931
932         i = txt->size - 1;
933         for (; i > 0; i >>= 1)
934                 nbit++;
935
936         nbit = 1 << nbit;
937         offset = (nbit - txt->size) / 2;
938
939         insertStatEntry(persistentContext, stat, txt, (nbit >> 1) - offset);
940         chooseNextStatEntry(persistentContext, stat, txt, 0, nbit, offset);
941
942         return stat;
943 }
944
945 static void
946 ts_setup_firstcall(FunctionCallInfo fcinfo, FuncCallContext *funcctx,
947                                    TSVectorStat *stat)
948 {
949         TupleDesc       tupdesc;
950         MemoryContext oldcontext;
951         StatEntry  *node;
952
953         funcctx->user_fctx = (void *) stat;
954
955         oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
956
957         stat->stack = palloc0(sizeof(StatEntry *) * (stat->maxdepth + 1));
958         stat->stackpos = 0;
959
960         node = stat->root;
961         /* find leftmost value */
962         if (node == NULL)
963                 stat->stack[stat->stackpos] = NULL;
964         else
965                 for (;;)
966                 {
967                         stat->stack[stat->stackpos] = node;
968                         if (node->left)
969                         {
970                                 stat->stackpos++;
971                                 node = node->left;
972                         }
973                         else
974                                 break;
975                 }
976         Assert(stat->stackpos <= stat->maxdepth);
977
978         tupdesc = CreateTemplateTupleDesc(3, false);
979         TupleDescInitEntry(tupdesc, (AttrNumber) 1, "word",
980                                            TEXTOID, -1, 0);
981         TupleDescInitEntry(tupdesc, (AttrNumber) 2, "ndoc",
982                                            INT4OID, -1, 0);
983         TupleDescInitEntry(tupdesc, (AttrNumber) 3, "nentry",
984                                            INT4OID, -1, 0);
985         funcctx->tuple_desc = BlessTupleDesc(tupdesc);
986         funcctx->attinmeta = TupleDescGetAttInMetadata(tupdesc);
987
988         MemoryContextSwitchTo(oldcontext);
989 }
990
991 static StatEntry *
992 walkStatEntryTree(TSVectorStat *stat)
993 {
994         StatEntry  *node = stat->stack[stat->stackpos];
995
996         if (node == NULL)
997                 return NULL;
998
999         if (node->ndoc != 0)
1000         {
1001                 /* return entry itself: we already was at left sublink */
1002                 return node;
1003         }
1004         else if (node->right && node->right != stat->stack[stat->stackpos + 1])
1005         {
1006                 /* go on right sublink */
1007                 stat->stackpos++;
1008                 node = node->right;
1009
1010                 /* find most-left value */
1011                 for (;;)
1012                 {
1013                         stat->stack[stat->stackpos] = node;
1014                         if (node->left)
1015                         {
1016                                 stat->stackpos++;
1017                                 node = node->left;
1018                         }
1019                         else
1020                                 break;
1021                 }
1022                 Assert(stat->stackpos <= stat->maxdepth);
1023         }
1024         else
1025         {
1026                 /* we already return all left subtree, itself and  right subtree */
1027                 if (stat->stackpos == 0)
1028                         return NULL;
1029
1030                 stat->stackpos--;
1031                 return walkStatEntryTree(stat);
1032         }
1033
1034         return node;
1035 }
1036
1037 static Datum
1038 ts_process_call(FuncCallContext *funcctx)
1039 {
1040         TSVectorStat *st;
1041         StatEntry  *entry;
1042
1043         st = (TSVectorStat *) funcctx->user_fctx;
1044
1045         entry = walkStatEntryTree(st);
1046
1047         if (entry != NULL)
1048         {
1049                 Datum           result;
1050                 char       *values[3];
1051                 char            ndoc[16];
1052                 char            nentry[16];
1053                 HeapTuple       tuple;
1054
1055                 values[0] = palloc(entry->lenlexeme + 1);
1056                 memcpy(values[0], entry->lexeme, entry->lenlexeme);
1057                 (values[0])[entry->lenlexeme] = '\0';
1058                 sprintf(ndoc, "%d", entry->ndoc);
1059                 values[1] = ndoc;
1060                 sprintf(nentry, "%d", entry->nentry);
1061                 values[2] = nentry;
1062
1063                 tuple = BuildTupleFromCStrings(funcctx->attinmeta, values);
1064                 result = HeapTupleGetDatum(tuple);
1065
1066                 pfree(values[0]);
1067
1068                 /* mark entry as already visited */
1069                 entry->ndoc = 0;
1070
1071                 return result;
1072         }
1073
1074         return (Datum) 0;
1075 }
1076
1077 static TSVectorStat *
1078 ts_stat_sql(MemoryContext persistentContext, text *txt, text *ws)
1079 {
1080         char       *query = text_to_cstring(txt);
1081         int                     i;
1082         TSVectorStat *stat;
1083         bool            isnull;
1084         Portal          portal;
1085         SPIPlanPtr      plan;
1086
1087         if ((plan = SPI_prepare(query, 0, NULL)) == NULL)
1088                 /* internal error */
1089                 elog(ERROR, "SPI_prepare(\"%s\") failed", query);
1090
1091         if ((portal = SPI_cursor_open(NULL, plan, NULL, NULL, true)) == NULL)
1092                 /* internal error */
1093                 elog(ERROR, "SPI_cursor_open(\"%s\") failed", query);
1094
1095         SPI_cursor_fetch(portal, true, 100);
1096
1097         if (SPI_tuptable == NULL ||
1098                 SPI_tuptable->tupdesc->natts != 1 ||
1099                 !is_expected_type(SPI_gettypeid(SPI_tuptable->tupdesc, 1),
1100                                                   TSVECTOROID))
1101                 ereport(ERROR,
1102                                 (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1103                                  errmsg("ts_stat query must return one tsvector column")));
1104
1105         stat = MemoryContextAllocZero(persistentContext, sizeof(TSVectorStat));
1106         stat->maxdepth = 1;
1107
1108         if (ws)
1109         {
1110                 char       *buf;
1111
1112                 buf = VARDATA(ws);
1113                 while (buf - VARDATA(ws) < VARSIZE(ws) - VARHDRSZ)
1114                 {
1115                         if (pg_mblen(buf) == 1)
1116                         {
1117                                 switch (*buf)
1118                                 {
1119                                         case 'A':
1120                                         case 'a':
1121                                                 stat->weight |= 1 << 3;
1122                                                 break;
1123                                         case 'B':
1124                                         case 'b':
1125                                                 stat->weight |= 1 << 2;
1126                                                 break;
1127                                         case 'C':
1128                                         case 'c':
1129                                                 stat->weight |= 1 << 1;
1130                                                 break;
1131                                         case 'D':
1132                                         case 'd':
1133                                                 stat->weight |= 1;
1134                                                 break;
1135                                         default:
1136                                                 stat->weight |= 0;
1137                                 }
1138                         }
1139                         buf += pg_mblen(buf);
1140                 }
1141         }
1142
1143         while (SPI_processed > 0)
1144         {
1145                 for (i = 0; i < SPI_processed; i++)
1146                 {
1147                         Datum           data = SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 1, &isnull);
1148
1149                         if (!isnull)
1150                                 stat = ts_accum(persistentContext, stat, data);
1151                 }
1152
1153                 SPI_freetuptable(SPI_tuptable);
1154                 SPI_cursor_fetch(portal, true, 100);
1155         }
1156
1157         SPI_freetuptable(SPI_tuptable);
1158         SPI_cursor_close(portal);
1159         SPI_freeplan(plan);
1160         pfree(query);
1161
1162         return stat;
1163 }
1164
1165 Datum
1166 ts_stat1(PG_FUNCTION_ARGS)
1167 {
1168         FuncCallContext *funcctx;
1169         Datum           result;
1170
1171         if (SRF_IS_FIRSTCALL())
1172         {
1173                 TSVectorStat *stat;
1174                 text       *txt = PG_GETARG_TEXT_P(0);
1175
1176                 funcctx = SRF_FIRSTCALL_INIT();
1177                 SPI_connect();
1178                 stat = ts_stat_sql(funcctx->multi_call_memory_ctx, txt, NULL);
1179                 PG_FREE_IF_COPY(txt, 0);
1180                 ts_setup_firstcall(fcinfo, funcctx, stat);
1181                 SPI_finish();
1182         }
1183
1184         funcctx = SRF_PERCALL_SETUP();
1185         if ((result = ts_process_call(funcctx)) != (Datum) 0)
1186                 SRF_RETURN_NEXT(funcctx, result);
1187         SRF_RETURN_DONE(funcctx);
1188 }
1189
1190 Datum
1191 ts_stat2(PG_FUNCTION_ARGS)
1192 {
1193         FuncCallContext *funcctx;
1194         Datum           result;
1195
1196         if (SRF_IS_FIRSTCALL())
1197         {
1198                 TSVectorStat *stat;
1199                 text       *txt = PG_GETARG_TEXT_P(0);
1200                 text       *ws = PG_GETARG_TEXT_P(1);
1201
1202                 funcctx = SRF_FIRSTCALL_INIT();
1203                 SPI_connect();
1204                 stat = ts_stat_sql(funcctx->multi_call_memory_ctx, txt, ws);
1205                 PG_FREE_IF_COPY(txt, 0);
1206                 PG_FREE_IF_COPY(ws, 1);
1207                 ts_setup_firstcall(fcinfo, funcctx, stat);
1208                 SPI_finish();
1209         }
1210
1211         funcctx = SRF_PERCALL_SETUP();
1212         if ((result = ts_process_call(funcctx)) != (Datum) 0)
1213                 SRF_RETURN_NEXT(funcctx, result);
1214         SRF_RETURN_DONE(funcctx);
1215 }
1216
1217
1218 /*
1219  * Triggers for automatic update of a tsvector column from text column(s)
1220  *
1221  * Trigger arguments are either
1222  *              name of tsvector col, name of tsconfig to use, name(s) of text col(s)
1223  *              name of tsvector col, name of regconfig col, name(s) of text col(s)
1224  * ie, tsconfig can either be specified by name, or indirectly as the
1225  * contents of a regconfig field in the row.  If the name is used, it must
1226  * be explicitly schema-qualified.
1227  */
1228 Datum
1229 tsvector_update_trigger_byid(PG_FUNCTION_ARGS)
1230 {
1231         return tsvector_update_trigger(fcinfo, false);
1232 }
1233
1234 Datum
1235 tsvector_update_trigger_bycolumn(PG_FUNCTION_ARGS)
1236 {
1237         return tsvector_update_trigger(fcinfo, true);
1238 }
1239
1240 static Datum
1241 tsvector_update_trigger(PG_FUNCTION_ARGS, bool config_column)
1242 {
1243         TriggerData *trigdata;
1244         Trigger    *trigger;
1245         Relation        rel;
1246         HeapTuple       rettuple = NULL;
1247         int                     tsvector_attr_num,
1248                                 i;
1249         ParsedText      prs;
1250         Datum           datum;
1251         bool            isnull;
1252         text       *txt;
1253         Oid                     cfgId;
1254
1255         /* Check call context */
1256         if (!CALLED_AS_TRIGGER(fcinfo))         /* internal error */
1257                 elog(ERROR, "tsvector_update_trigger: not fired by trigger manager");
1258
1259         trigdata = (TriggerData *) fcinfo->context;
1260         if (TRIGGER_FIRED_FOR_STATEMENT(trigdata->tg_event))
1261                 elog(ERROR, "tsvector_update_trigger: can't process STATEMENT events");
1262         if (TRIGGER_FIRED_AFTER(trigdata->tg_event))
1263                 elog(ERROR, "tsvector_update_trigger: must be fired BEFORE event");
1264
1265         if (TRIGGER_FIRED_BY_INSERT(trigdata->tg_event))
1266                 rettuple = trigdata->tg_trigtuple;
1267         else if (TRIGGER_FIRED_BY_UPDATE(trigdata->tg_event))
1268                 rettuple = trigdata->tg_newtuple;
1269         else
1270                 elog(ERROR, "tsvector_update_trigger: must be fired for INSERT or UPDATE");
1271
1272         trigger = trigdata->tg_trigger;
1273         rel = trigdata->tg_relation;
1274
1275         if (trigger->tgnargs < 3)
1276                 elog(ERROR, "tsvector_update_trigger: arguments must be tsvector_field, ts_config, text_field1, ...)");
1277
1278         /* Find the target tsvector column */
1279         tsvector_attr_num = SPI_fnumber(rel->rd_att, trigger->tgargs[0]);
1280         if (tsvector_attr_num == SPI_ERROR_NOATTRIBUTE)
1281                 ereport(ERROR,
1282                                 (errcode(ERRCODE_UNDEFINED_COLUMN),
1283                                  errmsg("tsvector column \"%s\" does not exist",
1284                                                 trigger->tgargs[0])));
1285         if (!is_expected_type(SPI_gettypeid(rel->rd_att, tsvector_attr_num),
1286                                                   TSVECTOROID))
1287                 ereport(ERROR,
1288                                 (errcode(ERRCODE_DATATYPE_MISMATCH),
1289                                  errmsg("column \"%s\" is not of tsvector type",
1290                                                 trigger->tgargs[0])));
1291
1292         /* Find the configuration to use */
1293         if (config_column)
1294         {
1295                 int                     config_attr_num;
1296
1297                 config_attr_num = SPI_fnumber(rel->rd_att, trigger->tgargs[1]);
1298                 if (config_attr_num == SPI_ERROR_NOATTRIBUTE)
1299                         ereport(ERROR,
1300                                         (errcode(ERRCODE_UNDEFINED_COLUMN),
1301                                          errmsg("configuration column \"%s\" does not exist",
1302                                                         trigger->tgargs[1])));
1303                 if (!is_expected_type(SPI_gettypeid(rel->rd_att, config_attr_num),
1304                                                           REGCONFIGOID))
1305                         ereport(ERROR,
1306                                         (errcode(ERRCODE_DATATYPE_MISMATCH),
1307                                          errmsg("column \"%s\" is not of regconfig type",
1308                                                         trigger->tgargs[1])));
1309
1310                 datum = SPI_getbinval(rettuple, rel->rd_att, config_attr_num, &isnull);
1311                 if (isnull)
1312                         ereport(ERROR,
1313                                         (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
1314                                          errmsg("configuration column \"%s\" must not be null",
1315                                                         trigger->tgargs[1])));
1316                 cfgId = DatumGetObjectId(datum);
1317         }
1318         else
1319         {
1320                 List       *names;
1321
1322                 names = stringToQualifiedNameList(trigger->tgargs[1]);
1323                 /* require a schema so that results are not search path dependent */
1324                 if (list_length(names) < 2)
1325                         ereport(ERROR,
1326                                         (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
1327                                          errmsg("text search configuration name \"%s\" must be schema-qualified",
1328                                                         trigger->tgargs[1])));
1329                 cfgId = TSConfigGetCfgid(names, false);
1330         }
1331
1332         /* initialize parse state */
1333         prs.lenwords = 32;
1334         prs.curwords = 0;
1335         prs.pos = 0;
1336         prs.words = (ParsedWord *) palloc(sizeof(ParsedWord) * prs.lenwords);
1337
1338         /* find all words in indexable column(s) */
1339         for (i = 2; i < trigger->tgnargs; i++)
1340         {
1341                 int                     numattr;
1342
1343                 numattr = SPI_fnumber(rel->rd_att, trigger->tgargs[i]);
1344                 if (numattr == SPI_ERROR_NOATTRIBUTE)
1345                         ereport(ERROR,
1346                                         (errcode(ERRCODE_UNDEFINED_COLUMN),
1347                                          errmsg("column \"%s\" does not exist",
1348                                                         trigger->tgargs[i])));
1349                 if (!is_text_type(SPI_gettypeid(rel->rd_att, numattr)))
1350                         ereport(ERROR,
1351                                         (errcode(ERRCODE_DATATYPE_MISMATCH),
1352                                          errmsg("column \"%s\" is not of a character type",
1353                                                         trigger->tgargs[i])));
1354
1355                 datum = SPI_getbinval(rettuple, rel->rd_att, numattr, &isnull);
1356                 if (isnull)
1357                         continue;
1358
1359                 txt = DatumGetTextP(datum);
1360
1361                 parsetext(cfgId, &prs, VARDATA(txt), VARSIZE(txt) - VARHDRSZ);
1362
1363                 if (txt != (text *) DatumGetPointer(datum))
1364                         pfree(txt);
1365         }
1366
1367         /* make tsvector value */
1368         if (prs.curwords)
1369         {
1370                 datum = PointerGetDatum(make_tsvector(&prs));
1371                 rettuple = SPI_modifytuple(rel, rettuple, 1, &tsvector_attr_num,
1372                                                                    &datum, NULL);
1373                 pfree(DatumGetPointer(datum));
1374         }
1375         else
1376         {
1377                 TSVector        out = palloc(CALCDATASIZE(0, 0));
1378
1379                 SET_VARSIZE(out, CALCDATASIZE(0, 0));
1380                 out->size = 0;
1381                 datum = PointerGetDatum(out);
1382                 rettuple = SPI_modifytuple(rel, rettuple, 1, &tsvector_attr_num,
1383                                                                    &datum, NULL);
1384                 pfree(prs.words);
1385         }
1386
1387         if (rettuple == NULL)           /* internal error */
1388                 elog(ERROR, "tsvector_update_trigger: %d returned by SPI_modifytuple",
1389                          SPI_result);
1390
1391         return PointerGetDatum(rettuple);
1392 }