mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-15 15:29:53 +00:00
metal : fix SSM_SCAN state head offset
Some checks are pending
flake8 Lint / Lint (push) Waiting to run
Some checks are pending
flake8 Lint / Lint (push) Waiting to run
This commit is contained in:
parent
8b15bc6fa0
commit
5b8ec2b978
@ -850,8 +850,8 @@ kernel void kernel_ssm_scan_f32(
|
|||||||
|
|
||||||
device const int32_t * ids = (device const int32_t *) src7;
|
device const int32_t * ids = (device const int32_t *) src7;
|
||||||
|
|
||||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03);
|
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
|
||||||
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off);
|
device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);
|
||||||
|
|
||||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||||
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
|
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
|
||||||
@ -935,8 +935,8 @@ kernel void kernel_ssm_scan_f32_group(
|
|||||||
|
|
||||||
device const int32_t * ids = (device const int32_t *) src7;
|
device const int32_t * ids = (device const int32_t *) src7;
|
||||||
|
|
||||||
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03);
|
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03);
|
||||||
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off);
|
device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off);
|
||||||
|
|
||||||
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
||||||
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
|
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}
|
||||||
|
Loading…
Reference in New Issue
Block a user