Merge branch 'master' into gg/flash-attn

This commit is contained in:
Georgi Gerganov 2024-01-31 18:49:43 +02:00
commit 2ddc9bbef1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
30 changed files with 1758 additions and 1597 deletions

View File

@ -565,6 +565,31 @@ jobs:
path: |
cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip
windows-latest-cmake-sycl:
runs-on: windows-latest
defaults:
run:
shell: bash
env:
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/62641e01-1e8d-4ace-91d6-ae03f7f8a71f/w_BaseKit_p_2024.0.0.49563_offline.exe
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel
steps:
- name: Clone
id: checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Install
run: scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL
- name: Build
id: cmake_build
run: examples/sycl/win-build-sycl.bat
ios-xcode-build:
runs-on: macos-latest

View File

@ -1,6 +1,12 @@
name: EditorConfig Checker
on:
workflow_dispatch: # allows manual triggering
inputs:
create_release:
description: 'Create new release'
required: true
type: boolean
push:
branches:
- master

1
.gitignore vendored
View File

@ -89,3 +89,4 @@ examples/jeopardy/results.txt
poetry.lock
poetry.toml
nppBackup

View File

@ -507,7 +507,11 @@ if (LLAMA_SYCL)
set(GGML_HEADERS_SYCL ggml.h ggml-sycl.h)
set(GGML_SOURCES_SYCL ggml-sycl.cpp)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
if (WIN32)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -fsycl sycl7 OpenCL mkl_sycl_blas_dll.lib mkl_intel_ilp64_dll.lib mkl_sequential_dll.lib mkl_core_dll.lib)
else()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -fsycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
endif()
endif()
if (LLAMA_KOMPUTE)

View File

@ -8,10 +8,14 @@
[Linux](#linux)
[Windows](#windows)
[Environment Variable](#environment-variable)
[Known Issue](#known-issue)
[Q&A](#q&a)
[Todo](#todo)
## Background
@ -33,7 +37,7 @@ For Intel CPU, recommend to use llama.cpp for X86 (Intel MKL building).
|OS|Status|Verified|
|-|-|-|
|Linux|Support|Ubuntu 22.04|
|Windows|Ongoing| |
|Windows|Support|Windows 11|
## Intel GPU
@ -42,7 +46,7 @@ For Intel CPU, recommend to use llama.cpp for X86 (Intel MKL building).
|-|-|-|
|Intel Data Center Max Series| Support| Max 1550|
|Intel Data Center Flex Series| Support| Flex 170|
|Intel Arc Series| Support| Arc 770|
|Intel Arc Series| Support| Arc 770, 730M|
|Intel built-in Arc GPU| Support| built-in Arc GPU in Meteor Lake|
|Intel iGPU| Support| iGPU in i5-1250P, i7-1165G7|
@ -131,6 +135,7 @@ cmake .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
#build all binary
cmake --build . --config Release -v
cd ..
```
or
@ -195,7 +200,7 @@ GGML_SYCL_DEVICE=0 ./build/bin/main -m models/llama-2-7b.Q4_0.gguf -p "Building
or run by script:
```
./examples/sycl/run_llama2.sh
./examples/sycl/run-llama2.sh
```
Note:
@ -205,11 +210,175 @@ Note:
5. Check the device ID in output
Like
Like:
```
Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
```
## Windows
### Setup Environment
1. Install Intel GPU driver.
Please install Intel GPU driver by official guide: [Install GPU Drivers](https://www.intel.com/content/www/us/en/products/docs/discrete-gpus/arc/software/drivers.html).
2. Install Intel® oneAPI Base toolkit.
a. Please follow the procedure in [Get the Intel® oneAPI Base Toolkit ](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html).
Recommend to install to default folder: **/opt/intel/oneapi**.
Following guide uses the default folder as example. If you use other folder, please modify the following guide info with your folder.
b. Enable oneAPI running environment:
- In Search, input 'oneAPI'.
Search & open "Intel oneAPI command prompt for Intel 64 for Visual Studio 2022"
- In Run:
In CMD:
```
"C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64
```
c. Check GPU
In oneAPI command line:
```
sycl-ls
```
There should be one or more level-zero devices. Like **[ext_oneapi_level_zero:gpu:0]**.
Output (example):
```
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000]
[opencl:cpu:1] Intel(R) OpenCL, 11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000]
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Iris(R) Xe Graphics OpenCL 3.0 NEO [31.0.101.5186]
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Iris(R) Xe Graphics 1.3 [1.3.28044]
```
3. Install cmake & make
a. Download & install cmake for windows: https://cmake.org/download/
b. Download & install make for windows provided by mingw-w64: https://www.mingw-w64.org/downloads/
### Build locally:
In oneAPI command line window:
```
mkdir -p build
cd build
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: for FP16
:: faster for long-prompt inference
:: cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON
:: for FP32
cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release
:: build example/main only
:: make main
:: build all binary
make -j
cd ..
```
or
```
.\examples\sycl\win-build-sycl.bat
```
Note:
- By default, it will build for all binary files. It will take more time. To reduce the time, we recommend to build for **example/main** only.
### Run
1. Put model file to folder **models**
2. Enable oneAPI running environment
- In Search, input 'oneAPI'.
Search & open "Intel oneAPI command prompt for Intel 64 for Visual Studio 2022"
- In Run:
In CMD:
```
"C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64
```
3. List device ID
Run without parameter:
```
build\bin\ls-sycl-device.exe
or
build\bin\main.exe
```
Check the ID in startup log, like:
```
found 4 SYCL devices:
Device 0: Intel(R) Arc(TM) A770 Graphics, compute capability 1.3,
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
Device 1: Intel(R) FPGA Emulation Device, compute capability 1.2,
max compute_units 24, max work group size 67108864, max sub group size 64, global mem size 67065057280
Device 2: 13th Gen Intel(R) Core(TM) i7-13700K, compute capability 3.0,
max compute_units 24, max work group size 8192, max sub group size 64, global mem size 67065057280
Device 3: Intel(R) Arc(TM) A770 Graphics, compute capability 3.0,
max compute_units 512, max work group size 1024, max sub group size 32, global mem size 16225243136
```
|Attribute|Note|
|-|-|
|compute capability 1.3|Level-zero running time, recommended |
|compute capability 3.0|OpenCL running time, slower than level-zero in most cases|
4. Set device ID and execute llama.cpp
Set device ID = 0 by **set GGML_SYCL_DEVICE=0**
```
set GGML_SYCL_DEVICE=0
build\bin\main.exe -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0
```
or run by script:
```
.\examples\sycl\win-run-llama2.bat
```
Note:
- By default, mmap is used to read model file. In some cases, it leads to the hang issue. Recommend to use parameter **--no-mmap** to disable mmap() to skip this issue.
5. Check the device ID in output
Like:
```
Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
```
## Environment Variable
@ -220,7 +389,7 @@ Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
|LLAMA_SYCL|ON (mandatory)|Enable build with SYCL code path. <br>For FP32/FP16, LLAMA_SYCL=ON is mandatory.|
|LLAMA_SYCL_F16|ON (optional)|Enable FP16 build with SYCL code path. Faster for long-prompt inference. <br>For FP32, not set it.|
|CMAKE_C_COMPILER|icx|Use icx compiler for SYCL code path|
|CMAKE_CXX_COMPILER|icpx|use icpx for SYCL code path|
|CMAKE_CXX_COMPILER|icpx (Linux), icx (Windows)|use icpx/icx for SYCL code path|
#### Running
@ -232,19 +401,24 @@ Using device **0** (Intel(R) Arc(TM) A770 Graphics) as main device
## Known Issue
- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`.
Miss to enable oneAPI running environment.
Install oneAPI base toolkit and enable it by: `source /opt/intel/oneapi/setvars.sh`.
- Hang during startup
llama.cpp use mmap as default way to read model file and copy to GPU. In some system, memcpy will be abnormal and block.
Solution: add **--no-mmap**.
## Q&A
- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`.
Miss to enable oneAPI running environment.
Install oneAPI base toolkit and enable it by: `source /opt/intel/oneapi/setvars.sh`.
- In Windows, no result, not error.
Miss to enable oneAPI running environment.
## Todo
- Support to build in Windows.

View File

@ -10,6 +10,9 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
### Hot topics
- Remove LLAMA_MAX_DEVICES and LLAMA_SUPPORTS_GPU_OFFLOAD: https://github.com/ggerganov/llama.cpp/pull/5240
- Incoming backends: https://github.com/ggerganov/llama.cpp/discussions/5138
- [SYCL backend](README-sycl.md) is ready (1/28/2024), support Linux/Windows in Intel GPUs (iGPU, Arc/Flex/Max series)
- New SOTA quantized models, including pure 2-bits: https://huggingface.co/ikawrakow
- Collecting Apple Silicon performance stats:
- M-series: https://github.com/ggerganov/llama.cpp/discussions/4167
@ -604,7 +607,7 @@ Building the program with BLAS support may lead to some performance improvements
llama.cpp based on SYCL is used to support Intel GPU (Data Center Max series, Flex series, Arc series, Built-in GPU and iGPU).
For detailed info, please refer to [llama.cpp for SYCL](README_sycl.md).
For detailed info, please refer to [llama.cpp for SYCL](README-sycl.md).
### Prepare Data & Run

View File

@ -583,20 +583,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_gpu_layers = std::stoi(argv[i]);
#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
}
} else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_gpu_layers_draft = std::stoi(argv[i]);
#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
}
} else if (arg == "--main-gpu" || arg == "-mg") {
if (++i >= argc) {
invalid_param = true;
@ -637,11 +637,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
if (split_arg.size() >= LLAMA_MAX_DEVICES) {
if (split_arg.size() >= llama_max_devices()) {
invalid_param = true;
break;
}
for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
for (size_t i = 0; i < llama_max_devices(); ++i) {
if (i < split_arg.size()) {
params.tensor_split[i] = std::stof(split_arg[i]);
} else {
@ -989,30 +989,30 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
if (llama_mlock_supported()) {
if (llama_supports_mlock()) {
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
if (llama_mmap_supported()) {
if (llama_supports_mmap()) {
printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
}
printf(" --numa attempt optimizations that help on some NUMA systems\n");
printf(" if run without this previously, it is recommended to drop the system page cache before using this\n");
printf(" see https://github.com/ggerganov/llama.cpp/issues/1437\n");
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
printf(" -ngl N, --n-gpu-layers N\n");
printf(" number of layers to store in VRAM\n");
printf(" -ngld N, --n-gpu-layers-draft N\n");
printf(" number of layers to store in VRAM for the draft model\n");
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
printf(" how to split the model across multiple GPUs, one of:\n");
printf(" - none: use one GPU only\n");
printf(" - layer (default): split layers and KV across GPUs\n");
printf(" - row: split rows across GPUs\n");
printf(" -ts SPLIT, --tensor-split SPLIT\n");
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
#endif // LLAMA_SUPPORTS_GPU_OFFLOAD
if (llama_supports_gpu_offload()) {
printf(" -ngl N, --n-gpu-layers N\n");
printf(" number of layers to store in VRAM\n");
printf(" -ngld N, --n-gpu-layers-draft N\n");
printf(" number of layers to store in VRAM for the draft model\n");
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
printf(" how to split the model across multiple GPUs, one of:\n");
printf(" - none: use one GPU only\n");
printf(" - layer (default): split layers and KV across GPUs\n");
printf(" - row: split rows across GPUs\n");
printf(" -ts SPLIT, --tensor-split SPLIT\n");
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
}
printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false");
printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false");
printf(" -gan N, --grp-attn-n N\n");
@ -1520,7 +1520,9 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false");
fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false");
fprintf(stream, "cpu_has_vulkan: %s\n", ggml_cpu_has_vulkan() ? "true" : "false");
fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false");
fprintf(stream, "cpu_has_kompute: %s\n", ggml_cpu_has_kompute() ? "true" : "false");
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
@ -1649,7 +1651,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES);
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector);
fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);

View File

@ -43,40 +43,40 @@ extern char const *LLAMA_BUILD_TARGET;
int32_t get_num_physical_cores();
struct gpt_params {
uint32_t seed = -1; // RNG seed
uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores();
int32_t n_threads_draft = -1;
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_threads_batch_draft = -1;
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
float p_accept = 0.5f; // speculative decoding accept probability
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_beams = 0; // if non-zero then use beam search of given width.
int32_t grp_attn_n = 1; // group-attention factor
int32_t grp_attn_w = 512; // group-attention width
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
// pinging @cebtenzzre
int32_t n_threads = get_num_physical_cores();
int32_t n_threads_draft = -1;
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_threads_batch_draft = -1;
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
float p_accept = 0.5f; // speculative decoding accept probability
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
int32_t n_beams = 0; // if non-zero then use beam search of given width.
int32_t grp_attn_n = 1; // group-attention factor
int32_t grp_attn_w = 512; // group-attention width
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
// pinging @cebtenzzre
// // sampling parameters
struct llama_sampling_params sparams;

View File

@ -1363,12 +1363,12 @@ bool consume_common_train_arg(
*invalid_param = true;
return true;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params->n_gpu_layers = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
if (llama_supports_gpu_offload()) {
params->n_gpu_layers = std::stoi(argv[i]);
} else {
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
}
} else if (arg == "-h" || arg == "--help") {
params->print_usage = true;
return true;

View File

@ -88,7 +88,7 @@ int main(int argc, char ** argv) {
llama_model_params model_params = llama_model_default_params();
const std::vector<float> t_split (LLAMA_MAX_DEVICES, 0.0f);
const std::vector<float> t_split(llama_max_devices(), 0.0f);
model_params.n_gpu_layers = n_gpu_layers;
model_params.tensor_split = t_split.data();

View File

@ -160,7 +160,7 @@ struct cmd_params {
std::vector<int> main_gpu;
std::vector<bool> no_kv_offload;
std::vector<bool> mul_mat_q;
std::vector<std::array<float, LLAMA_MAX_DEVICES>> tensor_split;
std::vector<std::vector<float>> tensor_split;
int reps;
bool verbose;
output_formats output_format;
@ -179,7 +179,7 @@ static const cmd_params cmd_params_defaults = {
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* mul_mat_q */ {true},
/* tensor_split */ {{}},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* reps */ 5,
/* verbose */ false,
/* output_format */ MARKDOWN
@ -380,10 +380,10 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
const std::regex regex{R"([;/]+)"};
std::sregex_token_iterator it{ts.begin(), ts.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
GGML_ASSERT(split_arg.size() <= llama_max_devices());
std::array<float, LLAMA_MAX_DEVICES> tensor_split;
for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
std::vector<float> tensor_split(llama_max_devices());
for (size_t i = 0; i < llama_max_devices(); ++i) {
if (i < split_arg.size()) {
tensor_split[i] = std::stof(split_arg[i]);
} else {
@ -459,7 +459,7 @@ struct cmd_params_instance {
int main_gpu;
bool no_kv_offload;
bool mul_mat_q;
std::array<float, LLAMA_MAX_DEVICES> tensor_split;
std::vector<float> tensor_split;
llama_model_params to_llama_mparams() const {
llama_model_params mparams = llama_model_default_params();
@ -563,6 +563,7 @@ struct test {
static const bool cuda;
static const bool opencl;
static const bool vulkan;
static const bool kompute;
static const bool metal;
static const bool gpu_blas;
static const bool blas;
@ -581,7 +582,7 @@ struct test {
int main_gpu;
bool no_kv_offload;
bool mul_mat_q;
std::array<float, LLAMA_MAX_DEVICES> tensor_split;
std::vector<float> tensor_split;
int n_prompt;
int n_gen;
std::string test_time;
@ -647,6 +648,9 @@ struct test {
if (vulkan) {
return "Vulkan";
}
if (kompute) {
return "Kompute";
}
if (metal) {
return "Metal";
}
@ -662,7 +666,7 @@ struct test {
static const std::vector<std::string> & get_fields() {
static const std::vector<std::string> fields = {
"build_commit", "build_number",
"cuda", "opencl", "vulkan", "metal", "gpu_blas", "blas",
"cuda", "opencl", "vulkan", "kompute", "metal", "gpu_blas", "blas",
"cpu_info", "gpu_info",
"model_filename", "model_type", "model_size", "model_n_params",
"n_batch", "n_threads", "type_k", "type_v",
@ -686,8 +690,9 @@ struct test {
field == "avg_ns" || field == "stddev_ns") {
return INT;
}
if (field == "cuda" || field == "opencl" || field == "vulkan"|| field == "metal" || field == "gpu_blas" || field == "blas" ||
field == "f16_kv" || field == "no_kv_offload" || field == "mul_mat_q") {
if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "f16_kv" || field == "no_kv_offload" ||
field == "mul_mat_q") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
@ -699,7 +704,7 @@ struct test {
std::vector<std::string> get_values() const {
std::string tensor_split_str;
int max_nonzero = 0;
for (int i = 0; i < LLAMA_MAX_DEVICES; i++) {
for (size_t i = 0; i < llama_max_devices(); i++) {
if (tensor_split[i] > 0) {
max_nonzero = i;
}
@ -714,7 +719,8 @@ struct test {
}
std::vector<std::string> values = {
build_commit, std::to_string(build_number),
std::to_string(cuda), std::to_string(opencl), std::to_string(vulkan), std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas),
std::to_string(cuda), std::to_string(opencl), std::to_string(vulkan), std::to_string(vulkan),
std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas),
cpu_info, gpu_info,
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
@ -743,6 +749,7 @@ const int test::build_number = LLAMA_BUILD_NUMBER;
const bool test::cuda = !!ggml_cpu_has_cublas();
const bool test::opencl = !!ggml_cpu_has_clblast();
const bool test::vulkan = !!ggml_cpu_has_vulkan();
const bool test::kompute = !!ggml_cpu_has_kompute();
const bool test::metal = !!ggml_cpu_has_metal();
const bool test::gpu_blas = !!ggml_cpu_has_gpublas();
const bool test::blas = !!ggml_cpu_has_blas();

View File

@ -111,17 +111,71 @@ llama_print_timings: eval time = 1279.03 ms / 18 runs ( 71.06 m
llama_print_timings: total time = 34570.79 ms
```
## Orin compile and run
### compile
```sh
make LLAMA_CUBLAS=1 CUDA_DOCKER_ARCH=sm_87 LLAMA_CUDA_F16=1 -j 32
```
### run on Orin
### case 1
**input**
```sh
./llava-cli \
-m /data/local/tmp/ggml-model-q4_k.gguf \
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
--image /data/local/tmp/demo.jpeg \
-p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\nWho is the author of this book? \nAnswer the question using a single word or phrase. ASSISTANT:" \
--n-gpu-layers 999
```
**output**
```sh
encode_image_with_clip: image encoded in 296.62 ms by CLIP ( 2.06 ms per image patch)
Susan Wise Bauer
llama_print_timings: load time = 1067.64 ms
llama_print_timings: sample time = 1.53 ms / 6 runs ( 0.25 ms per token, 3934.43 tokens per second)
llama_print_timings: prompt eval time = 306.84 ms / 246 tokens ( 1.25 ms per token, 801.72 tokens per second)
llama_print_timings: eval time = 91.50 ms / 6 runs ( 15.25 ms per token, 65.58 tokens per second)
llama_print_timings: total time = 1352.63 ms / 252 tokens
```
### case 2
**input**
```sh
./llava-cli \
-m /data/local/tmp/ggml-model-q4_k.gguf \
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
-p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\nWhat is in the image? ASSISTANT:" \
--n-gpu-layers 999
```
**output**
```sh
encode_image_with_clip: image encoded in 302.15 ms by CLIP ( 2.10 ms per image patch)
The image features a cat lying in the grass.
llama_print_timings: load time = 1057.07 ms
llama_print_timings: sample time = 3.27 ms / 11 runs ( 0.30 ms per token, 3360.83 tokens per second)
llama_print_timings: prompt eval time = 213.60 ms / 232 tokens ( 0.92 ms per token, 1086.14 tokens per second)
llama_print_timings: eval time = 166.65 ms / 11 runs ( 15.15 ms per token, 66.01 tokens per second)
llama_print_timings: total time = 1365.47 ms / 243 tokens
```
## Minor shortcomings
The `n_patch` of output in `ldp` is 1/4 of the input. In order to implement quickly, we uniformly modified `clip_n_patches` function to a quarter. when counting the time consumption, the calculated time will be 4 times bigger than the real cost.
## TODO
- [ ] Support non-CPU backend for the new operators, such as `depthwise`, `hardswish`, `hardsigmoid`
- [x] Support non-CPU backend for the new operators, such as `depthwise`, `hardswish`, `hardsigmoid`
- [ ] Optimize LDP projector performance
- Optimize the structure definition to avoid unnecessary memory rearrangements, to reduce the use of `ggml_permute_cpy`;
- Optimize operator implementation (ARM CPU/NVIDIA GPU): such as depthwise conv, hardswish, hardsigmoid, etc.
- [ ] run MobileVLM on `Jetson Orin`
- [x] run MobileVLM on `Jetson Orin`
- [ ] Support more model variants, such as `MobileVLM-3B`.

View File

@ -1789,28 +1789,28 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
if (llama_mlock_supported())
if (llama_supports_mlock())
{
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
if (llama_mmap_supported())
if (llama_supports_mmap())
{
printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
}
printf(" --numa attempt optimizations that help on some NUMA systems\n");
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
printf(" -ngl N, --n-gpu-layers N\n");
printf(" number of layers to store in VRAM\n");
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
printf(" how to split the model across multiple GPUs, one of:\n");
printf(" - none: use one GPU only\n");
printf(" - layer (default): split layers and KV across GPUs\n");
printf(" - row: split rows across GPUs\n");
printf(" -ts SPLIT --tensor-split SPLIT\n");
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row)\n");
#endif
if (llama_supports_gpu_offload()) {
printf(" -ngl N, --n-gpu-layers N\n");
printf(" number of layers to store in VRAM\n");
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
printf(" how to split the model across multiple GPUs, one of:\n");
printf(" - none: use one GPU only\n");
printf(" - layer (default): split layers and KV across GPUs\n");
printf(" - row: split rows across GPUs\n");
printf(" -ts SPLIT --tensor-split SPLIT\n");
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row)\n");
}
printf(" -m FNAME, --model FNAME\n");
printf(" model path (default: %s)\n", params.model.c_str());
printf(" -a ALIAS, --alias ALIAS\n");
@ -2066,13 +2066,13 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params.n_gpu_layers = std::stoi(argv[i]);
#else
LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. "
if (llama_supports_gpu_offload()) {
params.n_gpu_layers = std::stoi(argv[i]);
} else {
LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. "
"See main README.md for information on enabling GPU BLAS support",
{{"n_gpu_layers", params.n_gpu_layers}});
#endif
}
}
else if (arg == "--split-mode" || arg == "-sm")
{
@ -2115,9 +2115,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
GGML_ASSERT(split_arg.size() <= llama_max_devices());
for (size_t i_device = 0; i_device < LLAMA_MAX_DEVICES; ++i_device)
for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device)
{
if (i_device < split_arg.size())
{

View File

@ -1,7 +1,9 @@
/*MIT license
Copyright (C) 2024 Intel Corporation
SPDX-License-Identifier: MIT
*/
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
#include "ggml-sycl.h"

View File

@ -0,0 +1,23 @@
:: MIT license
:: Copyright (C) 2024 Intel Corporation
:: SPDX-License-Identifier: MIT
mkdir -p build
cd build
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: for FP16
:: faster for long-prompt inference
:: cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON
:: for FP32
cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release
:: build example/main only
:: make main
:: build all binary
make -j
cd ..

View File

@ -0,0 +1,13 @@
:: MIT license
:: Copyright (C) 2024 Intel Corporation
:: SPDX-License-Identifier: MIT
INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
set GGML_SYCL_DEVICE=0
rem set GGML_SYCL_DEBUG=1
.\build\bin\main.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 33 -s 0

View File

@ -524,6 +524,8 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_SCALE_BLOCK_SIZE 256
@ -540,6 +542,7 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
#define CUDA_PAD_BLOCK_SIZE 256
#define CUDA_ACC_BLOCK_SIZE 256
#define CUDA_IM2COL_BLOCK_SIZE 256
#define CUDA_POOL2D_BLOCK_SIZE 256
#define CUDA_Q8_0_NE_ALIGN 2048
@ -823,6 +826,24 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
dst[i] = fmaxf(x[i], 0);
}
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@ -5823,7 +5844,7 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
}
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.y;
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
@ -6145,9 +6166,10 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
static __global__ void im2col_f32_f16(
const float * x, half * dst,
int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
template <typename T>
static __global__ void im2col_kernel(
const float * x, T * dst, int batch_offset,
int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= pelements) {
@ -6160,21 +6182,73 @@ static __global__ void im2col_f32_f16(
const int ky = (i - kd) / OW;
const int ix = i % OW;
const int oh = blockIdx.y;
const int batch = blockIdx.z / IC;
const int ic = blockIdx.z % IC;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = blockIdx.y * s1 + ky * d1 - p1;
const int64_t iih = oh * s1 + ky * d1 - p1;
const int64_t offset_dst =
(blockIdx.y * OW + ix) * CHW +
(blockIdx.z * (KW * KH) + ky * KW + kx);
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = __float2half(0.0f);
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = blockIdx.z * offset_delta;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
template <typename Ti, typename To>
static __global__ void pool2d_nchw_kernel(
const int ih, const int iw, const int oh, const int ow,
const int kh, const int kw, const int sh, const int sw,
const int ph, const int pw, const int parallel_elements,
const Ti* src, To* dst, const enum ggml_op_pool op) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= parallel_elements) {
return;
}
const int I_HW = ih * iw;
const int O_HW = oh * ow;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / ow;
const int cur_ow = idx % O_HW % ow;
const Ti* i_ptr = src + nc * I_HW;
To* o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * sh - ph;
const int bh = max(0, start_h);
const int eh = min(ih, start_h + kh);
const int start_w = cur_ow * sw - pw;
const int bw = max(0, start_w);
const int ew = min(iw, start_w + kw);
const To scale = 1. / (kh * kw);
To res = 0;
switch (op) {
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
}
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
#if __CUDA_ARCH__ >= 350
Ti cur = __ldg(i_ptr + i * iw + j);
#else
Ti cur = i_ptr[i * iw + j];
#endif
switch (op) {
case GGML_OP_POOL_AVG: res += cur * scale; break;
case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
}
}
}
o_ptr[cur_oh * ow + cur_ow] = res;
}
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
@ -6388,6 +6462,16 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
@ -7475,7 +7559,7 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
@ -7587,14 +7671,15 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
}
}
static void im2col_f32_f16_cuda(const float* x, half* dst,
template <typename T>
static void im2col_cuda(const float* x, T* dst,
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
int offset_delta,
int batch, int batch_offset, int offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OH, IC);
im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
dim3 block_nums(num_blocks, OH, batch * IC);
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
// buffer pool for cuda
@ -8179,6 +8264,34 @@ static void ggml_cuda_op_relu(
(void) src1_dd;
}
static void ggml_cuda_op_hardsigmoid(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
hardsigmoid_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
static void ggml_cuda_op_hardswish(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
hardswish_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
static void ggml_cuda_op_leaky_relu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
@ -8810,13 +8923,46 @@ static void ggml_cuda_op_alibi(
(void) src1_dd;
}
static void ggml_cuda_op_pool2d(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
const int k0 = opts[1];
const int k1 = opts[2];
const int s0 = opts[3];
const int s1 = opts[4];
const int p0 = opts[5];
const int p1 = opts[6];
const int64_t IH = src0->ne[1];
const int64_t IW = src0->ne[0];
const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];
const int parallel_elements = N * OC * OH * OW;
const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
dim3 block_nums(num_blocks);
pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op);
(void) src1;
(void) src1_dd;
}
static void ggml_cuda_op_im2col(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@ -8838,8 +8984,14 @@ static void ggml_cuda_op_im2col(
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[3];
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
if(dst->type == GGML_TYPE_F16) {
im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
} else {
im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
}
(void) src0;
(void) src0_dd;
@ -9435,6 +9587,13 @@ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, g
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
}
static void ggml_cuda_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardsigmoid);
}
static void ggml_cuda_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardswish);
}
static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
}
@ -10220,6 +10379,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
}
static void ggml_cuda_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pool2d);
}
static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
}
@ -10321,6 +10484,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
case GGML_UNARY_OP_RELU:
func = ggml_cuda_relu;
break;
case GGML_UNARY_OP_HARDSIGMOID:
func = ggml_cuda_hardsigmoid;
break;
case GGML_UNARY_OP_HARDSWISH:
func = ggml_cuda_hardswish;
break;
default:
return false;
}
@ -10395,6 +10564,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
case GGML_OP_IM2COL:
func = ggml_cuda_im2col;
break;
case GGML_OP_POOL_2D:
func = ggml_cuda_pool2d;
break;
case GGML_OP_SUM_ROWS:
func = ggml_cuda_sum_rows;
break;
@ -11123,6 +11295,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
return true;
@ -11221,6 +11395,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ROPE:
case GGML_OP_ALIBI:
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:

View File

@ -135,6 +135,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ROPE_F16,
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@ -515,6 +516,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@ -645,6 +647,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_OP_ALIBI:
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
return true;
case GGML_OP_POOL_1D:
case GGML_OP_POOL_2D:
return false;
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARGSORT:
@ -2031,7 +2037,7 @@ static bool ggml_metal_graph_compute(
{
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
@ -2039,6 +2045,7 @@ static bool ggml_metal_graph_compute(
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
const int32_t N = src1->ne[is_2D ? 3 : 2];
@ -2059,8 +2066,8 @@ static bool ggml_metal_graph_compute(
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
switch (dst->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
default: GGML_ASSERT(false);
};

View File

@ -1775,9 +1775,29 @@ kernel void kernel_rope(
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
kernel void kernel_im2col_f16(
typedef void (im2col_t)(
device const float * x,
device half * dst,
device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
constant int32_t & IH,
constant int32_t & CHW,
constant int32_t & s0,
constant int32_t & s1,
constant int32_t & p0,
constant int32_t & p1,
constant int32_t & d0,
constant int32_t & d1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col(
device const float * x,
device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16(
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
pdst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
kernel void kernel_upscale_f32(
device const char * src0,
device char * dst,

View File

@ -1,7 +1,14 @@
/*MIT license
Copyright (C) 2024 Intel Corporation
SPDX-License-Identifier: MIT
*/
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#include <algorithm>
#include <assert.h>

View File

@ -1,7 +1,8 @@
/*MIT license
Copyright (C) 2024 Intel Corporation
SPDX-License-Identifier: MIT
*/
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
#pragma once

File diff suppressed because it is too large Load Diff

View File

@ -817,7 +817,7 @@ static void ggml_vk_load_shaders() {
// mulmat
std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, vk_device.subgroup_size * 2, 64, 2, 4, 4, vk_device.subgroup_size };
std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, vk_device.subgroup_size, 32, 2, 4, 2, vk_device.subgroup_size };
std::initializer_list<uint32_t> warptile_s = { vk_device.subgroup_size, 32, 32, 8, 32, 32, 2, 2, 2, vk_device.subgroup_size };
std::initializer_list<uint32_t> warptile_s = { vk_device.subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, vk_device.subgroup_size };
std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
@ -2873,7 +2873,8 @@ static void ggml_vk_op_f32(vk_context * ctx, const ggml_tensor * src0, const ggm
if (op == GGML_OP_CPY) {
GGML_ASSERT(!transfer_src0);
GGML_ASSERT(!transfer_src1);
d_sz = dst->ne[1] * dst->nb[1];
x_sz = ggml_nbytes(src0);
d_sz = ggml_nbytes(dst);
if (extra->offset + d_sz >= d_D->size) {
d_sz = VK_WHOLE_SIZE;
@ -4556,8 +4557,15 @@ GGML_CALL static bool ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml
}
ggml_vk_preallocate_buffers();
int last_node = cgraph->n_nodes - 1;
// If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
while (last_node > 0 && cgraph->nodes[last_node]->backend != GGML_BACKEND_GPU) {
last_node -= 1;
}
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_vk_build_graph(cgraph->nodes[i], i == cgraph->n_nodes - 1);
ggml_vk_build_graph(cgraph->nodes[i], i == last_node);
}
ggml_compute_params params = {};

135
ggml.c
View File

@ -5413,7 +5413,7 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
int s0,
int p0,
int d0) {
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
@ -5491,16 +5491,15 @@ struct ggml_tensor * ggml_conv_depthwise_2d(
int p1,
int d0,
int d1) {
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
s0, s1, p0, p1, d0, d1, true); // [N * IC, OH, OW, KH * KW]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1), // [OC1, KH, KW] => [1, OC, 1, KH * KW]
ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3])); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC1, KH, KW] => [1, OC, 1, KH * KW]
struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
return result;
@ -5521,7 +5520,8 @@ struct ggml_tensor * ggml_im2col(
int p1,
int d0,
int d1,
bool is_2D) {
bool is_2D,
enum ggml_type dst_type) {
if(is_2D) {
GGML_ASSERT(a->ne[2] == b->ne[2]);
@ -5545,7 +5545,7 @@ struct ggml_tensor * ggml_im2col(
is_2D ? b->ne[3] : 1,
};
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
ggml_set_op_params(result, params, sizeof(params));
@ -5570,7 +5570,7 @@ struct ggml_tensor * ggml_conv_2d(
int p1,
int d0,
int d1) {
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
@ -5696,12 +5696,13 @@ struct ggml_tensor * ggml_pool_2d(
is_node = true;
}
struct ggml_tensor * result;
const int64_t ne[3] = {
ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
a->ne[2],
};
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
ggml_set_op_params(result, params, sizeof(params));
@ -5709,7 +5710,6 @@ struct ggml_tensor * ggml_pool_2d(
result->op = GGML_OP_POOL_2D;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
@ -12608,6 +12608,92 @@ static void ggml_compute_forward_conv_transpose_1d(
}
}
// src0: kernel [OC, IC, KH, KW]
// src1: image [N, IC, IH, IW]
// dst: result [N, OH, OW, IC*KH*KW]
static void ggml_compute_forward_im2col_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
GGML_TENSOR_BINARY_OP_LOCALS;
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
const int ith = params->ith;
const int nth = params->nth;
const int64_t N = is_2D ? ne13 : ne12;
const int64_t IC = is_2D ? ne12 : ne11;
const int64_t IH = is_2D ? ne11 : 1;
const int64_t IW = ne10;
const int64_t KH = is_2D ? ne01 : 1;
const int64_t KW = ne00;
const int64_t OH = is_2D ? ne2 : 1;
const int64_t OW = ne1;
int ofs0 = is_2D ? nb13 : nb12;
int ofs1 = is_2D ? nb12 : nb11;
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float));
if (params->type == GGML_TASK_INIT) {
return;
}
if (params->type == GGML_TASK_FINALIZE) {
return;
}
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
{
float * const wdata = (float *) dst->data;
for (int64_t in = 0; in < N; in++) {
for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
for (int64_t iow = 0; iow < OW; iow++) {
for (int64_t iic = ith; iic < IC; iic += nth) {
// micro kernel
float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
for (int64_t ikw = 0; ikw < KW; ikw++) {
const int64_t iiw = iow*s0 + ikw*d0 - p0;
const int64_t iih = ioh*s1 + ikh*d1 - p1;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
} else {
dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
}
}
}
}
}
}
}
}
}
// src0: kernel [OC, IC, KH, KW]
// src1: image [N, IC, IH, IW]
// dst: result [N, OH, OW, IC*KH*KW]
@ -12698,14 +12784,14 @@ static void ggml_compute_forward_im2col(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
switch (dst->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_im2col_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
GGML_ASSERT(false);
ggml_compute_forward_im2col_f32(params, src0, src1, dst);
} break;
default:
{
@ -12896,8 +12982,8 @@ static void ggml_compute_forward_pool_2d(
const struct ggml_compute_params * params,
const struct ggml_tensor * src,
struct ggml_tensor * dst) {
assert(src->type == GGML_TYPE_F32);
assert(params->ith == 0);
GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(params->ith == 0);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
@ -17297,12 +17383,16 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
struct ggml_cplan cplan;
memset(&cplan, 0, sizeof(struct ggml_cplan));
int max_tasks = 1;
// thread scheduling for the different operations + work buffer size estimation
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
const int n_tasks = ggml_get_n_tasks(node, n_threads);
max_tasks = MAX(max_tasks, n_tasks);
size_t cur = 0;
switch (node->op) {
@ -17475,7 +17565,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
work_size += CACHE_LINE_SIZE*(n_threads - 1);
}
cplan.n_threads = n_threads;
cplan.n_threads = MIN(max_tasks, n_threads);
cplan.work_size = work_size;
cplan.work_data = NULL;
@ -20791,6 +20881,14 @@ int ggml_cpu_has_vulkan(void) {
#endif
}
int ggml_cpu_has_kompute(void) {
#if defined(GGML_USE_KOMPUTE)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_sycl(void) {
#if defined(GGML_USE_SYCL)
return 1;
@ -20800,7 +20898,8 @@ int ggml_cpu_has_sycl(void) {
}
int ggml_cpu_has_gpublas(void) {
return ggml_cpu_has_cublas() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_sycl();
return ggml_cpu_has_cublas() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
ggml_cpu_has_sycl();
}
int ggml_cpu_has_sse3(void) {

4
ggml.h
View File

@ -1496,7 +1496,8 @@ extern "C" {
int p1,
int d0,
int d1,
bool is_2D);
bool is_2D,
enum ggml_type dst_type);
GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
struct ggml_context * ctx,
@ -2284,6 +2285,7 @@ extern "C" {
GGML_API int ggml_cpu_has_cublas (void);
GGML_API int ggml_cpu_has_clblast (void);
GGML_API int ggml_cpu_has_vulkan (void);
GGML_API int ggml_cpu_has_kompute (void);
GGML_API int ggml_cpu_has_gpublas (void);
GGML_API int ggml_cpu_has_sse3 (void);
GGML_API int ggml_cpu_has_ssse3 (void);

View File

@ -19,8 +19,8 @@ shader_int8_ext = """
# Type-specific defines
shader_f16_defines = """
#define QUANT_K 32
#define QUANT_R 2
#define QUANT_K 1
#define QUANT_R 1
#define A_TYPE float16_t
"""

287
llama.cpp
View File

@ -2715,10 +2715,10 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XSS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XXS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
case LLAMA_FTYPE_MOSTLY_Q3_K_XS:return "Q3_K - Extra small";
case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XSS - 3.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw";
default: return "unknown, may not work";
}
@ -4702,126 +4702,6 @@ struct llm_build_context {
ctx0 = nullptr;
}
}
struct ggml_cgraph * build_orion() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
cb(inp_pos, "inp_pos", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, model.layers[il].attn_norm_b,
LLM_NORM, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
// if (model.layers[il].bq) {
// Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
// cb(Qcur, "Qcur", il);
// }
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
// if (model.layers[il].bk) {
// Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
// cb(Kcur, "Kcur", il);
// }
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
// if (model.layers[il].bv) {
// Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
// cb(Vcur, "Vcur", il);
// }
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
LLM_NORM, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, model.output_norm_b,
LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@ -6625,6 +6505,125 @@ struct llm_build_context {
return gf;
}
struct ggml_cgraph * build_orion() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
cb(inp_pos, "inp_pos", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, model.layers[il].attn_norm_b,
LLM_NORM, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
// if (model.layers[il].bq) {
// Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
// cb(Qcur, "Qcur", il);
// }
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
// if (model.layers[il].bk) {
// Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
// cb(Kcur, "Kcur", il);
// }
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
// if (model.layers[il].bv) {
// Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
// cb(Vcur, "Vcur", il);
// }
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
LLM_NORM, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, model.output_norm_b,
LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
};
static struct ggml_cgraph * llama_build_graph(
@ -6914,11 +6913,6 @@ static int llama_decode_internal(
n_threads = std::min(4, n_threads);
}
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1;
if ((ggml_cpu_has_cublas() || ggml_cpu_has_vulkan()) && fully_offloaded) {
n_threads = 1;
}
#ifdef GGML_USE_MPI
const int64_t n_layer = hparams.n_layer;
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
@ -10131,18 +10125,45 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
return result;
}
int32_t llama_max_devices(void) {
return LLAMA_MAX_DEVICES;
size_t llama_max_devices(void) {
#if defined(GGML_USE_METAL)
return 1;
#elif defined(GGML_USE_CUBLAS)
return GGML_CUDA_MAX_DEVICES;
#elif defined(GGML_USE_SYCL)
return GGML_SYCL_MAX_DEVICES;
#else
return 1;
#endif
}
bool llama_mmap_supported(void) {
bool llama_supports_mmap(void) {
return llama_mmap::SUPPORTED;
}
bool llama_mlock_supported(void) {
bool llama_supports_mlock(void) {
return llama_mlock::SUPPORTED;
}
bool llama_supports_gpu_offload(void) {
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
return true;
#else
return false;
#endif
}
// deprecated:
bool llama_mmap_supported(void) {
return llama_supports_mmap();
}
bool llama_mlock_supported(void) {
return llama_supports_mlock();
}
void llama_backend_init(bool numa) {
ggml_time_init();
@ -10174,8 +10195,8 @@ int64_t llama_time_us(void) {
}
struct llama_model * llama_load_model_from_file(
const char * path_model,
struct llama_model_params params) {
const char * path_model,
struct llama_model_params params) {
ggml_time_init();
llama_model * model = new llama_model;

29
llama.h
View File

@ -3,15 +3,7 @@
#include "ggml.h"
#include "ggml-backend.h"
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
#elif defined(GGML_USE_SYCL)
#include "ggml-sycl.h"
#define LLAMA_MAX_DEVICES GGML_SYCL_MAX_DEVICES
#else
#define LLAMA_MAX_DEVICES 1
#endif // GGML_USE_CUBLAS
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
@ -49,12 +41,6 @@
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 4
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
#define LLAMA_SUPPORTS_GPU_OFFLOAD
#endif
#ifdef __cplusplus
extern "C" {
#endif
@ -201,7 +187,7 @@ extern "C" {
// LLAMA_SPLIT_LAYER: ignored
int32_t main_gpu;
// proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
const float * tensor_split;
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
@ -338,9 +324,14 @@ extern "C" {
LLAMA_API int64_t llama_time_us(void);
LLAMA_API int32_t llama_max_devices(void);
LLAMA_API bool llama_mmap_supported (void);
LLAMA_API bool llama_mlock_supported(void);
LLAMA_API size_t llama_max_devices(void);
LLAMA_API bool llama_supports_mmap (void);
LLAMA_API bool llama_supports_mlock (void);
LLAMA_API bool llama_supports_gpu_offload(void);
LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead");
LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead");
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);

View File

@ -0,0 +1,19 @@
:: MIT license
:: Copyright (C) 2024 Intel Corporation
:: SPDX-License-Identifier: MIT
set URL=%1
set COMPONENTS=%2
curl.exe --output %TEMP%\webimage.exe --url %URL% --retry 5 --retry-delay 5
start /b /wait %TEMP%\webimage.exe -s -x -f webimage_extracted --log extract.log
del %TEMP%\webimage.exe
if "%COMPONENTS%"=="" (
webimage_extracted\bootstrapper.exe -s --action install --eula=accept -p=NEED_VS2017_INTEGRATION=0 -p=NEED_VS2019_INTEGRATION=0 -p=NEED_VS2022_INTEGRATION=0 --log-dir=.
) else (
webimage_extracted\bootstrapper.exe -s --action install --components=%COMPONENTS% --eula=accept -p=NEED_VS2017_INTEGRATION=0 -p=NEED_VS2019_INTEGRATION=0 -p=NEED_VS2022_INTEGRATION=0 --log-dir=.
)
set installer_exit_code=%ERRORLEVEL%
rd /s/q "webimage_extracted"
exit /b %installer_exit_code%

View File

@ -227,6 +227,14 @@ static std::string var_to_str(ggml_type type) {
return ggml_type_name(type);
}
static std::string var_to_str(ggml_op_pool pool) {
switch (pool) {
case GGML_OP_POOL_AVG: return "avg";
case GGML_OP_POOL_MAX: return "max";
default: return std::to_string(pool);
}
}
#define VARS_TO_STR1(a) VAR_TO_STR(a)
#define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
#define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
@ -238,6 +246,7 @@ static std::string var_to_str(ggml_type type) {
#define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + "," + VARS_TO_STR8(b, c, d, e, f, g, h, i)
#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
#define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
#ifdef GGML_USE_SYCL
static bool inline _isinf(float f) {
@ -1162,10 +1171,45 @@ struct test_alibi : public test_case {
}
};
// GGML_OP_POOL2D
struct test_pool2d : public test_case {
enum ggml_op_pool pool_type;
const ggml_type type_input;
const std::array<int64_t, 4> ne_input;
// kernel size
const int k0;
const int k1;
// stride
const int s0;
const int s1;
// padding
const int p0;
const int p1;
std::string vars() override {
return VARS_TO_STR9(pool_type, type_input, ne_input, k0, k1, s0, s1, p0, p1);
}
test_pool2d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,
ggml_type type_input = GGML_TYPE_F32,
std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
int k0 = 3, int k1 = 3,
int s0 = 1, int s1 = 1,
int p0 = 1, int p1 = 1)
: pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), k1(k1), s0(s0), s1(s1), p0(p0), p1(p1) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);
return out;
}
};
// GGML_OP_IM2COL
struct test_im2col : public test_case {
const ggml_type type_input;
const ggml_type type_kernel;
const ggml_type dst_type;
const std::array<int64_t, 4> ne_input;
const std::array<int64_t, 4> ne_kernel;
// stride
@ -1181,22 +1225,22 @@ struct test_im2col : public test_case {
const bool is_2D;
std::string vars() override {
return VARS_TO_STR11(type_input, type_kernel, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
return VARS_TO_STR12(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
}
test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16,
test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
int s0 = 1, int s1 = 1,
int p0 = 1, int p1 = 1,
int d0 = 1, int d1 = 1,
bool is_2D = true)
: type_input(type_input), type_kernel(type_kernel), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}
: type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D);
ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, dst_type);
return out;
}
};
@ -1982,6 +2026,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
}
for (ggml_type type_input : {GGML_TYPE_F32}) {
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
for (int k0 : {1, 3}) {
for (int k1 : {1, 3}) {
for (int s0 : {1, 2}) {
for (int s1 : {1, 2}) {
for (int p0 : {0, 1}) {
for (int p1 : {0, 1}) {
test_cases.emplace_back(new test_pool2d(pool_type, type_input, {10, 10, 3, 1}, k0, k1, s0, s1, p0, p1));
}
}
}
}
}
}
}
}
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));
@ -2119,7 +2184,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
test_cases.emplace_back(new test_alibi());
test_cases.emplace_back(new test_im2col());
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
test_cases.emplace_back(new test_concat(GGML_TYPE_I32));