Update Vulkan RoPE implementation (#7818)

* Update Vulkan RoPE implementation

* Return nullptr on alloc_buffer when allocation fails, instead of throwing an exception

Minor fixes

* Fix segfault when running out of VRAM

Co-authored-by: slaren <slarengh@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
0cc4m 2024-06-11 21:20:29 +02:00 committed by GitHub
parent 14f83526cd
commit ef52d1d16a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1311 additions and 1227 deletions

View File

@ -886,7 +886,7 @@ static bool alloc_tensor_range(struct ggml_context * ctx,
fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size); fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
#endif #endif
for (size_t i = 0; i < *n_buffers; i++) { for (size_t i = 0; i < *n_buffers; i++) {
ggml_backend_buffer_free(*buffers[i]); ggml_backend_buffer_free((*buffers)[i]);
} }
free(*buffers); free(*buffers);
return false; return false;

File diff suppressed because it is too large Load Diff

View File

@ -150,7 +150,7 @@ struct vk_device {
vk_pipeline pipeline_relu_f32; vk_pipeline pipeline_relu_f32;
vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_diag_mask_inf_f32;
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
vk_pipeline pipeline_rope_f32, pipeline_rope_f16; vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
vk_pipeline pipeline_argsort_f32; vk_pipeline pipeline_argsort_f32;
vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_sum_rows_f32;
@ -283,26 +283,15 @@ struct vk_op_diag_mask_push_constants {
struct vk_op_rope_push_constants { struct vk_op_rope_push_constants {
uint32_t ncols; uint32_t ncols;
uint32_t n_dims;
float freq_scale; float freq_scale;
uint32_t p_delta_rows; uint32_t p_delta_rows;
float freq_base; float freq_base;
float ext_factor; float ext_factor;
float attn_factor; float attn_factor;
float corr_dims[4]; float corr_dims[2];
};
struct vk_op_rope_neox_push_constants {
uint32_t ncols;
uint32_t ndims;
float freq_scale;
uint32_t p_delta_rows;
float freq_base;
float ext_factor;
float attn_factor;
float corr_dims[4];
float theta_scale; float theta_scale;
float inv_ndims; uint32_t has_ff;
uint32_t has_freq_facs;
}; };
struct vk_op_soft_max_push_constants { struct vk_op_soft_max_push_constants {
@ -1534,11 +1523,11 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
@ -3905,10 +3894,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
} }
} else { } else {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rope_f32; return ctx->device->pipeline_rope_norm_f32;
} }
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_rope_f16; return ctx->device->pipeline_rope_norm_f16;
} }
} }
return nullptr; return nullptr;
@ -4152,10 +4141,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
} else if (op == GGML_OP_ROPE) { } else if (op == GGML_OP_ROPE) {
const int mode = ((int32_t *) dst->op_params)[2];
const bool is_neox = mode & 2;
if (is_neox) {
// Empty src2 is possible in rope, but the shader needs a buffer // Empty src2 is possible in rope, but the shader needs a buffer
vk_subbuffer subbuf_z; vk_subbuffer subbuf_z;
if (use_src2) { if (use_src2) {
@ -4166,10 +4151,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
} else {
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
}
} else if (use_src2) { } else if (use_src2) {
ggml_vk_sync_buffers(subctx); ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_Z, z_buf_offset, z_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_Z, z_buf_offset, z_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
@ -4391,7 +4372,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; // const int mode = ((int32_t *) dst->op_params)[2];
// const int n_ctx = ((int32_t *) dst->op_params)[3]; // const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
const float freq_base = ((float *) dst->op_params)[5]; const float freq_base = ((float *) dst->op_params)[5];
@ -4401,28 +4382,16 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
const float beta_fast = ((float *) dst->op_params)[9]; const float beta_fast = ((float *) dst->op_params)[9];
const float beta_slow = ((float *) dst->op_params)[10]; const float beta_slow = ((float *) dst->op_params)[10];
const bool is_neox = mode & 2;
#pragma message("TODO: update rope NORM mode to match NEOX mode")
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
float corr_dims[2]; float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
if (is_neox) {
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.0f / n_dims;
ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f}, theta_scale, inv_ndims, freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
src2 != nullptr, src2 != nullptr,
}); });
} else {
ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, {
(uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1],
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f}
});
}
} }
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@ -6070,7 +6039,13 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(
std::cerr << "ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")" << std::endl; std::cerr << "ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")" << std::endl;
#endif #endif
ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
vk_buffer dev_buffer = ggml_vk_create_buffer_device(ctx->ctx, size);
vk_buffer dev_buffer = nullptr;
try {
dev_buffer = ggml_vk_create_buffer_device(ctx->ctx, size);
} catch (const vk::SystemError& e) {
return nullptr;
}
ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->ctx, std::move(dev_buffer), ctx->name); ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->ctx, std::move(dev_buffer), ctx->name);
@ -6466,7 +6441,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
// return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; // return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
// } break; // } break;
case GGML_OP_ROPE: case GGML_OP_ROPE:
return true; return ggml_is_contiguous(op->src[0]);
case GGML_OP_NONE: case GGML_OP_NONE:
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
case GGML_OP_VIEW: case GGML_OP_VIEW:

View File

@ -2400,7 +2400,7 @@ void main() {
""" """
# ROPE # ROPE
rope_src = """ rope_norm_src = """
#version 450 #version 450
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
@ -2408,17 +2408,21 @@ rope_src = """
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {int data_b[];}; layout (binding = 1) readonly buffer Y {int data_pos[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 2) readonly buffer Z {float data_ff[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
uint ncols; uint ncols;
uint n_dims;
float freq_scale; float freq_scale;
uint p_delta_rows; uint p_delta_rows;
float freq_base; float freq_base;
float ext_factor; float ext_factor;
float attn_factor; float attn_factor;
float corr_dims[4]; float corr_dims[2];
float theta_scale;
uint has_ff;
} p; } p;
float rope_yarn_ramp(const float low, const float high, const uint i0) { float rope_yarn_ramp(const float low, const float high, const uint i0) {
@ -2450,14 +2454,24 @@ void main() {
return; return;
} }
if (col >= p.n_dims) {
const uint i = row*p.ncols + col;
data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1];
return;
}
const uint i = row*p.ncols + col; const uint i = row*p.ncols + col;
const uint i2 = row/p.p_delta_rows; const uint i2 = row/p.p_delta_rows;
const int pos = data_b[i2]; const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
const float theta_base = pos * pow(p.freq_base, -float(col)/p.ncols);
const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta_base, col, cos_theta, sin_theta); rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
const float x0 = float(data_a[i + 0]); const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + 1]); const float x1 = float(data_a[i + 1]);
@ -2475,22 +2489,21 @@ rope_neox_src = """
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {int data_b[];}; layout (binding = 1) readonly buffer Y {int data_pos[];};
layout (binding = 2) readonly buffer Z {float data_freq_factors[];}; layout (binding = 2) readonly buffer Z {float data_ff[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
layout (push_constant) uniform parameter { layout (push_constant) uniform parameter {
uint ncols; uint ncols;
uint ndims; uint n_dims;
float freq_scale; float freq_scale;
uint p_delta_rows; uint p_delta_rows;
float freq_base; float freq_base;
float ext_factor; float ext_factor;
float attn_factor; float attn_factor;
float corr_dims[4]; float corr_dims[2];
float theta_scale; float theta_scale;
float inv_ndims; uint has_ff;
uint has_freq_facs;
} p; } p;
float rope_yarn_ramp(const float low, const float high, const uint i0) { float rope_yarn_ramp(const float low, const float high, const uint i0) {
@ -2522,11 +2535,8 @@ void main() {
return; return;
} }
const uint ib = col / p.ndims; if (col >= p.n_dims) {
const uint ic = col % p.ndims; const uint i = row*p.ncols + col;
if (ib > 0) {
const uint i = row*p.ncols + ib*p.ndims + ic;
data_d[i + 0] = data_a[i + 0]; data_d[i + 0] = data_a[i + 0];
data_d[i + 1] = data_a[i + 1]; data_d[i + 1] = data_a[i + 1];
@ -2534,29 +2544,27 @@ void main() {
return; return;
} }
const uint i = row*p.ncols + ib*p.ndims + ic/2; const uint i = row*p.ncols + col/2;
const uint i2 = row/p.p_delta_rows; const uint i2 = row/p.p_delta_rows;
const int pos = data_b[i2]; const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor; const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta_base, ic, cos_theta, sin_theta); rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
const float x0 = float(data_a[i + 0]); const float x0 = float(data_a[i + 0]);
const float x1 = float(data_a[i + p.ndims/2]); const float x1 = float(data_a[i + p.n_dims/2]);
data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
data_d[i + p.ndims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
} }
""" """
argsort_src = """ argsort_src = """
#version 450 #version 450
#extension GL_EXT_shader_16bit_storage : require
#define BLOCK_SIZE 1024 #define BLOCK_SIZE 1024
#define ASC 0 #define ASC 0
@ -3039,8 +3047,8 @@ async def main():
tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("soft_max_f32_f16", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float16_t", "C_TYPE": "float16_t", "D_TYPE": "float"})) tasks.append(string_to_spv("soft_max_f32_f16", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float16_t", "C_TYPE": "float16_t", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("rope_norm_f32", rope_norm_src, {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) tasks.append(string_to_spv("rope_norm_f16", rope_norm_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))