mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
metal : add kernel_get_rows_i32
ggml-ci
This commit is contained in:
parent
ab62fc3e55
commit
289313716f
@ -87,6 +87,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
||||||
|
GGML_METAL_DECL_KERNEL(get_rows_i32);
|
||||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DECL_KERNEL(group_norm);
|
GGML_METAL_DECL_KERNEL(group_norm);
|
||||||
GGML_METAL_DECL_KERNEL(norm);
|
GGML_METAL_DECL_KERNEL(norm);
|
||||||
@ -377,6 +378,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
||||||
|
GGML_METAL_ADD_KERNEL(get_rows_i32);
|
||||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||||
GGML_METAL_ADD_KERNEL(group_norm);
|
GGML_METAL_ADD_KERNEL(group_norm);
|
||||||
GGML_METAL_ADD_KERNEL(norm);
|
GGML_METAL_ADD_KERNEL(norm);
|
||||||
@ -499,6 +501,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
||||||
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
||||||
|
GGML_METAL_DEL_KERNEL(get_rows_i32);
|
||||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DEL_KERNEL(group_norm);
|
GGML_METAL_DEL_KERNEL(group_norm);
|
||||||
GGML_METAL_DEL_KERNEL(norm);
|
GGML_METAL_DEL_KERNEL(norm);
|
||||||
@ -1978,6 +1981,7 @@ void ggml_metal_graph_compute(
|
|||||||
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
||||||
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
||||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
||||||
|
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3829,6 +3829,35 @@ kernel void kernel_get_rows_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_get_rows_i32(
|
||||||
|
device const void * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device int32_t * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant uint64_t & nb2,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
|
uint3 tptg [[threads_per_threadgroup]]) {
|
||||||
|
const int64_t i10 = tgpig.x;
|
||||||
|
const int64_t i11 = tgpig.y;
|
||||||
|
|
||||||
|
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||||
|
|
||||||
|
const int64_t i02 = i11;
|
||||||
|
|
||||||
|
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
||||||
|
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||||
|
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
||||||
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
||||||
#define BLOCK_SIZE_K 32
|
#define BLOCK_SIZE_K 32
|
||||||
|
Loading…
Reference in New Issue
Block a user