mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
fix memcpy() crash, add missed cmd in guide, fix softmax (#6622)
* disable mmap to fix memcpy crash, add missed cmd in guide, fix softmax * refactor to disable mmap for SYCL backend * fix compile error in other os * refactor the solution, use host buf to fix it, instead of disable mmap * keep to support mmap() * use host buff to reduce malloc times * revert to malloc/free solution, for threaad safe
This commit is contained in:
parent
b5e7285baf
commit
de17e3f745
@ -68,7 +68,7 @@ It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS,
|
|||||||
|
|
||||||
| Intel GPU | Status | Verified Model |
|
| Intel GPU | Status | Verified Model |
|
||||||
|-------------------------------|---------|---------------------------------------|
|
|-------------------------------|---------|---------------------------------------|
|
||||||
| Intel Data Center Max Series | Support | Max 1550 |
|
| Intel Data Center Max Series | Support | Max 1550, 1100 |
|
||||||
| Intel Data Center Flex Series | Support | Flex 170 |
|
| Intel Data Center Flex Series | Support | Flex 170 |
|
||||||
| Intel Arc Series | Support | Arc 770, 730M |
|
| Intel Arc Series | Support | Arc 770, 730M |
|
||||||
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake |
|
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake |
|
||||||
@ -84,8 +84,7 @@ It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS,
|
|||||||
- **Execution Unit (EU)**
|
- **Execution Unit (EU)**
|
||||||
- If the iGPU has less than 80 EUs, the inference speed will likely be too slow for practical use.
|
- If the iGPU has less than 80 EUs, the inference speed will likely be too slow for practical use.
|
||||||
|
|
||||||
### Nvidia GPU
|
### Other Vendor GPU
|
||||||
The BLAS acceleration on Nvidia GPU through oneAPI can be obtained using the Nvidia plugins for oneAPI and the cuBLAS backend of the upstream oneMKL library. Details and instructions on how to setup the runtime and library can be found in [this section](#i-setup-environment)
|
|
||||||
|
|
||||||
**Verified devices**
|
**Verified devices**
|
||||||
|
|
||||||
@ -94,14 +93,9 @@ The BLAS acceleration on Nvidia GPU through oneAPI can be obtained using the Nvi
|
|||||||
| Ampere Series | Support | A100, A4000 |
|
| Ampere Series | Support | A100, A4000 |
|
||||||
| Ampere Series *(Mobile)* | Support | RTX 40 Series |
|
| Ampere Series *(Mobile)* | Support | RTX 40 Series |
|
||||||
|
|
||||||
*Notes:*
|
|
||||||
- Support for Nvidia targets through oneAPI is currently limited to Linux platforms.
|
|
||||||
|
|
||||||
- Please make sure the native oneAPI MKL *(dedicated to intel CPUs and GPUs)* is not "visible" at this stage to properly setup and use the built-from-source oneMKL with cuBLAS backend in llama.cpp for Nvidia GPUs.
|
|
||||||
|
|
||||||
|
|
||||||
## Docker
|
## Docker
|
||||||
The docker build option is currently limited to *intel GPU* targets.
|
The docker build option is currently limited to *intel GPU* targets.
|
||||||
|
|
||||||
### Build image
|
### Build image
|
||||||
```sh
|
```sh
|
||||||
# Using FP16
|
# Using FP16
|
||||||
@ -168,29 +162,10 @@ Platform #0: Intel(R) OpenCL HD Graphics
|
|||||||
- **Nvidia GPU**
|
- **Nvidia GPU**
|
||||||
|
|
||||||
In order to target Nvidia GPUs through SYCL, please make sure the CUDA/CUBLAS native requirements *-found [here](README.md#cuda)-* are installed.
|
In order to target Nvidia GPUs through SYCL, please make sure the CUDA/CUBLAS native requirements *-found [here](README.md#cuda)-* are installed.
|
||||||
Installation can be verified by running the following:
|
|
||||||
```sh
|
|
||||||
nvidia-smi
|
|
||||||
```
|
|
||||||
Please make sure at least one CUDA device is available, which can be displayed like this *(here an A100-40GB Nvidia GPU)*:
|
|
||||||
```
|
|
||||||
+---------------------------------------------------------------------------------------+
|
|
||||||
| NVIDIA-SMI 535.54.03 Driver Version: 535.54.03 CUDA Version: 12.2 |
|
|
||||||
|-----------------------------------------+----------------------+----------------------+
|
|
||||||
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
|
|
||||||
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
|
|
||||||
| | | MIG M. |
|
|
||||||
|=========================================+======================+======================|
|
|
||||||
| 0 NVIDIA A100-PCIE-40GB On | 00000000:8D:00.0 Off | 0 |
|
|
||||||
| N/A 36C P0 57W / 250W | 4MiB / 40960MiB | 0% Default |
|
|
||||||
| | | Disabled |
|
|
||||||
+-----------------------------------------+----------------------+----------------------+
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
2. **Install Intel® oneAPI Base toolkit**
|
2. **Install Intel® oneAPI Base toolkit**
|
||||||
|
|
||||||
- **Base installation**
|
- **For Intel GPU**
|
||||||
|
|
||||||
The base toolkit can be obtained from the official [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page.
|
The base toolkit can be obtained from the official [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page.
|
||||||
|
|
||||||
@ -202,10 +177,10 @@ Upon a successful installation, SYCL is enabled for the available intel devices,
|
|||||||
|
|
||||||
- **Adding support to Nvidia GPUs**
|
- **Adding support to Nvidia GPUs**
|
||||||
|
|
||||||
**oneAPI**: In order to enable SYCL support on Nvidia GPUs, please install the [Codeplay oneAPI Plugin for Nvidia GPUs](https://developer.codeplay.com/products/oneapi/nvidia/download). User should also make sure the plugin version matches the installed base toolkit one *(previous step)* for a seamless "oneAPI on Nvidia GPU" setup.
|
**oneAPI Plugin**: In order to enable SYCL support on Nvidia GPUs, please install the [Codeplay oneAPI Plugin for Nvidia GPUs](https://developer.codeplay.com/products/oneapi/nvidia/download). User should also make sure the plugin version matches the installed base toolkit one *(previous step)* for a seamless "oneAPI on Nvidia GPU" setup.
|
||||||
|
|
||||||
|
|
||||||
**oneMKL**: The current oneMKL releases *(shipped with the oneAPI base-toolkit)* do not contain the cuBLAS backend. A build from source of the upstream [oneMKL](https://github.com/oneapi-src/oneMKL) with the *cuBLAS* backend enabled is thus required to run it on Nvidia GPUs.
|
**oneMKL for cuBlas**: The current oneMKL releases *(shipped with the oneAPI base-toolkit)* do not contain the cuBLAS backend. A build from source of the upstream [oneMKL](https://github.com/oneapi-src/oneMKL) with the *cuBLAS* backend enabled is thus required to run it on Nvidia GPUs.
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
git clone https://github.com/oneapi-src/oneMKL
|
git clone https://github.com/oneapi-src/oneMKL
|
||||||
@ -237,7 +212,7 @@ When targeting an intel GPU, the user should expect one or more level-zero devic
|
|||||||
|
|
||||||
- **Nvidia GPU**
|
- **Nvidia GPU**
|
||||||
|
|
||||||
Similarly, user targetting Nvidia GPUs should expect at least one SYCL-CUDA device [`ext_oneapi_cuda:gpu`] as bellow:
|
Similarly, user targeting Nvidia GPUs should expect at least one SYCL-CUDA device [`ext_oneapi_cuda:gpu`] as bellow:
|
||||||
```
|
```
|
||||||
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.12.0.12_195853.xmain-hotfix]
|
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.12.0.12_195853.xmain-hotfix]
|
||||||
[opencl:cpu:1] Intel(R) OpenCL, Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz OpenCL 3.0 (Build 0) [2023.16.12.0.12_195853.xmain-hotfix]
|
[opencl:cpu:1] Intel(R) OpenCL, Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz OpenCL 3.0 (Build 0) [2023.16.12.0.12_195853.xmain-hotfix]
|
||||||
@ -260,6 +235,9 @@ cmake --build .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icp
|
|||||||
|
|
||||||
# Option 2: Use FP32 by default
|
# Option 2: Use FP32 by default
|
||||||
cmake --build .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
cmake --build .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
||||||
|
|
||||||
|
#build all binary
|
||||||
|
cmake --build . --config Release -j -v
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Nvidia GPU
|
#### Nvidia GPU
|
||||||
@ -278,6 +256,10 @@ cmake --build .. -DLLAMA_SYCL=ON -DLLAMA_SYCL_TARGET=NVIDIA -DCMAKE_C_COMPILER=i
|
|||||||
|
|
||||||
# Option 2: Use FP32 by default
|
# Option 2: Use FP32 by default
|
||||||
cmake --build .. -DLLAMA_SYCL=ON -DLLAMA_SYCL_TARGET=NVIDIA -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
cmake --build .. -DLLAMA_SYCL=ON -DLLAMA_SYCL_TARGET=NVIDIA -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
||||||
|
|
||||||
|
#build all binary
|
||||||
|
cmake --build . --config Release -j -v
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### III. Run the inference
|
### III. Run the inference
|
||||||
@ -357,7 +339,6 @@ Otherwise, you can run the script:
|
|||||||
|
|
||||||
*Notes:*
|
*Notes:*
|
||||||
|
|
||||||
- By default, `mmap` is used to read the model file. In some cases, it causes runtime hang issues. Please disable it by passing `--no-mmap` to the `/bin/main` if faced with the issue.
|
|
||||||
- Upon execution, verify the selected device(s) ID(s) in the output log, which can for instance be displayed as follow:
|
- Upon execution, verify the selected device(s) ID(s) in the output log, which can for instance be displayed as follow:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
@ -438,7 +419,7 @@ cd build
|
|||||||
|
|
||||||
cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON
|
cmake -G "MinGW Makefiles" .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icx -DCMAKE_BUILD_TYPE=Release -DLLAMA_SYCL_F16=ON
|
||||||
|
|
||||||
make
|
make -j
|
||||||
```
|
```
|
||||||
|
|
||||||
Otherwise, run the `win-build-sycl.bat` wrapper which encapsulates the former instructions:
|
Otherwise, run the `win-build-sycl.bat` wrapper which encapsulates the former instructions:
|
||||||
@ -525,7 +506,6 @@ Otherwise, run the following wrapper script:
|
|||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
|
||||||
- By default, `mmap` is used to read the model file. In some cases, it causes runtime hang issues. Please disable it by passing `--no-mmap` to the `main.exe` if faced with the issue.
|
|
||||||
- Upon execution, verify the selected device(s) ID(s) in the output log, which can for instance be displayed as follow:
|
- Upon execution, verify the selected device(s) ID(s) in the output log, which can for instance be displayed as follow:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
@ -557,12 +537,6 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
|||||||
|
|
||||||
## Known Issues
|
## Known Issues
|
||||||
|
|
||||||
- Hanging during startup
|
|
||||||
|
|
||||||
llama.cpp uses *mmap* as the default mode for reading the model file and copying it to the GPU. In some systems, `memcpy` might behave abnormally and therefore hang.
|
|
||||||
|
|
||||||
- **Solution**: add `--no-mmap` or `--mmap 0` flag to the `main` executable.
|
|
||||||
|
|
||||||
- `Split-mode:[row]` is not supported.
|
- `Split-mode:[row]` is not supported.
|
||||||
|
|
||||||
## Q&A
|
## Q&A
|
||||||
@ -574,7 +548,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|
|||||||
|
|
||||||
- General compiler error:
|
- General compiler error:
|
||||||
|
|
||||||
- Remove build folder or try a clean-build.
|
- Remove **build** folder or try a clean-build.
|
||||||
|
|
||||||
- I can **not** see `[ext_oneapi_level_zero:gpu]` afer installing the GPU driver on Linux.
|
- I can **not** see `[ext_oneapi_level_zero:gpu]` afer installing the GPU driver on Linux.
|
||||||
|
|
||||||
|
@ -20,4 +20,4 @@ cmake .. -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
|||||||
#cmake --build . --config Release --target llama-bench
|
#cmake --build . --config Release --target llama-bench
|
||||||
|
|
||||||
#build all binary
|
#build all binary
|
||||||
cmake --build . --config Release -v
|
cmake --build . --config Release -j -v
|
||||||
|
@ -12,6 +12,7 @@ if [ $# -gt 0 ]; then
|
|||||||
GGML_SYCL_SINGLE_GPU=1
|
GGML_SYCL_SINGLE_GPU=1
|
||||||
else
|
else
|
||||||
GGML_SYCL_DEVICE=0
|
GGML_SYCL_DEVICE=0
|
||||||
|
GGML_SYCL_SINGLE_GPU=0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
#export GGML_SYCL_DEBUG=1
|
#export GGML_SYCL_DEBUG=1
|
||||||
|
@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
|
|||||||
#define SYCL_SCALE_BLOCK_SIZE 256
|
#define SYCL_SCALE_BLOCK_SIZE 256
|
||||||
#define SYCL_CLAMP_BLOCK_SIZE 256
|
#define SYCL_CLAMP_BLOCK_SIZE 256
|
||||||
#define SYCL_ROPE_BLOCK_SIZE 256
|
#define SYCL_ROPE_BLOCK_SIZE 256
|
||||||
#define SYCL_SOFT_MAX_BLOCK_SIZE 1024
|
|
||||||
#define SYCL_ALIBI_BLOCK_SIZE 32
|
#define SYCL_ALIBI_BLOCK_SIZE 32
|
||||||
#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
|
#define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
|
||||||
#define SYCL_QUANTIZE_BLOCK_SIZE 256
|
#define SYCL_QUANTIZE_BLOCK_SIZE 256
|
||||||
@ -13080,11 +13079,13 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
|
|||||||
const int nrows_y, const float scale, const float max_bias,
|
const int nrows_y, const float scale, const float max_bias,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
while (nth < ncols_x && nth < SYCL_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
int max_block_size = g_work_group_size;
|
||||||
|
while (nth < ncols_x && nth < max_block_size) nth *= 2;
|
||||||
|
if (nth>max_block_size) nth = max_block_size;
|
||||||
|
|
||||||
const sycl::range<3> block_dims(1, 1, nth);
|
const sycl::range<3> block_dims(1, 1, nth);
|
||||||
const sycl::range<3> block_nums(1, 1, nrows_x);
|
const sycl::range<3> block_nums(1, 1, nrows_x);
|
||||||
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
|
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
|
||||||
static_assert(SYCL_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
|
||||||
|
|
||||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||||
@ -13094,6 +13095,12 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
|
|||||||
|
|
||||||
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
||||||
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
||||||
|
if (ncols_x > max_block_size) {
|
||||||
|
soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||||
|
max_bias, m0, m1, n_head_log2, block_nums,
|
||||||
|
block_dims, n_local_scratch, stream);
|
||||||
|
return;
|
||||||
|
}
|
||||||
switch (ncols_x) {
|
switch (ncols_x) {
|
||||||
case 32:
|
case 32:
|
||||||
soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
|
||||||
@ -16814,11 +16821,13 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||||||
const dpct::queue_ptr stream = g_syclStreams[ctx->device][0];
|
const dpct::queue_ptr stream = g_syclStreams[ctx->device][0];
|
||||||
SYCL_CHECK(
|
SYCL_CHECK(
|
||||||
CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
|
CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
|
||||||
|
char* host_buf = (char*)malloc(size);
|
||||||
|
memcpy(host_buf, data, size);
|
||||||
SYCL_CHECK(
|
SYCL_CHECK(
|
||||||
CHECK_TRY_ERROR((*stream)
|
CHECK_TRY_ERROR((*stream)
|
||||||
.memcpy((char *)tensor->data + offset, data, size)
|
.memcpy((char *)tensor->data + offset, host_buf, size)
|
||||||
.wait()));
|
.wait()));
|
||||||
|
free(host_buf);
|
||||||
}
|
}
|
||||||
catch (sycl::exception const &exc) {
|
catch (sycl::exception const &exc) {
|
||||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||||
|
Loading…
Reference in New Issue
Block a user