Skip to content

[libc] Better implementation of Strlen for SVE targets #167624

@SchrodingerZhu

Description

@SchrodingerZhu

With #167259 , we added the textbook version of strlen. However, as mentioned in that PR, the SVE version is poorly optimized.

I managed to get a version where SVE performs better than the system libc (which is using AOR strlen asimd) on super long strings but it is rather complicated and still fall behind in middle range. I hope someone who are very familiar with SVE optimization can help me look into this:

// -*- C++ -*-
// Standalone SVE/NEON/libc strlen microbenchmark
#include <algorithm>
#include <array>
#include <cassert>
#include <chrono>
#include <cinttypes>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <tuple>
#include <vector>

#include <arm_neon.h>
#include <arm_sve.h>

#define LIBC_LIKELY(x) __builtin_expect(!!(x), 1)
#define LIBC_UNLIKELY(x) __builtin_expect(!!(x), 0)

// -----------------------------------------------------------------------------
// NEON implementation
// -----------------------------------------------------------------------------
namespace neon {
template <size_t MAX_STEP>
static inline size_t string_length_impl(const char *src, size_t &len) {
  using Vector __attribute__((may_alias)) = uint8x8_t;

  uintptr_t misalign_bytes = reinterpret_cast<uintptr_t>(src) % sizeof(Vector);
  const Vector *block_ptr =
      reinterpret_cast<const Vector *>(src - misalign_bytes);

  Vector v = *block_ptr;
  Vector vcmp = vceqz_u8(v);
  uint64x1_t cmp_mask = vreinterpret_u64_u8(vcmp);
  uint64_t cmp = vget_lane_u64(cmp_mask, 0);
  cmp >>= (misalign_bytes << 3);
  if (cmp) {
    len = __builtin_ctzll(cmp) >> 3;
    return true;
  }
  for (size_t i = 0; i < MAX_STEP || MAX_STEP == 0; ++i) {
    ++block_ptr;
    v = *block_ptr;
    vcmp = vceqz_u8(v);
    cmp_mask = vreinterpret_u64_u8(vcmp);
    cmp = vget_lane_u64(vcmp, 0);
    if (cmp) {
      size_t base = reinterpret_cast<uintptr_t>(block_ptr) -
                    reinterpret_cast<uintptr_t>(src);
      len = base + ((__builtin_ctzll(cmp)) >> 3);
      return true;
    }
  }
  return false;
}
static inline size_t string_length(const char *src) {
  size_t len = 0;
  string_length_impl<0>(src, len);
  return len;
}
} // namespace neon

// -----------------------------------------------------------------------------
// SVE implementation
// -----------------------------------------------------------------------------
namespace sve {
[[maybe_unused]] static inline size_t string_length(const char *src) {

  size_t len = 0;
  if (neon::string_length_impl<8>(src, len))
    return len;

  constexpr size_t MINIMAL_PAGE_SIZE = 4096;
  const size_t vlen = svcntb();
  const uint8_t *ptr = reinterpret_cast<const uint8_t *>(src);
  for (;;) {
    // Near to page boundary, do loads with fault suppressed.
    for (size_t i = 0; i < 4; ++i) {
      svsetffr();
      svuint8_t data = svldff1(svptrue_b8(), &ptr[len]);
      svbool_t no_fault_mask = svrdffr_z(svptrue_b8());
      svbool_t cmp_zero = svcmpeq_n_u8(svptrue_b8(), data, 0);
      bool has_zero = svptest_any(no_fault_mask, cmp_zero);
      if (has_zero) {
        svbool_t before_zero = svbrkb_z(no_fault_mask, cmp_zero);
        return len + svcntp_b8(no_fault_mask, before_zero);
      }
      len += svcntp_b8(svptrue_b8(), no_fault_mask);
    }

    size_t remaining =
        MINIMAL_PAGE_SIZE -
        (reinterpret_cast<uintptr_t>(ptr) + len) % MINIMAL_PAGE_SIZE;

    while (remaining > 4 * vlen) {
      svuint8_t fst = svld1_u8(svptrue_b8(), &ptr[len]);
      svuint8_t snd = svld1_u8(svptrue_b8(), &ptr[len + vlen]);
      svuint8_t trd = svld1_u8(svptrue_b8(), &ptr[len + 2 * vlen]);
      svuint8_t fth = svld1_u8(svptrue_b8(), &ptr[len + 3 * vlen]);
      len += 4 * vlen;
      remaining -= 4 * vlen;
      svuint8_t min1 = svmin_x(svptrue_b8(), fst, snd);
      svuint8_t min2 = svmin_x(svptrue_b8(), trd, fth);
      svuint8_t min = svmin_x(svptrue_b8(), min1, min2);
      svbool_t cmp_zero = svcmpeq_n_u8(svptrue_b8(), min, 0);
      if (!svptest_any(svptrue_b8(), cmp_zero))
        continue;
      svbool_t fst_cmp_zero = svcmpeq_n_u8(svptrue_b8(), fst, 0);
      if (svptest_any(svptrue_b8(), fst_cmp_zero)) {
        svbool_t before_zero = svbrkb_z(svptrue_b8(), fst_cmp_zero);
        return len + svcntp_b8(svptrue_b8(), before_zero) - 4 * vlen;
      }
      svbool_t snd_cmp_zero = svcmpeq_n_u8(svptrue_b8(), snd, 0);
      if (svptest_any(svptrue_b8(), snd_cmp_zero)) {
        svbool_t before_zero = svbrkb_z(svptrue_b8(), snd_cmp_zero);
        return len + svcntp_b8(svptrue_b8(), before_zero) - 3 * vlen;
      }
      svbool_t trd_cmp_zero = svcmpeq_n_u8(svptrue_b8(), trd, 0);
      if (svptest_any(svptrue_b8(), trd_cmp_zero)) {
        svbool_t before_zero = svbrkb_z(svptrue_b8(), trd_cmp_zero);
        return len + svcntp_b8(svptrue_b8(), before_zero) - 2 * vlen;
      }
      svbool_t fth_cmp_zero = svcmpeq_n_u8(svptrue_b8(), fth, 0);
      if (svptest_any(svptrue_b8(), fth_cmp_zero)) {
        svbool_t before_zero = svbrkb_z(svptrue_b8(), fth_cmp_zero);
        return len + svcntp_b8(svptrue_b8(), before_zero) - vlen;
      }
      __builtin_unreachable();
    }
  }
}
} // namespace sve

// -----------------------------------------------------------------------------
// libc fallback
// -----------------------------------------------------------------------------
namespace syslibc {
static inline size_t string_length(const char *s) { return std::strlen(s); }
} // namespace syslibc

// -----------------------------------------------------------------------------
// Benchmark harness
// -----------------------------------------------------------------------------
struct Impl {
  const char *name;
  size_t (*fn)(const char *);
};

static std::vector<Impl> get_impls() {
  return {{"libc", &syslibc::string_length},
          {"neon", &neon::string_length},
          {"sve", &sve::string_length}};
}

struct Result {
  std::string name;
  double ns_per_call;
  double gib_per_s;
};

static inline uint64_t now_ns() {
  using clock = std::chrono::steady_clock;
  return std::chrono::duration_cast<std::chrono::nanoseconds>(
             clock::now().time_since_epoch())
      .count();
}

// correctness check
static bool run_correctness(const std::vector<Impl> &impls) {
  bool ok = true;
  std::vector<size_t> sizes = {
      0,   1,   3,   7,   8,   9,    15,   16,   31,   32,   63,    64,   127,
      128, 255, 256, 511, 512, 1023, 1024, 4096, 4777, 8192, 16383, 16384};
  for (size_t n : sizes) {
    std::unique_ptr<char[]> s(new char[n + 2]);
    std::fill(s.get(), s.get() + n, 'A');
    s[n] = 0;
    size_t ref = syslibc::string_length(s.get());
    for (auto &impl : impls) {
      size_t got = impl.fn(s.get());
      if (got != ref) {
        std::cerr << "FAIL " << impl.name << " len=" << n << " got=" << got
                  << " ref=" << ref << "\n";
        ok = false;
      }
    }
  }
  return ok;
}

static Result bench(const Impl &impl, size_t size, size_t reps) {
  std::unique_ptr<char[]> buf(new char[size + 1]);
  std::fill(buf.get(), buf.get() + size, 'X');
  buf[size] = 0;

  volatile size_t dummy = 0;
  uint64_t t0 = now_ns();
  for (size_t i = 0; i < reps; ++i)
    dummy += impl.fn(buf.get());
  uint64_t t1 = now_ns();
  double ns_call = double(t1 - t0) / reps;
  double gib_s =
      (double(size) * reps) / ((t1 - t0) * 1e-9) / (1024.0 * 1024.0 * 1024.0);
  (void)dummy;
  return {impl.name, ns_call, gib_s};
}

int main() {
  auto impls = get_impls();
  std::cout << "Implementations:";
  for (auto &i : impls)
    std::cout << " " << i.name;
  std::cout << "\n";

  if (!run_correctness(impls)) {
    std::cerr << "Correctness check failed!\n";
    return 1;
  }

  std::vector<size_t> sizes = {1,  2,  3,  4,   5,    6,    7,
                               8,  9,  10, 11,  12,   13,   14,
                               15, 16, 64, 256, 1024, 4096, 1 << 20};
  for (size_t s : sizes) {
    std::cout << "\n=== strlen(" << s << " bytes) ===\n";
    for (auto &impl : impls) {
      Result r = bench(impl, s, 1000000 / std::max<size_t>(1, s / 16));
      std::cout << impl.name << ": " << r.ns_per_call << " ns/call, "
                << r.gib_per_s << " GiB/s\n";
    }
  }
  return 0;
}

Currently, I get:

Implementations: libc neon sve

=== strlen(1 bytes) ===
libc: 1.54709 ns/call, 0.601984 GiB/s
neon: 1.55341 ns/call, 0.599535 GiB/s
sve: 1.55998 ns/call, 0.597008 GiB/s

=== strlen(2 bytes) ===
libc: 1.54949 ns/call, 1.2021 GiB/s
neon: 1.56232 ns/call, 1.19223 GiB/s
sve: 1.56403 ns/call, 1.19093 GiB/s

=== strlen(3 bytes) ===
libc: 1.5488 ns/call, 1.80396 GiB/s
neon: 1.56451 ns/call, 1.78584 GiB/s
sve: 1.55648 ns/call, 1.79505 GiB/s

=== strlen(4 bytes) ===
libc: 1.55371 ns/call, 2.39767 GiB/s
neon: 1.55062 ns/call, 2.40245 GiB/s
sve: 1.5603 ns/call, 2.38754 GiB/s

=== strlen(5 bytes) ===
libc: 1.5491 ns/call, 3.006 GiB/s
neon: 1.5563 ns/call, 2.9921 GiB/s
sve: 1.58851 ns/call, 2.93143 GiB/s

=== strlen(6 bytes) ===
libc: 1.55341 ns/call, 3.59721 GiB/s
neon: 1.55934 ns/call, 3.58352 GiB/s
sve: 1.56587 ns/call, 3.56858 GiB/s

=== strlen(7 bytes) ===
libc: 1.54874 ns/call, 4.20941 GiB/s
neon: 1.56408 ns/call, 4.16811 GiB/s
sve: 1.56109 ns/call, 4.1761 GiB/s

=== strlen(8 bytes) ===
libc: 1.54798 ns/call, 4.81309 GiB/s
neon: 1.55618 ns/call, 4.78775 GiB/s
sve: 1.66662 ns/call, 4.47046 GiB/s

=== strlen(9 bytes) ===
libc: 1.55501 ns/call, 5.39026 GiB/s
neon: 1.55747 ns/call, 5.38174 GiB/s
sve: 1.69088 ns/call, 4.95712 GiB/s

=== strlen(10 bytes) ===
libc: 1.54885 ns/call, 6.013 GiB/s
neon: 1.55358 ns/call, 5.99467 GiB/s
sve: 1.66546 ns/call, 5.592 GiB/s

=== strlen(11 bytes) ===
libc: 1.54771 ns/call, 6.61915 GiB/s
neon: 1.55496 ns/call, 6.5883 GiB/s
sve: 1.67341 ns/call, 6.12197 GiB/s

=== strlen(12 bytes) ===
libc: 1.54979 ns/call, 7.21121 GiB/s
neon: 1.56757 ns/call, 7.12943 GiB/s
sve: 1.66768 ns/call, 6.70144 GiB/s

=== strlen(13 bytes) ===
libc: 1.54618 ns/call, 7.83041 GiB/s
neon: 1.55778 ns/call, 7.7721 GiB/s
sve: 1.66386 ns/call, 7.27659 GiB/s

=== strlen(14 bytes) ===
libc: 1.54958 ns/call, 8.4142 GiB/s
neon: 1.5529 ns/call, 8.39625 GiB/s
sve: 1.63237 ns/call, 7.98749 GiB/s

=== strlen(15 bytes) ===
libc: 1.54933 ns/call, 9.01671 GiB/s
neon: 1.55451 ns/call, 8.98664 GiB/s
sve: 1.68683 ns/call, 8.2817 GiB/s

=== strlen(16 bytes) ===
libc: 1.54842 ns/call, 9.62349 GiB/s
neon: 1.56699 ns/call, 9.5094 GiB/s
sve: 1.56869 ns/call, 9.49912 GiB/s

=== strlen(64 bytes) ===
libc: 2.05222 ns/call, 29.0439 GiB/s
neon: 2.5751 ns/call, 23.1465 GiB/s
sve: 2.43021 ns/call, 24.5266 GiB/s

=== strlen(256 bytes) ===
libc: 3.41683 ns/call, 69.7777 GiB/s
neon: 9.74771 ns/call, 24.4589 GiB/s
sve: 6.98982 ns/call, 34.1094 GiB/s

=== strlen(1024 bytes) ===
libc: 10.2717 ns/call, 92.8444 GiB/s
neon: 34.388 ns/call, 27.7328 GiB/s
sve: 14.9494 ns/call, 63.7936 GiB/s

=== strlen(4096 bytes) ===
libc: 36.7107 ns/call, 103.912 GiB/s
neon: 145.79 ns/call, 26.1657 GiB/s
sve: 33.2739 ns/call, 114.645 GiB/s

=== strlen(1048576 bytes) ===
libc: 9803.73 ns/call, 99.6113 GiB/s
neon: 35989.3 ns/call, 27.1348 GiB/s
sve: 8660.27 ns/call, 112.764 GiB/s

Metadata

Metadata

Assignees

No one assigned

    Labels

    backend:AArch64help wantedIndicates that a maintainer wants help. Not [good first issue].libc

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions