diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c36eedb01..9e1d14ff5 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -850,8 +850,8 @@ kernel void kernel_ssm_scan_f32( device const int32_t * ids = (device const int32_t *) src7; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} @@ -935,8 +935,8 @@ kernel void kernel_ssm_scan_f32_group( device const int32_t * ids = (device const int32_t *) src7; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns}