metal : minor fixup in FA kernel (#10143)
Some checks failed
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-musa.Dockerfile platforms:linux/amd64 tag:full-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-musa.Dockerfile platforms:linux/amd64 tag:light-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-musa.Dockerfile platforms:linux/amd64 tag:server-musa]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
Nix aarch64 builds / nix-build-aarch64 (push) Has been cancelled

* metal : minor fixup in FA kernel

ggml-ci

* metal : use the unrolled loop variable

* metal : remove unused var
This commit is contained in:
Georgi Gerganov 2024-11-03 15:18:40 +02:00 committed by GitHub
parent 1839f69130
commit 08828a6d7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iv3 = iq3 / rv3; const short iv3 = iq3 / rv3;
// load the queries from shared memory into local memory // load the queries from shared memory into local memory
float4 mq[D4]; float4 mq[D4/NW];
for (short ii = 0; ii < D4; ii += NW) { for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg; short i = ii + tiisg;
mq[i] = (float4) sq4[i]; mq[ii/NW] = (float4) sq4[i];
} }
// pointer to the mask // pointer to the mask
@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
mk[2] = (float4) pk4[i + 2*(nb11/8)]; mk[2] = (float4) pk4[i + 2*(nb11/8)];
mk[3] = (float4) pk4[i + 3*(nb11/8)]; mk[3] = (float4) pk4[i + 3*(nb11/8)];
mqk += (float4) (mq[i] * mk); mqk += (float4) (mq[ii/NW] * mk);
} }
// reduce the results from the threads in the simdgroup // reduce the results from the threads in the simdgroup
@ -2857,8 +2857,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
// O = diag(ms)*O // O = diag(ms)*O
#pragma unroll #pragma unroll
for (short ii = 0; ii < D4; ii += NW) { for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg; lo[ii/NW] *= ms;
lo[i/NW] *= ms;
} }
} }
@ -2872,10 +2871,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
for (short ii = 0; ii < D4; ii += NW) { for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg; const short i = ii + tiisg;
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
} }
} }
} }