cont : mul mat vec

This commit is contained in:
Georgi Gerganov 2024-11-09 22:56:39 +02:00
parent 626e126e48
commit c34f614b39
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 402 additions and 1069 deletions

View File

@ -487,6 +487,49 @@ typedef struct {
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mm;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mv;
typedef struct {
int32_t nei0;
int32_t nei1;
uint64_t nbi1;
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int32_t ne0;
int32_t ne1;
uint64_t nb1;
} ggml_metal_kargs_mul_mv_id;
#endif
#endif // GGML_COMMON_DECL

View File

@ -2159,28 +2159,32 @@ static void ggml_metal_encode_node(
}
};
ggml_metal_kargs_mul_mv args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
[encoder setBytes:&args length:sizeof(args) atIndex:3];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
@ -2472,30 +2476,34 @@ static void ggml_metal_encode_node(
GGML_ASSERT(ne00 >= nth0*nth1);
}
ggml_metal_kargs_mul_mv_id args = {
/*.nei0 =*/ ne20,
/*.nei1 =*/ ne21,
/*.nbi1 =*/ nb21,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.nb1 =*/ nb1,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
[encoder setBytes:&args length:sizeof(args) atIndex:4];
const int64_t _ne1 = 1;
const int tgz = dst_rows;

File diff suppressed because it is too large Load Diff