Make `wildspeed` a modern Postgres extension.
[wildspeed.git] / wildspeed.c
1 #include "postgres.h"
2
3 #include "catalog/pg_type.h"
4 #include "mb/pg_wchar.h"
5 #include "utils/array.h"
6
7 /*
8  * MARK_SIGN is a sign of end of string and this character should
9  * be regular character. The best candidate of this is a zero byte
10  * which is accepted for any locale used by postgres. But it's impossible
11  * to show it, so we will replace it to another one (MARK_SIGN_SHOW) which 
12  * can be noticed well. But we can't use it as mark because it's allowed
13  * to be inside string.
14  */ 
15
16 #define         MARK_SIGN               '\0'
17 #define         MARK_SIGN_SHOW  '$'
18
19
20 #define WC_BEGIN                0x01   /* should be in begining of string */
21 #define WC_MIDDLE               0x02   /* should be in middle of string */
22 #define WC_END                  0x04   /* should be in end of string */
23
24 PG_MODULE_MAGIC;
25
26 static text*
27 appendStrToText( text *src, char *str, int32 len, int32 maxlen )
28 {
29         int32   curlen;
30
31         if (src == NULL )
32         {
33                 Assert( maxlen >= 0 );
34                 src = (text*)palloc( VARHDRSZ + sizeof(char*) * maxlen );
35                 SET_VARSIZE(src, 0 + VARHDRSZ);
36         }
37
38         curlen = VARSIZE(src) - VARHDRSZ;
39
40         if (len>0)
41                 memcpy( VARDATA(src) + curlen, str, len );
42
43         SET_VARSIZE(src, curlen + len + VARHDRSZ);
44
45         return src;
46 }
47
48 static text*
49 appendMarkToText( text *src, int32 maxlen )
50 {
51         char sign = MARK_SIGN;
52
53         return appendStrToText( src, &sign, 1, maxlen );
54 }
55
56 static text*
57 setFlagOfText( char flag, int32 maxlen )
58 {
59         char flagstruct[2];
60
61         Assert( maxlen > 0 );
62         /*
63          * Mark text by setting first byte to MARK_SIGN to indicate
64          * that text has flags. It's a safe for non empty string, 
65          * because first character can not be a MARK_SIGN (see
66          * gin_extract_permuted() )
67          */
68
69         flagstruct[0] = MARK_SIGN;
70         flagstruct[1] = flag;
71         
72         return appendStrToText(NULL, flagstruct, 2, maxlen );
73 }
74
75 PG_FUNCTION_INFO_V1(gin_extract_permuted);
76 Datum           gin_extract_permuted(PG_FUNCTION_ARGS);
77 Datum
78 gin_extract_permuted(PG_FUNCTION_ARGS)
79 {
80         text    *src = PG_GETARG_TEXT_P(0);
81         int32   *nentries = (int32 *) PG_GETARG_POINTER(1);
82         Datum   *entries = NULL;
83         int32   srclen = pg_mbstrlen_with_len(VARDATA(src), VARSIZE(src) - VARHDRSZ);
84
85         *nentries = srclen;
86
87         if ( srclen == 0 )
88         {
89                 /*
90                  * Empty string is encoded by alone MARK_SIGN character
91                  */
92                 *nentries = 1;
93                 entries = (Datum*) palloc(sizeof(Datum));
94                 entries[0] = PointerGetDatum( appendMarkToText( NULL, 1 ) );
95         }
96         else
97         {
98                 text    *dst;
99                 int32   i, 
100                                 offset=0; /* offset to current position in src in bytes */ 
101                 int32   nbytes = VARSIZE(src) - VARHDRSZ;
102                 char    *srcptr = VARDATA(src);
103
104                 /*
105                  * Permutation: hello will be permuted to hello$, ello$h, llo$he, lo$hel, o$hell.
106                  * So, number of entries is equial to number of characters (not a bytes)
107                  */
108
109                 entries = (Datum*)palloc(sizeof(char*) * nbytes );
110                 for(i=0; i<srclen;i++) {
111                 
112                         /*
113                          * Copy first part. For llo$he it will be 'llo'
114                          */
115                         dst = appendStrToText( NULL, srcptr + offset, nbytes - offset, nbytes + 1 ); 
116
117                         /*
118                          * Set mark sign ($)
119                          */
120                         dst = appendMarkToText( dst, -1 );
121
122                         /*
123                          * Copy rest of string (in example above 'he')
124                          */
125                         dst = appendStrToText( dst, srcptr, offset, -1 );
126
127                         entries[i] = PointerGetDatum(dst);
128                         offset += pg_mblen( srcptr + offset );
129                 }
130         }
131
132         PG_FREE_IF_COPY(src,0);
133         PG_RETURN_POINTER(entries);
134 }
135
136 static int 
137 wildcmp_internal(text *a, text *b, bool partialMatch)
138 {
139         int32   cmp;
140         int             lena,
141                         lenb;
142         char    *ptra = VARDATA(a),
143                         *ptrb = VARDATA(b);
144         char    flag = 0;
145
146         lena = VARSIZE(a) - VARHDRSZ;
147         lenb = VARSIZE(b) - VARHDRSZ;
148
149         /*
150          * sets correct pointers and lengths in case of flags
151          * presence
152          */
153         if ( lena > 2 && *ptra == MARK_SIGN )
154         {
155                 flag = *(ptra+1);
156                 ptra+=2;
157                 lena-=2;
158
159                 if ( lenb > 2 && *ptrb == MARK_SIGN )
160                 {
161                         /*
162                          * If they have different flags then they can not be equal, this 
163                          * place works only during check of equality of keys
164                          * to search
165                          */
166                         if ( flag != *(ptrb+1) )
167                                 return 1;
168                         ptrb+=2;
169                         lenb-=2;
170
171                         /* b can not be a product of gin_extract_wildcard for partial match mode */
172                         Assert( partialMatch == false );
173                 }
174         } 
175         else  if ( lenb > 2 && *ptrb == MARK_SIGN )
176         {
177                 /* b can not be a product of gin_extract_wildcard for partial match mode */
178                 Assert( partialMatch == false );
179
180                 ptrb+=2;
181                 lenb-=2;
182         }
183
184         if ( lena == 0 )
185         {
186                 if ( partialMatch )
187                         cmp = 0; /* full scan for partialMatch*/
188                 else
189                         cmp = (lenb>0) ? -1 : 0;
190         }
191         else
192         {
193                 /*
194                  * We couldn't use strcmp because of MARK_SIGN
195                  */
196                 cmp = memcmp(ptra, ptrb, Min(lena, lenb));
197
198                 if ( partialMatch )
199                 {
200                         if ( cmp == 0 )
201                         {
202                                 if ( lena > lenb )
203                                 {
204                                         /*
205                                          * b argument is not beginning with argument a
206                                          */
207                                         cmp = 1;
208                                 }
209                                 else if ( flag > 0 && lenb>lena /* be safe */ )
210                                 { /* there is some flags to check */
211                                         char    actualFlag;
212
213                                         if ( ptrb[ lenb - 1 ] == MARK_SIGN )
214                                                 actualFlag = WC_BEGIN;  
215                                         else if ( ptrb[ lena ] == MARK_SIGN )
216                                                 actualFlag = WC_END;
217                                         else
218                                                 actualFlag = WC_MIDDLE;
219
220                                         if ( (flag & actualFlag) == 0 )
221                                         {
222                                                 /* 
223                                                  * Prefix are matched but this prefix s not placed as needed.
224                                                  * so we should give a smoke signal to GIN that we don't want
225                                                  * this match, but wish to continue scan 
226                                                  */
227                                                 cmp = -1;
228                                         }
229                                 }
230                         } 
231                         else if (cmp < 0)
232                         {
233                                 cmp = 1; /* prevent continue scan */
234                         }
235                 } 
236                 else if ( (cmp == 0) && (lena != lenb) )
237                 {
238                         cmp = (lena < lenb) ? -1 : 1;
239                 }
240         }
241
242         return cmp;
243 }
244
245 PG_FUNCTION_INFO_V1(wildcmp);
246 Datum       wildcmp(PG_FUNCTION_ARGS);
247 Datum
248 wildcmp(PG_FUNCTION_ARGS)
249 {
250         text    *a = PG_GETARG_TEXT_P(0);
251         text    *b = PG_GETARG_TEXT_P(1);
252         int32   cmp;
253
254         cmp = wildcmp_internal(a, b, false);
255
256         PG_FREE_IF_COPY(a,0);
257         PG_FREE_IF_COPY(b,1);
258         PG_RETURN_INT32( cmp ); 
259 }
260
261 PG_FUNCTION_INFO_V1(wildcmp_prefix);
262 Datum       wildcmp_prefix(PG_FUNCTION_ARGS);
263 Datum
264 wildcmp_prefix(PG_FUNCTION_ARGS)
265 {
266         text    *a = PG_GETARG_TEXT_P(0);
267         text    *b = PG_GETARG_TEXT_P(1);
268 #ifdef NOT_USED
269         StrategyNumber  strategy = PG_GETARG_UINT16(2);
270 #endif
271         int32   cmp;
272
273         cmp = wildcmp_internal(a, b, true);
274
275         PG_FREE_IF_COPY(a,0);
276         PG_FREE_IF_COPY(b,1);
277         PG_RETURN_INT32( cmp ); 
278 }
279
280 #ifdef OPTIMIZE_WILDCARD_QUERY
281
282 typedef struct 
283 {
284         Datum   entry;
285         int32   len;
286         char    flag;
287 } OptItem;
288
289
290 /*
291  * Function drops most short search word to speedup 
292  * index search by preventing use word which gives
293  * a lot of matches
294  */
295 static void 
296 optimize_wildcard_search( Datum *entries, int32 *nentries )
297 {
298         int32   maxlen=0;
299         OptItem *items;
300         int             i, nitems = *nentries;
301         char    *ptr,*p;
302
303         items = (OptItem*)palloc( sizeof(OptItem) * (*nentries) );
304         for(i=0;i<nitems;i++)
305         {
306                 items[i].entry = entries[i];
307                 items[i].len = VARSIZE(entries[i]) - VARHDRSZ;
308                 ptr = VARDATA(entries[i]);
309
310                 if ( items[i].len > 2 && *ptr == MARK_SIGN )
311                 {
312                         items[i].len-=2;
313                         items[i].flag = *(ptr+1);
314                 }
315                 else
316                 {
317                         items[i].flag = 0;
318                         if ( items[i].len > 1 && (p=strchr(ptr, MARK_SIGN)) != NULL )
319                         {
320                                 if ( p == ptr + items[i].len -1 )
321                                         items[i].flag = WC_BEGIN;
322                                 else 
323                                         items[i].flag = WC_BEGIN | WC_END;
324                         }
325                 }
326
327                 if ( items[i].len > maxlen )
328                         maxlen = items[i].len;
329         }
330         
331         *nentries=0;
332
333         for(i=0;i<nitems;i++)
334         {
335                 if ( (items[i].flag & WC_BEGIN) && (items[i].flag & WC_END) )
336                 {       /* X$Y use always */
337                         entries[ *nentries ] = items[i].entry;
338                         (*nentries)++;
339                 }
340                 else if ( (items[i].flag & WC_MIDDLE) == 0 )
341                 { 
342                         /* 
343                          * for begin-only or end-only word we set more low limit than for 
344                          * other variants
345                          */
346                         if ( 3*items[i].len > maxlen )
347                         {
348                                 entries[ *nentries ] = items[i].entry;
349                                 (*nentries)++;
350                         }
351                 }
352                 else if ( 2*items[i].len > maxlen )
353                 {       
354                         /* 
355                          * use only items with biggest length 
356                          */
357                         entries[ *nentries ] = items[i].entry;
358                         (*nentries)++;
359                 }
360         }
361
362         Assert( *nentries>0 );
363
364 }
365 #endif
366
367 typedef struct 
368 {
369         bool    iswildcard;
370         int32   len;
371         char    *ptr;
372 } WildItem;
373
374 PG_FUNCTION_INFO_V1(gin_extract_wildcard);
375 Datum           gin_extract_wildcard(PG_FUNCTION_ARGS);
376 Datum
377 gin_extract_wildcard(PG_FUNCTION_ARGS)
378 {
379         text                    *q = PG_GETARG_TEXT_P(0);
380         int32                   lenq = VARSIZE(q) - VARHDRSZ;
381         int32                   *nentries = (int32 *) PG_GETARG_POINTER(1);
382 #ifdef NOT_USED
383         StrategyNumber  strategy = PG_GETARG_UINT16(2);
384 #endif
385         bool                    *partialmatch, 
386                                         **ptr_partialmatch = (bool**) PG_GETARG_POINTER(3);
387         Pointer                 **extra = (Pointer**)PG_GETARG_POINTER(4);
388         Datum                   *entries = NULL;
389         char                    *qptr = VARDATA(q);
390         int                             clen,
391                                         splitqlen = 0,
392                                         i;
393         WildItem                *items;
394         text                    *entry;
395         bool                    needRecheck = false;
396
397         *nentries = 0;
398
399         if ( lenq == 0 )
400         {
401                 partialmatch = *ptr_partialmatch = (bool*)palloc0(sizeof(bool));
402                 *nentries = 1;
403                 entries = (Datum*) palloc(sizeof(Datum));
404                 entries[0] = PointerGetDatum( appendMarkToText( NULL, 1 ) );
405
406                 PG_RETURN_POINTER(entries);
407         }
408
409         partialmatch = *ptr_partialmatch = (bool*)palloc0(sizeof(bool) * lenq);
410         entries = (Datum*) palloc(sizeof(Datum) * lenq);
411         items=(WildItem*) palloc0( sizeof(WildItem) * lenq );
412
413
414         /*
415          * Parse expression to the list of constant parts and
416          * wildcards
417          */
418         while( qptr - VARDATA(q) < lenq )
419         {
420                 clen = pg_mblen(qptr);
421
422                 if ( clen==1 && (*qptr == '_' || *qptr == '%' ) )
423                 {
424                         if ( splitqlen == 0 )
425                         {
426                                 items[ splitqlen ].iswildcard = true;
427                                 splitqlen++;
428                         } 
429                         else if ( items[ splitqlen-1 ].iswildcard == false )
430                         {
431                                 items[ splitqlen-1 ].len = qptr - items[ splitqlen-1 ].ptr;
432                                 items[ splitqlen ].iswildcard = true;
433                                 splitqlen++;
434                         }
435                         /*
436                          * ignore wildcard, because we don't make difference beetween
437                          * %, _ or a combination of its
438                          */
439                 }
440                 else
441                 {
442                         if ( splitqlen == 0 || items[ splitqlen-1 ].iswildcard == true )
443                         {
444                                 items[ splitqlen ].ptr = qptr;
445                                 splitqlen++;
446                         }
447                 }
448                 qptr += clen;
449         }
450
451         Assert( splitqlen >= 1 );
452         if ( items[ splitqlen-1 ].iswildcard == false )
453                 items[ splitqlen-1 ].len = qptr - items[ splitqlen-1 ].ptr;
454
455         if ( items[ 0 ].iswildcard == false )
456         {
457                 /* X... */
458                 if ( splitqlen == 1 )
459                 {
460                         /*   X => X$, exact match */
461                         *nentries = 1;
462                         entry = appendStrToText(NULL, items[ 0 ].ptr, items[ 0 ].len, lenq+1);
463                         entry = appendMarkToText( entry, -1 );
464                         entries[0] = PointerGetDatum( entry );
465                 } 
466                 else if ( items[ splitqlen-1 ].iswildcard == false ) 
467                 {
468                         /*   X * [X1 * [] ] ] Y => Y$X* [ + X1* [] ] */
469
470                         *nentries = 1;
471                         entry = appendStrToText(NULL, items[ splitqlen-1 ].ptr, items[ splitqlen-1 ].len, lenq+1);
472                         entry = appendMarkToText( entry, -1 );
473                         entry = appendStrToText(entry, items[ 0 ].ptr, items[ 0 ].len, -1);
474                         partialmatch[0] = true;
475                         entries[0] = PointerGetDatum( entry );
476
477                         for(i=1; i<splitqlen-1; i++)
478                         {
479                                 if ( items[ i ].iswildcard )
480                                         continue;
481                                 entry = setFlagOfText( WC_MIDDLE, lenq + 1 /* MARK_SIGN */ + 2 /* flag */ ); 
482                                 entry = appendStrToText(entry, items[ i ].ptr, items[ i ].len, -1 );
483                                 partialmatch[ *nentries ] = true;
484                                 entries[ *nentries ] =  PointerGetDatum( entry );
485                                 (*nentries)++;
486                         }
487
488                         if ( splitqlen > 3 /* X1 may be inside X OR Y */ )
489                                 needRecheck = true;
490                 }
491                 else
492                 {
493                         /*   X * [ X1 * [] ]  => X*$ [ + X1* [] ] */
494                 
495                         entry = setFlagOfText( WC_BEGIN, lenq + 1 /* MARK_SIGN */ + 2 /* flag */ );
496                         entry = appendStrToText(entry, items[ 0 ].ptr, items[ 0 ].len, -1);
497                         *nentries = 1;
498                         partialmatch[ 0 ] = true;
499                         entries[0] = PointerGetDatum( entry );
500
501                         for(i=2; i<splitqlen-1; i++)
502                         {
503                                 if ( items[ i ].iswildcard )
504                                         continue;
505                                 entry = setFlagOfText( (i==splitqlen-2) ? (WC_MIDDLE | WC_END) : WC_MIDDLE, 
506                                                                                 lenq + 1 /* MARK_SIGN */ + 2 /* flag */ );
507                                 entry = appendStrToText(entry, items[ i ].ptr, items[ i ].len, -1);
508                                 partialmatch[ *nentries ] = true;
509                                 entries[ *nentries ] =  PointerGetDatum( entry );
510                                 (*nentries)++;
511                         }
512                         if ( splitqlen > 2 /* we don't remeber an order of Xn */ )
513                                 needRecheck = true;
514                 }
515         } 
516         else
517         {
518                 /* *...  */
519
520                 if ( splitqlen == 1 )
521                 {
522                         /* any word => full scan */
523                         *nentries = 1;
524                         entry = appendStrToText(NULL, "", 0, lenq+1);
525                         partialmatch[0] = true;
526                         entries[0] = PointerGetDatum( entry );
527                 }
528                 else if ( items[ splitqlen-1 ].iswildcard == false )
529                 {
530                         /*     * [ X1 * [] ] X  => X$* [ + X1* [] ]  */
531                         *nentries = 1;
532                         entry = appendStrToText(NULL, items[ splitqlen-1 ].ptr, items[ splitqlen-1 ].len, lenq+1);
533                         entry = appendMarkToText( entry, -1 );
534                         partialmatch[0] = true;
535                         entries[0] = PointerGetDatum( entry );
536
537                         for(i=1; i<splitqlen-1; i++)
538                         {
539                                 if ( items[ i ].iswildcard )
540                                         continue;
541                                 entry = setFlagOfText( (i==1) ? (WC_MIDDLE | WC_BEGIN) : WC_MIDDLE, 
542                                                                                 lenq + 1 /* MARK_SIGN */ + 2 /* flag */ );
543                                 entry = appendStrToText(entry, items[ i ].ptr, items[ i ].len, -1);
544                                 partialmatch[ *nentries ] = true;
545                                 entries[ *nentries ] =  PointerGetDatum( entry );
546                                 (*nentries)++;
547                         }
548                         if ( splitqlen > 2 /* X1 may be inside X */ )
549                                 needRecheck = true;
550                 }
551                 else
552                 {
553                         /* * X [ * X1 [] ] * => X* [ + X1* [] ] */
554                         for(i=1; i<splitqlen-1; i++)
555                         {
556                                 if ( items[ i ].iswildcard )
557                                         continue;
558
559                                 if ( splitqlen > 3 )
560                                 {
561                                         if ( i==1 )
562                                                 entry = setFlagOfText( WC_MIDDLE | WC_BEGIN, lenq + 1 /* MARK_SIGN */ + 2 /* flag */ );
563                                         else if ( i == splitqlen-2 )
564                                                 entry = setFlagOfText( WC_MIDDLE | WC_END, lenq + 1 /* MARK_SIGN */ + 2 /* flag */ );
565                                         else
566                                                 entry = setFlagOfText( WC_MIDDLE, lenq + 1 /* MARK_SIGN */ + 2 /* flag */ ); 
567                                 }
568                                 else
569                                         entry = NULL;
570                                 entry = appendStrToText(entry, items[ i ].ptr, items[ i ].len, lenq+1);
571                                 partialmatch[ *nentries ] = true;
572                                 entries[ *nentries ] =  PointerGetDatum( entry );
573                                 (*nentries)++;
574
575                         if ( splitqlen > 3 /* we don't remeber an order of Xn */ )
576                                 needRecheck = true;
577                         }
578                 }
579         }
580
581         PG_FREE_IF_COPY(q,0);
582
583 #ifdef OPTIMIZE_WILDCARD_QUERY
584         if ( *nentries > 1 )
585         {
586                 int saven = *nentries;
587
588                 optimize_wildcard_search( entries, nentries );
589
590                 if ( saven != *nentries )
591                         needRecheck = true;
592         }
593 #endif
594
595         if (needRecheck)
596         {
597                 /*
598                  * Non empty extra signals to consistentFn about
599                  * rechecking of result
600                  */
601                 *extra = palloc0(sizeof(Pointer) * *nentries);
602         }
603
604         PG_RETURN_POINTER(entries);
605 }
606
607
608 PG_FUNCTION_INFO_V1(gin_consistent_wildcard);
609 Datum       gin_consistent_wildcard(PG_FUNCTION_ARGS);
610 Datum
611 gin_consistent_wildcard(PG_FUNCTION_ARGS)
612 {
613         bool            *check = (bool *) PG_GETARG_POINTER(0);
614 #ifdef NOT_USED
615         StrategyNumber strategy = PG_GETARG_UINT16(1);
616         text            *query = PG_GETARG_TEXT_P(2);
617 #endif
618         int                     nentries = PG_GETARG_INT32(3);
619         Pointer         *extra = (Pointer *) PG_GETARG_POINTER(4);
620         bool        *recheck = (bool *) PG_GETARG_POINTER(5);
621         bool        res = true;
622         int         i;
623
624         for (i = 0; res && i < nentries; i++)
625                 if (check[i] == false)
626                         res = false;
627
628         *recheck = (extra == NULL) ? false : true;
629
630         PG_RETURN_BOOL(res);
631 }
632
633 /*
634  * Mostly debug fuction
635  */
636 PG_FUNCTION_INFO_V1(permute);
637 Datum       permute(PG_FUNCTION_ARGS);
638 Datum
639 permute(PG_FUNCTION_ARGS)
640 {
641         Datum           src = PG_GETARG_DATUM(0);
642         int32           nentries = 0;
643         Datum           *entries;
644         ArrayType       *res;
645         int             i;
646
647         /*
648          * Get permuted values by gin_extract_permuted()
649          */
650         entries = (Datum*) DatumGetPointer(DirectFunctionCall2(
651                                         gin_extract_permuted, src, PointerGetDatum(&nentries)
652                         ));
653
654         /*
655          * We need to replace MARK_SIGN to MARK_SIGN_SHOW.
656          * See comments above near definition of MARK_SIGN and MARK_SIGN_SHOW.
657          */
658         if ( nentries == 1 && VARSIZE(entries[0]) == VARHDRSZ + 1)
659         {
660                 *(VARDATA(entries[0])) = MARK_SIGN_SHOW;                
661         }
662         else
663         {
664                 int32   offset = 0; /* offset of MARK_SIGN */
665                 char    *ptr;
666
667                 /*
668                  * We scan array from the end because it allows simple calculation
669                  * of MARK_SIGN position: on every iteration it's moved one 
670                  * character to the end.
671                  */
672                 for(i=nentries-1;i>=0;i--) 
673                 {
674                         ptr = VARDATA(entries[i]);
675
676                         offset += pg_mblen(ptr);
677                         Assert( *(ptr + offset) == MARK_SIGN );
678                         *(ptr + offset) = MARK_SIGN_SHOW;
679                 }
680         }
681
682         res = construct_array(
683                                         entries,
684                                         nentries,
685                                         TEXTOID,
686                                         -1,
687                                         false,
688                                         'i'
689                         );
690
691         PG_RETURN_POINTER(res);
692 }