diff --git a/Makefile b/Makefile index dfa32d516..e92335c7c 100644 --- a/Makefile +++ b/Makefile @@ -901,9 +901,12 @@ ggml/src/ggml-metal.o: \ ifdef GGML_METAL_EMBED_LIBRARY ggml/src/ggml-metal-embed.o: \ ggml/src/ggml-metal.metal \ - ggml/src/ggml-common.h + ggml/src/ggml-common.h \ + ggml/src/ggml-metal-impl.h @echo "Embedding Metal library" - @sed -e '/#include "ggml-common.h"/r ggml/src/ggml-common.h' -e '/#include "ggml-common.h"/d' < ggml/src/ggml-metal.metal > ggml/src/ggml-metal-embed.metal + @sed -e '/#include "ggml-common.h"/r ggml/src/ggml-common.h' -e '/#include "ggml-common.h"/d' < ggml/src/ggml-metal.metal > ggml/src/ggml-metal-embed.metal.tmp + @sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal-embed.metal + $(eval TEMP_ASSEMBLY=$(shell mktemp -d)) @echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s @echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index a05f8c505..670389569 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -62,9 +62,10 @@ if (GGML_METAL) add_compile_definitions(GGML_METAL_USE_BF16) endif() - # copy ggml-common.h and ggml-metal.metal to bin directory - configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) - configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) + # copy metal files to bin directory + configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) + configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) + configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) if (GGML_METAL_EMBED_LIBRARY) enable_language(ASM) @@ -72,25 +73,28 @@ if (GGML_METAL) add_compile_definitions(GGML_METAL_EMBED_LIBRARY) set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/ggml-common.h") + set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") # merge ggml-common.h and ggml-metal.metal into a single file - set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") - set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") + set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") add_custom_command( OUTPUT ${METALLIB_EMBED_ASM} COMMAND echo "Embedding Metal library" - COMMAND sed -e '/\#include \"ggml-common.h\"/r ${METALLIB_COMMON}' -e '/\#include \"ggml-common.h\"/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED} + COMMAND sed -e '/\#include \"ggml-common.h\"/r ${METALLIB_COMMON}' -e '/\#include \"ggml-common.h\"/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP} + COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED} COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM} COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM} - DEPENDS ggml-metal.metal ggml-common.h + DEPENDS ggml-metal.metal ggml-common.h ggml-metal-impl.h COMMENT "Generate assembly for embedded Metal library" ) @@ -128,8 +132,9 @@ if (GGML_METAL) COMMAND xcrun -sdk macosx metallib ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h + COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal - DEPENDS ggml-metal.metal ggml-common.h + DEPENDS ggml-metal.metal ggml-common.h ggml-metal-impl.h COMMENT "Compiling Metal kernels" ) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index d25100693..050161393 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -418,246 +418,6 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); -#if defined(GGML_COMMON_DECL_METAL_KARGS) -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - int32_t dim; -} ggml_metal_kargs_concat; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - uint64_t offs; -} ggml_metal_kargs_bin; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; -} ggml_metal_kargs_repeat; - -typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; - int64_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int64_t ne0; - int64_t ne1; - int64_t ne2; - int64_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; -} ggml_metal_kargs_cpy; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - int32_t n_past; - int32_t n_dims; - int32_t n_ctx_orig; - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; -} ggml_metal_kargs_rope; - -typedef struct { - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne11; - int32_t ne_12_2; // assume K and V are same shape - int32_t ne_12_3; - uint64_t nb_12_1; - uint64_t nb_12_2; - uint64_t nb_12_3; - uint64_t nb31; - int32_t ne1; - int32_t ne2; - float scale; - float max_bias; - float m0; - float m1; - uint16_t n_head_log2; - float logit_softcap; -} ggml_metal_kargs_flash_attn_ext; - -typedef struct { - int32_t ne00; - int32_t ne02; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne12; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int16_t r2; - int16_t r3; -} ggml_metal_kargs_mul_mm; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int16_t r2; - int16_t r3; -} ggml_metal_kargs_mul_mv; - -typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; - int32_t ne00; - int32_t ne02; - uint64_t nb01; - uint64_t nb02; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; -} ggml_metal_kargs_mul_mm_id; - -typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; - int32_t ne00; - int32_t ne01; - int32_t ne02; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; - uint64_t nb1; -} ggml_metal_kargs_mul_mv_id; - -typedef struct { - int32_t ne00; - int32_t ne00_4; - uint64_t nb01; - float eps; -} ggml_metal_kargs_norm; - -typedef struct { - int32_t ne00; - int32_t ne00_4; - uint64_t nb01; - float eps; -} ggml_metal_kargs_rms_norm; -#endif - #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal-impl.h b/ggml/src/ggml-metal-impl.h new file mode 100644 index 000000000..53c135496 --- /dev/null +++ b/ggml/src/ggml-metal-impl.h @@ -0,0 +1,249 @@ +#ifndef GGML_METAL_IMPL +#define GGML_METAL_IMPL + +// kernel argument structs +// +// - element counters (e.g. ne00) typically use int32_t to reduce register usage +// however, be careful from int overflows when using those in the kernel implementation +// +// - strides (e.g. nb00) use uint64_t + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} ggml_metal_kargs_concat; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} ggml_metal_kargs_bin; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_repeat; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cpy; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_past; + int32_t n_dims; + int32_t n_ctx_orig; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; +} ggml_metal_kargs_rope; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb_12_1; + uint64_t nb_12_2; + uint64_t nb_12_3; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mm; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mv; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; +} ggml_metal_kargs_mul_mm_id; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; + uint64_t nb1; +} ggml_metal_kargs_mul_mv_id; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_norm; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_rms_norm; + +#endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 6af5ef2d1..d9060400f 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -2,10 +2,7 @@ #import "ggml-impl.h" #import "ggml-backend-impl.h" - -#define GGML_COMMON_DECL_C -#define GGML_COMMON_DECL_METAL_KARGS -#include "ggml-common.h" +#import "ggml-metal-impl.h" #import diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index a0b4dfaff..90a7bccf9 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1,7 +1,7 @@ #define GGML_COMMON_DECL_METAL -#define GGML_COMMON_DECL_METAL_KARGS #define GGML_COMMON_IMPL_METAL #include "ggml-common.h" +#include "ggml-metal-impl.h" #include