diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 425a25895..2e1048906 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -656,10 +656,25 @@ if (GGML_KOMPUTE) message(FATAL_ERROR "glslc not found") endif() + # Function to extract #define value from header file + function(get_define_value HEADER DEFINE_NAME RESULT_VAR) + file(STRINGS ${HEADER} DEFINE_LINE REGEX "^#define[\t ]+${DEFINE_NAME}[\t ]+.*") + if(DEFINE_LINE) + string(REGEX REPLACE "^#define[\t ]+${DEFINE_NAME}[\t ]+([0-9]+).*" "\\1" DEFINE_VALUE ${DEFINE_LINE}) + set(${RESULT_VAR} ${DEFINE_VALUE} PARENT_SCOPE) + else() + message(WARNING "Define ${DEFINE_NAME} not found in ${HEADER}") + set(${RESULT_VAR} "" PARENT_SCOPE) + endif() + endfunction() + function(compile_shader) set(options) set(oneValueArgs) set(multiValueArgs SOURCES) + set(GGML_HEADER_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml.h") + message(STATUS "GGML_HEADER_PATH: ${GGML_HEADER_PATH}") + get_define_value(${GGML_HEADER_PATH} GGML_ROPE_TYPE_NEOX GGML_ROPE_TYPE_NEOX_VALUE) cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) foreach(source ${compile_shader_SOURCES}) get_filename_component(filename ${source} NAME) @@ -671,7 +686,10 @@ if (GGML_KOMPUTE) ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp - COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source} + ${GGML_HEADER_PATH} + COMMAND ${glslc_executable} --target-env=vulkan1.2 + -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source} + -DGGML_ROPE_TYPE_NEOX=${GGML_ROPE_TYPE_NEOX_VALUE} COMMENT "Compiling ${source} to ${spv_file}" )