From 5aec2cfaac386eb09aebb75b805860828f00de91 Mon Sep 17 00:00:00 2001 From: Tameem <113388789+AhmadTameem@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:27:40 +0500 Subject: [PATCH] ggml : add RISC-V vector intrinsics support (#2929) * added support for RISCV CFLAGS & native compile + cross compile options * Add RISC-V Vector Intrinsics Support Added RVV intrinsics for following ggml_vec_dot_q4_0_q8_0 ggml_vec_dot_q4_1_q8_1 ggml_vec_dot_q5_0_q8_0 ggml_vec_dot_q5_1_q8_1 ggml_vec_dot_q8_0_q8_0 Co-authored-by: Sharafat Signed-off-by: Ahmad Tameem --------- Signed-off-by: Ahmad Tameem Co-authored-by: moiz.hussain Co-authored-by: Sharafat --- Makefile | 13 ++++ ggml.c | 227 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+) diff --git a/Makefile b/Makefile index b56df3d8a..8f73297f4 100644 --- a/Makefile +++ b/Makefile @@ -35,6 +35,11 @@ ifndef UNAME_M UNAME_M := $(shell uname -m) endif +ifdef RISCV_CROSS_COMPILE +CC := riscv64-unknown-linux-gnu-gcc +CXX := riscv64-unknown-linux-gnu-g++ +endif + CCV := $(shell $(CC) --version | head -n 1) CXXV := $(shell $(CXX) --version | head -n 1) @@ -150,6 +155,9 @@ endif # Architecture specific # TODO: probably these flags need to be tweaked on some architectures # feel free to update the Makefile for your architecture and send a pull request or issue + +ifndef RISCV + ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64)) # Use all CPU extensions that are available: CFLAGS += -march=native -mtune=native @@ -198,6 +206,11 @@ ifneq ($(filter ppc64%,$(UNAME_M)),) endif endif +else + CFLAGS += -march=rv64gcv -mabi=lp64d + CXXFLAGS += -march=rv64gcv -mabi=lp64d +endif + ifndef LLAMA_NO_K_QUANTS CFLAGS += -DGGML_USE_K_QUANTS CXXFLAGS += -DGGML_USE_K_QUANTS diff --git a/ggml.c b/ggml.c index 46ce4a581..cf3955f7f 100644 --- a/ggml.c +++ b/ggml.c @@ -301,6 +301,10 @@ typedef double ggml_float; #endif #endif +#ifdef __riscv_v_intrinsic +#include +#endif + #ifdef __F16C__ #ifdef _MSC_VER @@ -2677,6 +2681,41 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * } *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + + vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + + vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl); + vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl); + + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); + sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); + } + + *s = sumf; #else // scalar float sumf = 0.0; @@ -2803,6 +2842,38 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * } *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + + vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); + sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; #else // scalar float sumf = 0.0; @@ -3037,6 +3108,76 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * } *s = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + uint32_t qh; + + // These temp values are for masking and shift operations + uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, + 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + memcpy(&qh, x[i].qh, sizeof(uint32_t)); + + // temporary registers + vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl); + vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl); + vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl); + vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl); + + // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl); + vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl); + vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl); + + // ((qh & (1u << (j + 16))) >> (j + 12)); + vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl); + vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl); + + // narrowing + vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl); + vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl); + + vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl); + vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl); + + // load + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + + vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + + vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl); + vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl); + + vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + + vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl); + vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl); + + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); + sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; + } + + *s = sumf; #else // scalar float sumf = 0.0; @@ -3293,6 +3434,72 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * } *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + uint32_t qh; + + // These temp values are for shift operations + uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + memcpy(&qh, x[i].qh, sizeof(uint32_t)); + + // temporary registers + vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl); + vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl); + + // load qh + vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl); + + // ((qh >> (j + 0)) << 4) & 0x10; + vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl); + vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl); + vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl); + + // ((qh >> (j + 12)) ) & 0x10; + vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl); + vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl); + + // narrowing + vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl); + vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl); + + vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl); + vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl); + + // load + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); + + vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + + vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl); + vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl); + + vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); + sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; #else // scalar float sumf = 0.0; @@ -3404,6 +3611,26 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * } *s = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + size_t vl = __riscv_vsetvl_e8m1(qk); + + for (int i = 0; i < nb; i++) { + // load elements + vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); + vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); + + vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + + sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); + } + + *s = sumf; #else // scalar float sumf = 0.0;