use op param epsilon for norms

This commit is contained in:
Aaron Miller 2023-10-11 18:40:07 -07:00 committed by cebtenzzre
parent 3327d84a7f
commit d5741c07a5

View File

@ -810,12 +810,10 @@ void ggml_vk_norm_(const std::vector<uint32_t>& spirv, kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inOff, uint32_t outOff,
int32_t ne00, int32_t nb01,
int32_t nrows) {
int32_t nrows, float epsilon) {
GGML_ASSERT(nb01%sizeof(float) == 0);
GGML_ASSERT(ne00%sizeof(float) == 0);
const float epsilon = 1e-6f; // this is what ggml.c uses for rms norm
struct PushConstants {
uint32_t inOff, outOff;
uint32_t ne00, nb01;
@ -1559,11 +1557,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
} break;
case GGML_OP_NORM:
{
ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0));
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
} break;
case GGML_OP_RMS_NORM:
{
ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0));
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
} break;
case GGML_OP_MUL_MAT:
{