]> granicus.if.org Git - postgresql/blob - contrib/tsearch2/rank.c
Fix some portability issues (reliance on gcc-isms).
[postgresql] / contrib / tsearch2 / rank.c
1 /*
2  * Relevation
3  * Teodor Sigaev <teodor@sigaev.ru>
4  */
5 #include "postgres.h"
6 #include <math.h>
7
8 #include "access/gist.h"
9 #include "access/itup.h"
10 #include "utils/builtins.h"
11 #include "fmgr.h"
12 #include "funcapi.h"
13 #include "storage/bufpage.h"
14 #include "executor/spi.h"
15 #include "commands/trigger.h"
16 #include "nodes/pg_list.h"
17 #include "catalog/namespace.h"
18
19 #include "utils/array.h"
20
21 #include "tsvector.h"
22 #include "query.h"
23 #include "common.h"
24
25 PG_FUNCTION_INFO_V1(rank);
26 Datum           rank(PG_FUNCTION_ARGS);
27
28 PG_FUNCTION_INFO_V1(rank_def);
29 Datum           rank_def(PG_FUNCTION_ARGS);
30
31 PG_FUNCTION_INFO_V1(rank_cd);
32 Datum           rank_cd(PG_FUNCTION_ARGS);
33
34 PG_FUNCTION_INFO_V1(rank_cd_def);
35 Datum           rank_cd_def(PG_FUNCTION_ARGS);
36
37 PG_FUNCTION_INFO_V1(get_covers);
38 Datum           get_covers(PG_FUNCTION_ARGS);
39
40 static float weights[] = {0.1, 0.2, 0.4, 1.0};
41
42 #define wpos(wep)       ( w[ ((WordEntryPos*)(wep))->weight ] )
43
44 #define DEF_NORM_METHOD 0
45
46 /*
47  * Returns a weight of a word collocation
48  */
49 static float4
50 word_distance(int4 w)
51 {
52         if (w > 100)
53                 return 1e-30;
54
55         return 1.0 / (1.005 + 0.05 * exp(((float4) w) / 1.5 - 2));
56 }
57
58 static int
59 cnt_length(tsvector * t)
60 {
61         WordEntry  *ptr = ARRPTR(t),
62                            *end = (WordEntry *) STRPTR(t);
63         int                     len = 0,
64                                 clen;
65
66         while (ptr < end)
67         {
68                 if ((clen = POSDATALEN(t, ptr)) == 0)
69                         len += 1;
70                 else
71                         len += clen;
72                 ptr++;
73         }
74
75         return len;
76 }
77
78 static int4
79 WordECompareITEM(char *eval, char *qval, WordEntry * ptr, ITEM * item)
80 {
81         if (ptr->len == item->length)
82                 return strncmp(
83                                            eval + ptr->pos,
84                                            qval + item->distance,
85                                            item->length);
86
87         return (ptr->len > item->length) ? 1 : -1;
88 }
89
90 static WordEntry *
91 find_wordentry(tsvector * t, QUERYTYPE * q, ITEM * item)
92 {
93         WordEntry  *StopLow = ARRPTR(t);
94         WordEntry  *StopHigh = (WordEntry *) STRPTR(t);
95         WordEntry  *StopMiddle;
96         int                     difference;
97
98         /* Loop invariant: StopLow <= item < StopHigh */
99
100         while (StopLow < StopHigh)
101         {
102                 StopMiddle = StopLow + (StopHigh - StopLow) / 2;
103                 difference = WordECompareITEM(STRPTR(t), GETOPERAND(q), StopMiddle, item);
104                 if (difference == 0)
105                         return StopMiddle;
106                 else if (difference < 0)
107                         StopLow = StopMiddle + 1;
108                 else
109                         StopHigh = StopMiddle;
110         }
111
112         return NULL;
113 }
114
115 static WordEntryPos POSNULL[] = {
116         {0, 0},
117         {0, MAXENTRYPOS - 1}
118 };
119
120 static float
121 calc_rank_and(float *w, tsvector * t, QUERYTYPE * q)
122 {
123         uint16    **pos = (uint16 **) palloc(sizeof(uint16 *) * q->size);
124         int                     i,
125                                 k,
126                                 l,
127                                 p;
128         WordEntry  *entry;
129         WordEntryPos *post,
130                            *ct;
131         int4            dimt,
132                                 lenct,
133                                 dist;
134         float           res = -1.0;
135         ITEM       *item = GETQUERY(q);
136
137         memset(pos, 0, sizeof(uint16 **) * q->size);
138         *(uint16 *) POSNULL = lengthof(POSNULL) - 1;
139
140         for (i = 0; i < q->size; i++)
141         {
142
143                 if (item[i].type != VAL)
144                         continue;
145
146                 entry = find_wordentry(t, q, &(item[i]));
147                 if (!entry)
148                         continue;
149
150                 if (entry->haspos)
151                         pos[i] = (uint16 *) _POSDATAPTR(t, entry);
152                 else
153                         pos[i] = (uint16 *) POSNULL;
154
155
156                 dimt = *(uint16 *) (pos[i]);
157                 post = (WordEntryPos *) (pos[i] + 1);
158                 for (k = 0; k < i; k++)
159                 {
160                         if (!pos[k])
161                                 continue;
162                         lenct = *(uint16 *) (pos[k]);
163                         ct = (WordEntryPos *) (pos[k] + 1);
164                         for (l = 0; l < dimt; l++)
165                         {
166                                 for (p = 0; p < lenct; p++)
167                                 {
168                                         dist = abs(post[l].pos - ct[p].pos);
169                                         if (dist || (dist == 0 && (pos[i] == (uint16 *) POSNULL || pos[k] == (uint16 *) POSNULL)))
170                                         {
171                                                 float           curw;
172
173                                                 if (!dist)
174                                                         dist = MAXENTRYPOS;
175                                                 curw = sqrt(wpos(&(post[l])) * wpos(&(ct[p])) * word_distance(dist));
176                                                 res = (res < 0) ? curw : 1.0 - (1.0 - res) * (1.0 - curw);
177                                         }
178                                 }
179                         }
180                 }
181         }
182         pfree(pos);
183         return res;
184 }
185
186 static float
187 calc_rank_or(float *w, tsvector * t, QUERYTYPE * q)
188 {
189         WordEntry  *entry;
190         WordEntryPos *post;
191         int4            dimt,
192                                 j,
193                                 i;
194         float           res = -1.0;
195         ITEM       *item = GETQUERY(q);
196
197         *(uint16 *) POSNULL = lengthof(POSNULL) - 1;
198
199         for (i = 0; i < q->size; i++)
200         {
201                 if (item[i].type != VAL)
202                         continue;
203
204                 entry = find_wordentry(t, q, &(item[i]));
205                 if (!entry)
206                         continue;
207
208                 if (entry->haspos)
209                 {
210                         dimt = POSDATALEN(t, entry);
211                         post = POSDATAPTR(t, entry);
212                 }
213                 else
214                 {
215                         dimt = *(uint16 *) POSNULL;
216                         post = POSNULL + 1;
217                 }
218
219                 for (j = 0; j < dimt; j++)
220                 {
221                         if (res < 0)
222                                 res = wpos(&(post[j]));
223                         else
224                                 res = 1.0 - (1.0 - res) * (1.0 - wpos(&(post[j])));
225                 }
226         }
227         return res;
228 }
229
230 static float
231 calc_rank(float *w, tsvector * t, QUERYTYPE * q, int4 method)
232 {
233         ITEM       *item = GETQUERY(q);
234         float           res = 0.0;
235
236         if (!t->size || !q->size)
237                 return 0.0;
238
239         res = (item->type != VAL && item->val == (int4) '&') ?
240                 calc_rank_and(w, t, q) : calc_rank_or(w, t, q);
241
242         if (res < 0)
243                 res = 1e-20;
244
245         switch (method)
246         {
247                 case 0:
248                         break;
249                 case 1:
250                         res /= log((float) cnt_length(t));
251                         break;
252                 case 2:
253                         res /= (float) cnt_length(t);
254                         break;
255                 default:
256                         /* internal error */
257                         elog(ERROR, "unrecognized normalization method: %d", method);
258         }
259
260         return res;
261 }
262
263 Datum
264 rank(PG_FUNCTION_ARGS)
265 {
266         ArrayType  *win = (ArrayType *) PG_DETOAST_DATUM(PG_GETARG_DATUM(0));
267         tsvector   *txt = (tsvector *) PG_DETOAST_DATUM(PG_GETARG_DATUM(1));
268         QUERYTYPE  *query = (QUERYTYPE *) PG_DETOAST_DATUM(PG_GETARG_DATUM(2));
269         int                     method = DEF_NORM_METHOD;
270         float           res = 0.0;
271         float           ws[lengthof(weights)];
272         int                     i;
273
274         if (ARR_NDIM(win) != 1)
275                 ereport(ERROR,
276                                 (errcode(ERRCODE_ARRAY_SUBSCRIPT_ERROR),
277                                  errmsg("array of weight must be one-dimensional")));
278
279         if (ARRNELEMS(win) < lengthof(weights))
280                 ereport(ERROR,
281                                 (errcode(ERRCODE_ARRAY_SUBSCRIPT_ERROR),
282                                  errmsg("array of weight is too short")));
283
284         for (i = 0; i < lengthof(weights); i++)
285         {
286                 ws[i] = (((float4 *) ARR_DATA_PTR(win))[i] >= 0) ? ((float4 *) ARR_DATA_PTR(win))[i] : weights[i];
287                 if (ws[i] > 1.0)
288                         ereport(ERROR,
289                                         (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
290                                          errmsg("weight out of range")));
291         }
292
293         if (PG_NARGS() == 4)
294                 method = PG_GETARG_INT32(3);
295
296         res = calc_rank(ws, txt, query, method);
297
298         PG_FREE_IF_COPY(win, 0);
299         PG_FREE_IF_COPY(txt, 1);
300         PG_FREE_IF_COPY(query, 2);
301         PG_RETURN_FLOAT4(res);
302 }
303
304 Datum
305 rank_def(PG_FUNCTION_ARGS)
306 {
307         tsvector   *txt = (tsvector *) PG_DETOAST_DATUM(PG_GETARG_DATUM(0));
308         QUERYTYPE  *query = (QUERYTYPE *) PG_DETOAST_DATUM(PG_GETARG_DATUM(1));
309         float           res = 0.0;
310         int                     method = DEF_NORM_METHOD;
311
312         if (PG_NARGS() == 3)
313                 method = PG_GETARG_INT32(2);
314
315         res = calc_rank(weights, txt, query, method);
316
317         PG_FREE_IF_COPY(txt, 0);
318         PG_FREE_IF_COPY(query, 1);
319         PG_RETURN_FLOAT4(res);
320 }
321
322
323 typedef struct
324 {
325         ITEM       *item;
326         int32           pos;
327 }       DocRepresentation;
328
329 static int
330 compareDocR(const void *a, const void *b)
331 {
332         if (((DocRepresentation *) a)->pos == ((DocRepresentation *) b)->pos)
333                 return 1;
334         return (((DocRepresentation *) a)->pos > ((DocRepresentation *) b)->pos) ? 1 : -1;
335 }
336
337
338 typedef struct
339 {
340         DocRepresentation *doc;
341         int                     len;
342 }       ChkDocR;
343
344 static bool
345 checkcondition_DR(void *checkval, ITEM * val)
346 {
347         DocRepresentation *ptr = ((ChkDocR *) checkval)->doc;
348
349         while (ptr - ((ChkDocR *) checkval)->doc < ((ChkDocR *) checkval)->len)
350         {
351                 if (val == ptr->item)
352                         return true;
353                 ptr++;
354         }
355
356         return false;
357 }
358
359
360 static bool
361 Cover(DocRepresentation * doc, int len, QUERYTYPE * query, int *pos, int *p, int *q)
362 {
363         int                     i;
364         DocRepresentation *ptr,
365                            *f = (DocRepresentation *) 0xffffffff;
366         ITEM       *item = GETQUERY(query);
367         int                     lastpos = *pos;
368         int                     oldq = *q;
369
370         *p = 0x7fffffff;
371         *q = 0;
372
373         for (i = 0; i < query->size; i++)
374         {
375                 if (item->type != VAL)
376                 {
377                         item++;
378                         continue;
379                 }
380                 ptr = doc + *pos;
381
382                 while (ptr - doc < len)
383                 {
384                         if (ptr->item == item)
385                         {
386                                 if (ptr->pos > *q)
387                                 {
388                                         *q = ptr->pos;
389                                         lastpos = ptr - doc;
390                                 }
391                                 break;
392                         }
393                         ptr++;
394                 }
395
396                 item++;
397         }
398
399         if (*q == 0)
400                 return false;
401
402         if (*q == oldq)
403         {                                                       /* already check this pos */
404                 (*pos)++;
405                 return Cover(doc, len, query, pos, p, q);
406         }
407
408         item = GETQUERY(query);
409         for (i = 0; i < query->size; i++)
410         {
411                 if (item->type != VAL)
412                 {
413                         item++;
414                         continue;
415                 }
416                 ptr = doc + lastpos;
417
418                 while (ptr >= doc + *pos)
419                 {
420                         if (ptr->item == item)
421                         {
422                                 if (ptr->pos < *p)
423                                 {
424                                         *p = ptr->pos;
425                                         f = ptr;
426                                 }
427                                 break;
428                         }
429                         ptr--;
430                 }
431                 item++;
432         }
433
434         if (*p <= *q)
435         {
436                 ChkDocR         ch;
437
438                 ch.doc = f;
439                 ch.len = (doc + lastpos) - f + 1;
440                 *pos = f - doc + 1;
441                 if (TS_execute(GETQUERY(query), &ch, false, checkcondition_DR))
442                 {
443                         /*
444                          * elog(NOTICE,"OP:%d NP:%d P:%d Q:%d", *pos, lastpos, *p,
445                          * *q);
446                          */
447                         return true;
448                 }
449                 else
450                         return Cover(doc, len, query, pos, p, q);
451         }
452
453         return false;
454 }
455
456 static DocRepresentation *
457 get_docrep(tsvector * txt, QUERYTYPE * query, int *doclen)
458 {
459         ITEM       *item = GETQUERY(query);
460         WordEntry  *entry;
461         WordEntryPos *post;
462         int4            dimt,
463                                 j,
464                                 i;
465         int                     len = query->size * 4,
466                                 cur = 0;
467         DocRepresentation *doc;
468
469         *(uint16 *) POSNULL = lengthof(POSNULL) - 1;
470         doc = (DocRepresentation *) palloc(sizeof(DocRepresentation) * len);
471         for (i = 0; i < query->size; i++)
472         {
473                 if (item[i].type != VAL)
474                         continue;
475
476                 entry = find_wordentry(txt, query, &(item[i]));
477                 if (!entry)
478                         continue;
479
480                 if (entry->haspos)
481                 {
482                         dimt = POSDATALEN(txt, entry);
483                         post = POSDATAPTR(txt, entry);
484                 }
485                 else
486                 {
487                         dimt = *(uint16 *) POSNULL;
488                         post = POSNULL + 1;
489                 }
490
491                 while (cur + dimt >= len)
492                 {
493                         len *= 2;
494                         doc = (DocRepresentation *) repalloc(doc, sizeof(DocRepresentation) * len);
495                 }
496
497                 for (j = 0; j < dimt; j++)
498                 {
499                         doc[cur].item = &(item[i]);
500                         doc[cur].pos = post[j].pos;
501                         cur++;
502                 }
503         }
504
505         *doclen = cur;
506
507         if (cur > 0)
508         {
509                 if (cur > 1)
510                         qsort((void *) doc, cur, sizeof(DocRepresentation), compareDocR);
511                 return doc;
512         }
513
514         pfree(doc);
515         return NULL;
516 }
517
518
519 Datum
520 rank_cd(PG_FUNCTION_ARGS)
521 {
522         int                     K = PG_GETARG_INT32(0);
523         tsvector   *txt = (tsvector *) PG_DETOAST_DATUM(PG_GETARG_DATUM(1));
524         QUERYTYPE  *query = (QUERYTYPE *) PG_DETOAST_DATUM(PG_GETARG_DATUM(2));
525         int                     method = DEF_NORM_METHOD;
526         DocRepresentation *doc;
527         float           res = 0.0;
528         int                     p = 0,
529                                 q = 0,
530                                 len,
531                                 cur;
532
533         doc = get_docrep(txt, query, &len);
534         if (!doc)
535         {
536                 PG_FREE_IF_COPY(txt, 1);
537                 PG_FREE_IF_COPY(query, 2);
538                 PG_RETURN_FLOAT4(0.0);
539         }
540
541         cur = 0;
542         if (K <= 0)
543                 K = 4;
544         while (Cover(doc, len, query, &cur, &p, &q))
545                 res += (q - p + 1 > K) ? ((float) K) / ((float) (q - p + 1)) : 1.0;
546
547         if (PG_NARGS() == 4)
548                 method = PG_GETARG_INT32(3);
549
550         switch (method)
551         {
552                 case 0:
553                         break;
554                 case 1:
555                         res /= log((float) cnt_length(txt));
556                         break;
557                 case 2:
558                         res /= (float) cnt_length(txt);
559                         break;
560                 default:
561                         /* internal error */
562                         elog(ERROR, "unrecognized normalization method: %d", method);
563         }
564
565         pfree(doc);
566         PG_FREE_IF_COPY(txt, 1);
567         PG_FREE_IF_COPY(query, 2);
568
569         PG_RETURN_FLOAT4(res);
570 }
571
572
573 Datum
574 rank_cd_def(PG_FUNCTION_ARGS)
575 {
576         PG_RETURN_DATUM(DirectFunctionCall4(
577                                                                                 rank_cd,
578                                                                                 Int32GetDatum(-1),
579                                                                                 PG_GETARG_DATUM(0),
580                                                                                 PG_GETARG_DATUM(1),
581                                                                                 (PG_NARGS() == 3) ? PG_GETARG_DATUM(2) : Int32GetDatum(DEF_NORM_METHOD)
582                                                                                 ));
583 }
584
585 /**************debug*************/
586
587 typedef struct
588 {
589         char       *w;
590         int2            len;
591         int2            pos;
592         int2            start;
593         int2            finish;
594 }       DocWord;
595
596 static int
597 compareDocWord(const void *a, const void *b)
598 {
599         if (((DocWord *) a)->pos == ((DocWord *) b)->pos)
600                 return 1;
601         return (((DocWord *) a)->pos > ((DocWord *) b)->pos) ? 1 : -1;
602 }
603
604
605 Datum
606 get_covers(PG_FUNCTION_ARGS)
607 {
608         tsvector   *txt = (tsvector *) PG_DETOAST_DATUM(PG_GETARG_DATUM(0));
609         QUERYTYPE  *query = (QUERYTYPE *) PG_DETOAST_DATUM(PG_GETARG_DATUM(1));
610         WordEntry  *pptr = ARRPTR(txt);
611         int                     i,
612                                 dlen = 0,
613                                 j,
614                                 cur = 0,
615                                 len = 0,
616                                 rlen;
617         DocWord    *dw,
618                            *dwptr;
619         text       *out;
620         char       *cptr;
621         DocRepresentation *doc;
622         int                     pos = 0,
623                                 p,
624                                 q,
625                                 olddwpos = 0;
626         int                     ncover = 1;
627
628         doc = get_docrep(txt, query, &rlen);
629
630         if (!doc)
631         {
632                 out = palloc(VARHDRSZ);
633                 VARATT_SIZEP(out) = VARHDRSZ;
634                 PG_FREE_IF_COPY(txt, 0);
635                 PG_FREE_IF_COPY(query, 1);
636                 PG_RETURN_POINTER(out);
637         }
638
639         for (i = 0; i < txt->size; i++)
640         {
641                 if (!pptr[i].haspos)
642                         ereport(ERROR,
643                                         (errcode(ERRCODE_SYNTAX_ERROR),
644                                          errmsg("no pos info")));
645                 dlen += POSDATALEN(txt, &(pptr[i]));
646         }
647
648         dwptr = dw = palloc(sizeof(DocWord) * dlen);
649         memset(dw, 0, sizeof(DocWord) * dlen);
650
651         for (i = 0; i < txt->size; i++)
652         {
653                 WordEntryPos *posdata = POSDATAPTR(txt, &(pptr[i]));
654
655                 for (j = 0; j < POSDATALEN(txt, &(pptr[i])); j++)
656                 {
657                         dw[cur].w = STRPTR(txt) + pptr[i].pos;
658                         dw[cur].len = pptr[i].len;
659                         dw[cur].pos = posdata[j].pos;
660                         cur++;
661                 }
662                 len += (pptr[i].len + 1) * (int) POSDATALEN(txt, &(pptr[i]));
663         }
664         qsort((void *) dw, dlen, sizeof(DocWord), compareDocWord);
665
666         while (Cover(doc, rlen, query, &pos, &p, &q))
667         {
668                 dwptr = dw + olddwpos;
669                 while (dwptr->pos < p && dwptr - dw < dlen)
670                         dwptr++;
671                 olddwpos = dwptr - dw;
672                 dwptr->start = ncover;
673                 while (dwptr->pos < q + 1 && dwptr - dw < dlen)
674                         dwptr++;
675                 (dwptr - 1)->finish = ncover;
676                 len += 4 /* {}+two spaces */ + 2 * 16 /* numbers */ ;
677                 ncover++;
678         }
679
680         out = palloc(VARHDRSZ + len);
681         cptr = ((char *) out) + VARHDRSZ;
682         dwptr = dw;
683
684         while (dwptr - dw < dlen)
685         {
686                 if (dwptr->start)
687                 {
688                         sprintf(cptr, "{%d ", dwptr->start);
689                         cptr = strchr(cptr, '\0');
690                 }
691                 memcpy(cptr, dwptr->w, dwptr->len);
692                 cptr += dwptr->len;
693                 *cptr = ' ';
694                 cptr++;
695                 if (dwptr->finish)
696                 {
697                         sprintf(cptr, "}%d ", dwptr->finish);
698                         cptr = strchr(cptr, '\0');
699                 }
700                 dwptr++;
701         }
702
703         VARATT_SIZEP(out) = cptr - ((char *) out);
704
705         pfree(dw);
706         pfree(doc);
707
708         PG_FREE_IF_COPY(txt, 0);
709         PG_FREE_IF_COPY(query, 1);
710         PG_RETURN_POINTER(out);
711 }