From e813e0e16852c080259cd0813e1a82ecb2625aea Mon Sep 17 00:00:00 2001 From: John Naylor Date: Sat, 20 Aug 2022 21:14:01 -0700 Subject: [PATCH] Add optimized functions for linear search within byte arrays In similar vein to b6ef167564, add pg_lfind8() and pg_lfind8_le() to search for bytes equal or less-than-or-equal to a given byte, respectively. To abstract away platform details, add helper functions and typedefs to simd.h. John Naylor and Nathan Bossart, per suggestion from Andres Freund Discussion: https://www.postgresql.org/message-id/CAFBsxsGzaaGLF%3DNuq61iRXTyspbO9rOjhSqFN%3DV6ozzmta5mXg%40mail.gmail.com --- src/include/port/pg_lfind.h | 68 ++++++- src/include/port/simd.h | 168 +++++++++++++++++- .../test_lfind/expected/test_lfind.out | 18 +- .../modules/test_lfind/sql/test_lfind.sql | 4 +- .../modules/test_lfind/test_lfind--1.0.sql | 10 +- src/test/modules/test_lfind/test_lfind.c | 100 ++++++++++- 6 files changed, 358 insertions(+), 10 deletions(-) diff --git a/src/include/port/pg_lfind.h b/src/include/port/pg_lfind.h index fb125977b2e..a4e13dffec0 100644 --- a/src/include/port/pg_lfind.h +++ b/src/include/port/pg_lfind.h @@ -1,7 +1,8 @@ /*------------------------------------------------------------------------- * * pg_lfind.h - * Optimized linear search routines. + * Optimized linear search routines using SIMD intrinsics where + * available. * * Copyright (c) 2022, PostgreSQL Global Development Group * @@ -15,6 +16,70 @@ #include "port/simd.h" +/* + * pg_lfind8 + * + * Return true if there is an element in 'base' that equals 'key', otherwise + * return false. + */ +static inline bool +pg_lfind8(uint8 key, uint8 *base, uint32 nelem) +{ + uint32 i; + + /* round down to multiple of vector length */ + uint32 tail_idx = nelem & ~(sizeof(Vector8) - 1); + Vector8 chunk; + + for (i = 0; i < tail_idx; i += sizeof(Vector8)) + { + vector8_load(&chunk, &base[i]); + if (vector8_has(chunk, key)) + return true; + } + + /* Process the remaining elements one at a time. */ + for (; i < nelem; i++) + { + if (key == base[i]) + return true; + } + + return false; +} + +/* + * pg_lfind8_le + * + * Return true if there is an element in 'base' that is less than or equal to + * 'key', otherwise return false. + */ +static inline bool +pg_lfind8_le(uint8 key, uint8 *base, uint32 nelem) +{ + uint32 i; + + /* round down to multiple of vector length */ + uint32 tail_idx = nelem & ~(sizeof(Vector8) - 1); + Vector8 chunk; + + for (i = 0; i < tail_idx; i += sizeof(Vector8)) + { + vector8_load(&chunk, &base[i]); + if (vector8_has_le(chunk, key)) + return true; + } + + /* Process the remaining elements one at a time. */ + for (; i < nelem; i++) + { + if (base[i] <= key) + return true; + } + + return false; +} + /* * pg_lfind32 * @@ -26,7 +91,6 @@ pg_lfind32(uint32 key, uint32 *base, uint32 nelem) { uint32 i = 0; - /* Use SIMD intrinsics where available. */ #ifdef USE_SSE2 /* diff --git a/src/include/port/simd.h b/src/include/port/simd.h index a571e79f574..61e4362258b 100644 --- a/src/include/port/simd.h +++ b/src/include/port/simd.h @@ -8,11 +8,17 @@ * * src/include/port/simd.h * + * NOTES + * - VectorN in this file refers to a register where the element operands + * are N bits wide. The vector width is platform-specific, so users that care + * about that will need to inspect "sizeof(VectorN)". + * *------------------------------------------------------------------------- */ #ifndef SIMD_H #define SIMD_H +#if (defined(__x86_64__) || defined(_M_AMD64)) /* * SSE2 instructions are part of the spec for the 64-bit x86 ISA. We assume * that compilers targeting this architecture understand SSE2 intrinsics. @@ -22,9 +28,169 @@ * will allow the use of intrinsics that haven't been enabled at compile * time. */ -#if (defined(__x86_64__) || defined(_M_AMD64)) #include #define USE_SSE2 +typedef __m128i Vector8; + +#else +/* + * If no SIMD instructions are available, we can in some cases emulate vector + * operations using bitwise operations on unsigned integers. + */ +#define USE_NO_SIMD +typedef uint64 Vector8; #endif + +/* load/store operations */ +static inline void vector8_load(Vector8 *v, const uint8 *s); + +/* assignment operations */ +static inline Vector8 vector8_broadcast(const uint8 c); + +/* element-wise comparisons to a scalar */ +static inline bool vector8_has(const Vector8 v, const uint8 c); +static inline bool vector8_has_zero(const Vector8 v); +static inline bool vector8_has_le(const Vector8 v, const uint8 c); + + +/* + * Load a chunk of memory into the given vector. + */ +static inline void +vector8_load(Vector8 *v, const uint8 *s) +{ +#if defined(USE_SSE2) + *v = _mm_loadu_si128((const __m128i *) s); +#else + memcpy(v, s, sizeof(Vector8)); +#endif +} + + +/* + * Create a vector with all elements set to the same value. + */ +static inline Vector8 +vector8_broadcast(const uint8 c) +{ +#if defined(USE_SSE2) + return _mm_set1_epi8(c); +#else + return ~UINT64CONST(0) / 0xFF * c; +#endif +} + +/* + * Return true if any elements in the vector are equal to the given scalar. + */ +static inline bool +vector8_has(const Vector8 v, const uint8 c) +{ + bool result; + + /* pre-compute the result for assert checking */ +#ifdef USE_ASSERT_CHECKING + bool assert_result = false; + + for (int i = 0; i < sizeof(Vector8); i++) + { + if (((const uint8 *) &v)[i] == c) + { + assert_result = true; + break; + } + } +#endif /* USE_ASSERT_CHECKING */ + +#if defined(USE_NO_SIMD) + /* any bytes in v equal to c will evaluate to zero via XOR */ + result = vector8_has_zero(v ^ vector8_broadcast(c)); +#elif defined(USE_SSE2) + result = _mm_movemask_epi8(_mm_cmpeq_epi8(v, vector8_broadcast(c))); +#endif + + Assert(assert_result == result); + return result; +} + +/* + * Convenience function equivalent to vector8_has(v, 0) + */ +static inline bool +vector8_has_zero(const Vector8 v) +{ +#if defined(USE_NO_SIMD) + /* + * We cannot call vector8_has() here, because that would lead to a circular + * definition. + */ + return vector8_has_le(v, 0); +#elif defined(USE_SSE2) + return vector8_has(v, 0); +#endif +} + +/* + * Return true if any elements in the vector are less than or equal to the + * given scalar. + */ +static inline bool +vector8_has_le(const Vector8 v, const uint8 c) +{ + bool result = false; +#if defined(USE_SSE2) + __m128i sub; +#endif + + /* pre-compute the result for assert checking */ +#ifdef USE_ASSERT_CHECKING + bool assert_result = false; + + for (int i = 0; i < sizeof(Vector8); i++) + { + if (((const uint8 *) &v)[i] <= c) + { + assert_result = true; + break; + } + } +#endif /* USE_ASSERT_CHECKING */ + +#if defined(USE_NO_SIMD) + + /* + * To find bytes <= c, we can use bitwise operations to find bytes < c+1, + * but it only works if c+1 <= 128 and if the highest bit in v is not set. + * Adapted from + * https://graphics.stanford.edu/~seander/bithacks.html#HasLessInWord + */ + if ((int64) v >= 0 && c < 0x80) + result = (v - vector8_broadcast(c + 1)) & ~v & vector8_broadcast(0x80); + else + { + /* one byte at a time */ + for (int i = 0; i < sizeof(Vector8); i++) + { + if (((const uint8 *) &v)[i] <= c) + { + result = true; + break; + } + } + } +#elif defined(USE_SSE2) + + /* + * Use saturating subtraction to find bytes <= c, which will present as + * NUL bytes in 'sub'. + */ + sub = _mm_subs_epu8(v, vector8_broadcast(c)); + result = vector8_has_zero(sub); +#endif + + Assert(assert_result == result); + return result; +} + #endif /* SIMD_H */ diff --git a/src/test/modules/test_lfind/expected/test_lfind.out b/src/test/modules/test_lfind/expected/test_lfind.out index 222c8fd7fff..1d4b14e7032 100644 --- a/src/test/modules/test_lfind/expected/test_lfind.out +++ b/src/test/modules/test_lfind/expected/test_lfind.out @@ -4,9 +4,21 @@ CREATE EXTENSION test_lfind; -- the operations complete without crashing or hanging and that none of their -- internal sanity tests fail. -- -SELECT test_lfind(); - test_lfind ------------- +SELECT test_lfind8(); + test_lfind8 +------------- + +(1 row) + +SELECT test_lfind8_le(); + test_lfind8_le +---------------- + +(1 row) + +SELECT test_lfind32(); + test_lfind32 +-------------- (1 row) diff --git a/src/test/modules/test_lfind/sql/test_lfind.sql b/src/test/modules/test_lfind/sql/test_lfind.sql index 899f1dd49bf..766c640831f 100644 --- a/src/test/modules/test_lfind/sql/test_lfind.sql +++ b/src/test/modules/test_lfind/sql/test_lfind.sql @@ -5,4 +5,6 @@ CREATE EXTENSION test_lfind; -- the operations complete without crashing or hanging and that none of their -- internal sanity tests fail. -- -SELECT test_lfind(); +SELECT test_lfind8(); +SELECT test_lfind8_le(); +SELECT test_lfind32(); diff --git a/src/test/modules/test_lfind/test_lfind--1.0.sql b/src/test/modules/test_lfind/test_lfind--1.0.sql index d82ab0567ef..81801926ae8 100644 --- a/src/test/modules/test_lfind/test_lfind--1.0.sql +++ b/src/test/modules/test_lfind/test_lfind--1.0.sql @@ -3,6 +3,14 @@ -- complain if script is sourced in psql, rather than via CREATE EXTENSION \echo Use "CREATE EXTENSION test_lfind" to load this file. \quit -CREATE FUNCTION test_lfind() +CREATE FUNCTION test_lfind32() + RETURNS pg_catalog.void + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE FUNCTION test_lfind8() + RETURNS pg_catalog.void + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE FUNCTION test_lfind8_le() RETURNS pg_catalog.void AS 'MODULE_PATHNAME' LANGUAGE C; diff --git a/src/test/modules/test_lfind/test_lfind.c b/src/test/modules/test_lfind/test_lfind.c index a000746fb83..82673d54c6e 100644 --- a/src/test/modules/test_lfind/test_lfind.c +++ b/src/test/modules/test_lfind/test_lfind.c @@ -16,12 +16,108 @@ #include "fmgr.h" #include "port/pg_lfind.h" +/* + * Convenience macros for testing both vector and scalar operations. The 2x + * factor is to make sure iteration works + */ +#define LEN_NO_TAIL(vectortype) (2 * sizeof(vectortype)) +#define LEN_WITH_TAIL(vectortype) (LEN_NO_TAIL(vectortype) + 3) + PG_MODULE_MAGIC; -PG_FUNCTION_INFO_V1(test_lfind); +/* workhorse for test_lfind8 */ +static void +test_lfind8_internal(uint8 key) +{ + uint8 charbuf[LEN_WITH_TAIL(Vector8)]; + const int len_no_tail = LEN_NO_TAIL(Vector8); + const int len_with_tail = LEN_WITH_TAIL(Vector8); + memset(charbuf, 0xFF, len_with_tail); + /* search tail to test one-byte-at-a-time path */ + charbuf[len_with_tail - 1] = key; + if (key > 0x00 && pg_lfind8(key - 1, charbuf, len_with_tail)) + elog(ERROR, "pg_lfind8() found nonexistent element '0x%x'", key - 1); + if (key < 0xFF && !pg_lfind8(key, charbuf, len_with_tail)) + elog(ERROR, "pg_lfind8() did not find existing element '0x%x'", key); + if (key < 0xFE && pg_lfind8(key + 1, charbuf, len_with_tail)) + elog(ERROR, "pg_lfind8() found nonexistent element '0x%x'", key + 1); + + memset(charbuf, 0xFF, len_with_tail); + /* search with vector operations */ + charbuf[len_no_tail - 1] = key; + if (key > 0x00 && pg_lfind8(key - 1, charbuf, len_no_tail)) + elog(ERROR, "pg_lfind8() found nonexistent element '0x%x'", key - 1); + if (key < 0xFF && !pg_lfind8(key, charbuf, len_no_tail)) + elog(ERROR, "pg_lfind8() did not find existing element '0x%x'", key); + if (key < 0xFE && pg_lfind8(key + 1, charbuf, len_no_tail)) + elog(ERROR, "pg_lfind8() found nonexistent element '0x%x'", key + 1); +} + +PG_FUNCTION_INFO_V1(test_lfind8); Datum -test_lfind(PG_FUNCTION_ARGS) +test_lfind8(PG_FUNCTION_ARGS) +{ + test_lfind8_internal(0); + test_lfind8_internal(1); + test_lfind8_internal(0x7F); + test_lfind8_internal(0x80); + test_lfind8_internal(0x81); + test_lfind8_internal(0xFD); + test_lfind8_internal(0xFE); + test_lfind8_internal(0xFF); + + PG_RETURN_VOID(); +} + +/* workhorse for test_lfind8_le */ +static void +test_lfind8_le_internal(uint8 key) +{ + uint8 charbuf[LEN_WITH_TAIL(Vector8)]; + const int len_no_tail = LEN_NO_TAIL(Vector8); + const int len_with_tail = LEN_WITH_TAIL(Vector8); + + memset(charbuf, 0xFF, len_with_tail); + /* search tail to test one-byte-at-a-time path */ + charbuf[len_with_tail - 1] = key; + if (key > 0x00 && pg_lfind8_le(key - 1, charbuf, len_with_tail)) + elog(ERROR, "pg_lfind8_le() found nonexistent element <= '0x%x'", key - 1); + if (key < 0xFF && !pg_lfind8_le(key, charbuf, len_with_tail)) + elog(ERROR, "pg_lfind8_le() did not find existing element <= '0x%x'", key); + if (key < 0xFE && !pg_lfind8_le(key + 1, charbuf, len_with_tail)) + elog(ERROR, "pg_lfind8_le() did not find existing element <= '0x%x'", key + 1); + + memset(charbuf, 0xFF, len_with_tail); + /* search with vector operations */ + charbuf[len_no_tail - 1] = key; + if (key > 0x00 && pg_lfind8_le(key - 1, charbuf, len_no_tail)) + elog(ERROR, "pg_lfind8_le() found nonexistent element <= '0x%x'", key - 1); + if (key < 0xFF && !pg_lfind8_le(key, charbuf, len_no_tail)) + elog(ERROR, "pg_lfind8_le() did not find existing element <= '0x%x'", key); + if (key < 0xFE && !pg_lfind8_le(key + 1, charbuf, len_no_tail)) + elog(ERROR, "pg_lfind8_le() did not find existing element <= '0x%x'", key + 1); +} + +PG_FUNCTION_INFO_V1(test_lfind8_le); +Datum +test_lfind8_le(PG_FUNCTION_ARGS) +{ + test_lfind8_le_internal(0); + test_lfind8_le_internal(1); + test_lfind8_le_internal(0x7F); + test_lfind8_le_internal(0x80); + test_lfind8_le_internal(0x81); + test_lfind8_le_internal(0xFD); + test_lfind8_le_internal(0xFE); + test_lfind8_le_internal(0xFF); + + PG_RETURN_VOID(); +} + +PG_FUNCTION_INFO_V1(test_lfind32); +Datum +test_lfind32(PG_FUNCTION_ARGS) { #define TEST_ARRAY_SIZE 135 uint32 test_array[TEST_ARRAY_SIZE] = {0};