ggml : sync (abort callback, mul / add broadcast, fix alibi) (#2183)

This commit is contained in:
Georgi Gerganov 2023-07-11 22:53:34 +03:00 committed by GitHub
parent 5bf2a27718
commit 20d7740a9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 173 additions and 72 deletions

View File

@ -239,13 +239,13 @@ struct ggml_tensor_extra_gpu {
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
}; };
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) { if (i >= kx) {
return; return;
} }
dst[i] = x[i] + y[i]; dst[i] = x[i] + y[i%ky];
} }
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
@ -275,16 +275,46 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i])); dst[i] = x[i] / (1.0f + expf(-x[i]));
} }
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const float eps = 1e-5f;
float mean = 0.0f;
float var = 0.0f;
for (int col = tid; col < ncols; col += WARP_SIZE) {
const float xi = x[row*ncols + col];
mean += xi;
var += xi * xi;
}
// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
var += __shfl_xor_sync(0xffffffff, var, mask, 32);
}
mean /= ncols;
var = var / ncols - mean * mean;
const float inv_var = rsqrtf(var + eps);
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
}
}
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) { static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const float eps = 1e-6; const float eps = 1e-6f;
float tmp = 0.0f; // partial sum for thread in warp float tmp = 0.0f; // partial sum for thread in warp
for (int i = 0; i < ncols; i += WARP_SIZE) { for (int col = tid; col < ncols; col += WARP_SIZE) {
const int col = i + tid;
const float xi = x[row*ncols + col]; const float xi = x[row*ncols + col];
tmp += xi * xi; tmp += xi * xi;
} }
@ -296,10 +326,9 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
} }
const float mean = tmp / ncols; const float mean = tmp / ncols;
const float scale = 1.0f / sqrtf(mean + eps); const float scale = rsqrtf(mean + eps);
for (int i = 0; i < ncols; i += WARP_SIZE) { for (int col = tid; col < ncols; col += WARP_SIZE) {
const int col = i + tid;
dst[row*ncols + col] = scale * x[row*ncols + col]; dst[row*ncols + col] = scale * x[row*ncols + col];
} }
} }
@ -1689,9 +1718,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
dst[i] = scale * x[i]; dst[i] = scale * x[i];
} }
static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k); add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
} }
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
@ -1709,6 +1738,12 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k); silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
} }
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
}
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0); GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
@ -2239,14 +2274,16 @@ inline void ggml_cuda_op_add(
GGML_ASSERT(src1_ddf_i != nullptr); GGML_ASSERT(src1_ddf_i != nullptr);
GGML_ASSERT(dst_ddf_i != nullptr); GGML_ASSERT(dst_ddf_i != nullptr);
const int64_t ne0 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low; const int64_t i01_diff = i01_high - i01_low;
const int64_t ne10 = src1->ne[0];
// compute // compute
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main); add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main); add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }
@ -2268,20 +2305,11 @@ inline void ggml_cuda_op_mul(
GGML_ASSERT(dst_ddf_i != nullptr); GGML_ASSERT(dst_ddf_i != nullptr);
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
for (int64_t i01 = i01_low; i01 < i01_high; i01++) { mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
// compute
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
}
(void) dst; (void) dst;
(void) src0_ddq_i; (void) src0_ddq_i;
@ -2310,6 +2338,28 @@ inline void ggml_cuda_op_silu(
(void) i1; (void) i1;
} }
inline void ggml_cuda_op_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
cudaStream_t & cudaStream_main){
GGML_ASSERT(src0_ddf_i != nullptr);
GGML_ASSERT(dst_ddf_i != nullptr);
const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;
// compute
norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
(void) src1;
(void) dst;
(void) src0_ddq_i;
(void) src1_ddf_i;
(void) i02;
(void) i1;
}
inline void ggml_cuda_op_rms_norm( inline void ggml_cuda_op_rms_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@ -2930,6 +2980,11 @@ void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true); ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
} }
void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
}
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true); ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
@ -3160,7 +3215,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
} }
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));
extra->data_device[id] = buf; extra->data_device[id] = buf;
@ -3322,6 +3377,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
} }
func = ggml_cuda_silu; func = ggml_cuda_silu;
break; break;
case GGML_OP_NORM:
if (!any_on_device) {
return false;
}
func = ggml_cuda_norm;
break;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
if (!any_on_device) { if (!any_on_device) {
return false; return false;

101
ggml.c
View File

@ -25,6 +25,7 @@
#include <float.h> #include <float.h>
#include <limits.h> #include <limits.h>
#include <stdarg.h> #include <stdarg.h>
#include <signal.h>
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
#include <unistd.h> #include <unistd.h>
@ -4723,7 +4724,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
{ {
assert(tensor->nb[0] == sizeof(ggml_fp16_t)); assert(tensor->nb[0] == sizeof(ggml_fp16_t));
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
} }
} break; } break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
@ -4775,7 +4776,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
{ {
assert(tensor->nb[0] == sizeof(ggml_fp16_t)); assert(tensor->nb[0] == sizeof(ggml_fp16_t));
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
} }
} break; } break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
@ -5035,11 +5036,15 @@ struct ggml_tensor * ggml_add_impl(
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b, struct ggml_tensor * b,
bool inplace) { bool inplace) {
GGML_ASSERT(ggml_are_same_shape(a, b)); // TODO: support less-strict constraint
// GGML_ASSERT(ggml_can_repeat(b, a));
GGML_ASSERT(ggml_can_repeat_rows(b, a));
bool is_node = false; bool is_node = false;
if (a->grad || b->grad) { if (!inplace && (a->grad || b->grad)) {
// TODO: support backward pass for broadcasting
GGML_ASSERT(ggml_are_same_shape(a, b));
is_node = true; is_node = true;
} }
@ -8297,7 +8302,7 @@ static void ggml_compute_forward_add_f32(
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
@ -8322,23 +8327,23 @@ static void ggml_compute_forward_add_f32(
if (nb10 == sizeof(float)) { if (nb10 == sizeof(float)) {
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// src0, src1 and dst are same shape => same indices // src1 is broadcastable across src0 and dst in i1, i2, i3
const int i3 = ir/(ne2*ne1); const int64_t i03 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne2*ne1)/ne1; const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1); const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
#ifdef GGML_USE_ACCELERATE #ifdef GGML_USE_ACCELERATE
vDSP_vadd( vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
ne0);
#else #else
ggml_vec_add_f32(ne0, ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
#endif #endif
// } // }
// } // }
@ -8346,15 +8351,20 @@ static void ggml_compute_forward_add_f32(
} else { } else {
// src1 is not contiguous // src1 is not contiguous
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// src0, src1 and dst are same shape => same indices // src1 is broadcastable across src0 and dst in i1, i2, i3
const int i3 = ir/(ne2*ne1); const int64_t i03 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne2*ne1)/ne1; const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1); const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
for (int i0 = 0; i0 < ne0; i0++) { for (int i0 = 0; i0 < ne0; i0++) {
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
} }
@ -11717,7 +11727,7 @@ static void ggml_compute_forward_alibi_f32(
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past const int ne1 = src0->ne[1]; // seq_len_without_past
//const int ne2 = src0->ne[2]; // n_head -> this is k const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz //const int ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0); const int n = ggml_nrows(src0);
@ -11728,8 +11738,9 @@ static void ggml_compute_forward_alibi_f32(
const int nb2 = src0->nb[2]; const int nb2 = src0->nb[2];
//const int nb3 = src0->nb[3]; //const int nb3 = src0->nb[3];
assert(nb0 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float));
assert(ne1 + n_past == ne0); (void) n_past; GGML_ASSERT(ne1 + n_past == ne0);
GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled) // add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@ -11753,7 +11764,7 @@ static void ggml_compute_forward_alibi_f32(
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
} }
pdst[0] = (i-ne0+1) * m_k + src[0]; pdst[0] = i * m_k + src[0];
} }
} }
@ -11782,7 +11793,7 @@ static void ggml_compute_forward_alibi_f16(
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past const int ne1 = src0->ne[1]; // seq_len_without_past
//const int ne2 = src0->ne[2]; // n_head -> this is k const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz //const int ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0); const int n = ggml_nrows(src0);
@ -11793,8 +11804,9 @@ static void ggml_compute_forward_alibi_f16(
const int nb2 = src0->nb[2]; const int nb2 = src0->nb[2];
//const int nb3 = src0->nb[3]; //const int nb3 = src0->nb[3];
assert(nb0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
assert(ne1 + n_past == ne0); (void) n_past; GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled) // add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@ -11819,7 +11831,7 @@ static void ggml_compute_forward_alibi_f16(
} }
// we return F32 // we return F32
pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]); pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
} }
} }
} }
@ -15944,6 +15956,9 @@ struct ggml_compute_state_shared {
// synchronization primitives // synchronization primitives
atomic_int n_active; // num active threads atomic_int n_active; // num active threads
atomic_int node_n; // active graph node atomic_int node_n; // active graph node
bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
void * abort_callback_data;
}; };
struct ggml_compute_state { struct ggml_compute_state {
@ -15975,6 +15990,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
int node_n = -1; int node_n = -1;
while (true) { while (true) {
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->shared->node_n += 1;
return (thread_ret_t) GGML_EXIT_ABORTED;
}
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
// all other threads are finished and spinning // all other threads are finished and spinning
// do finalize and init here so we don't have synchronize again // do finalize and init here so we don't have synchronize again
@ -16028,6 +16047,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
} else { } else {
break; break;
} }
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
break;
}
} }
atomic_store(&state->shared->n_active, n_threads); atomic_store(&state->shared->n_active, n_threads);
@ -16061,7 +16084,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
} }
} }
return 0; return GGML_EXIT_SUCCESS;
} }
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
@ -16401,7 +16424,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
return cplan; return cplan;
} }
void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
{ {
GGML_ASSERT(cplan); GGML_ASSERT(cplan);
GGML_ASSERT(cplan->n_threads > 0); GGML_ASSERT(cplan->n_threads > 0);
@ -16427,6 +16450,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
/*.n_threads =*/ n_threads, /*.n_threads =*/ n_threads,
/*.n_active =*/ n_threads, /*.n_active =*/ n_threads,
/*.node_n =*/ -1, /*.node_n =*/ -1,
/*.abort_callback =*/ NULL,
/*.abort_callback_data =*/ NULL,
}; };
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
@ -16450,12 +16475,12 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
const int64_t perf_start_time_us = ggml_perf_time_us(); const int64_t perf_start_time_us = ggml_perf_time_us();
// this is a work thread too // this is a work thread too
ggml_graph_compute_thread(&workers[0]); int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]);
// don't leave affinity set on the main thread // don't leave affinity set on the main thread
clear_numa_thread_affinity(); clear_numa_thread_affinity();
// join thread pool // join or kill thread pool
if (n_threads > 1) { if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) { for (int j = 1; j < n_threads; j++) {
const int rc = ggml_thread_join(workers[j].thrd, NULL); const int rc = ggml_thread_join(workers[j].thrd, NULL);
@ -16479,6 +16504,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
(double) perf_time_us_cur / 1000.0, (double) perf_time_us_cur / 1000.0,
(double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
} }
return compute_status;
} }
void ggml_graph_reset(struct ggml_cgraph * cgraph) { void ggml_graph_reset(struct ggml_cgraph * cgraph) {

11
ggml.h
View File

@ -201,8 +201,13 @@
#define GGML_MAX_NAME 48 #define GGML_MAX_NAME 48
#define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_N_THREADS 4
#define GGML_EXIT_SUCCESS 0
#define GGML_EXIT_ABORTED 1
#define GGML_UNUSED(x) (void)(x) #define GGML_UNUSED(x) (void)(x)
#define GGML_ASSERT(x) \ #define GGML_ASSERT(x) \
do { \ do { \
if (!(x)) { \ if (!(x)) { \
@ -442,6 +447,10 @@ extern "C" {
// the `n_tasks` of nodes, 1:1 mapping to cgraph nodes // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
int n_tasks[GGML_MAX_NODES]; int n_tasks[GGML_MAX_NODES];
// abort ggml_graph_compute when true
bool (*abort_callback)(void * data);
void * abort_callback_data;
}; };
// computation graph // computation graph
@ -1303,7 +1312,7 @@ extern "C" {
// ggml_graph_plan() has to be called before ggml_graph_compute() // ggml_graph_plan() has to be called before ggml_graph_compute()
// when plan.work_size > 0, caller must allocate memory for plan.work_data // when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
// same as ggml_graph_compute() but the work data is allocated as a part of the context // same as ggml_graph_compute() but the work data is allocated as a part of the context

View File

@ -10,7 +10,9 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Wdouble-promotion" #pragma GCC diagnostic ignored "-Wdouble-promotion"
#endif
#define MAX_NARGS 3 #define MAX_NARGS 3

View File

@ -7,7 +7,9 @@
#define MAX_NARGS 2 #define MAX_NARGS 2
#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Wdouble-promotion" #pragma GCC diagnostic ignored "-Wdouble-promotion"
#endif
// //
// logging // logging