diff --git a/CMakeLists.txt b/CMakeLists.txt index 1a434f07b..8d68ad183 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,6 +64,114 @@ option(LLAMA_OPENBLAS "llama: use OpenBLAS" option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) +INCLUDE(CheckCSourceRuns) + +SET(AVX_CODE " + #include + int main() + { + __m256 a; + a = _mm256_set1_ps(0); + return 0; + } +") + +SET(AVX512_CODE " + #include + int main() + { + __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0); + __m512i b = a; + __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ); + return 0; + } +") + +SET(AVX2_CODE " + #include + int main() + { + __m256i a = {0}; + a = _mm256_abs_epi16(a); + __m256i x; + _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code + return 0; + } +") + +SET(FMA_CODE " + #include + int main() + { + __m256 acc = _mm256_setzero_ps(); + const __m256 d = _mm256_setzero_ps(); + const __m256 p = _mm256_setzero_ps(); + acc = _mm256_fmadd_ps( d, p, acc ); + return 0; + } +") + +MACRO(CHECK_SSE type flags) + SET(__FLAG_I 1) + SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + FOREACH(__FLAG ${flags}) + IF(NOT ${type}_FOUND) + SET(CMAKE_REQUIRED_FLAGS ${__FLAG}) + CHECK_C_SOURCE_RUNS("${${type}_CODE}" HAS_${type}_${__FLAG_I}) + IF(HAS_${type}_${__FLAG_I}) + SET(${type}_FOUND TRUE CACHE BOOL "${type} support") + SET(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags") + ENDIF() + MATH(EXPR __FLAG_I "${__FLAG_I}+1") + ENDIF() + ENDFOREACH() + SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + + IF(NOT ${type}_FOUND) + SET(${type}_FOUND FALSE CACHE BOOL "${type} support") + SET(${type}_FLAGS "" CACHE STRING "${type} flags") + ENDIF() + + MARK_AS_ADVANCED(${type}_FOUND ${type}_FLAGS) + +ENDMACRO() + +CHECK_SSE("AVX" " ;-mavx;/arch:AVX") +CHECK_SSE("AVX2" " ;-mavx2 -mfma;/arch:AVX2") +CHECK_SSE("AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512") +CHECK_SSE("FMA" " ;-mfma;") + +IF(${AVX_FOUND}) + set(LLAMA_AVX ON) +ELSE() + set(LLAMA_AVX OFF) +ENDIF() + +IF (${FMA_FOUND}) + set(LLAMA_FMA ON) +ELSE() + set(LLAMA_FMA OFF) +ENDIF() + +IF(${AVX2_FOUND}) + set(LLAMA_AVX2 ON) +ELSE() + set(LLAMA_AVX2 OFF) +ENDIF() + +IF(${AVX512_FOUND}) + set(LLAMA_AVX512 ON) +ELSE() + set(LLAMA_AVX512 OFF) +ENDIF() + # # Compile flags #