mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 04:44:34 +00:00
use op param epsilon for norms
This commit is contained in:
parent
3327d84a7f
commit
d5741c07a5
@ -810,12 +810,10 @@ void ggml_vk_norm_(const std::vector<uint32_t>& spirv, kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t nb01,
|
||||
int32_t nrows) {
|
||||
int32_t nrows, float epsilon) {
|
||||
GGML_ASSERT(nb01%sizeof(float) == 0);
|
||||
GGML_ASSERT(ne00%sizeof(float) == 0);
|
||||
|
||||
const float epsilon = 1e-6f; // this is what ggml.c uses for rms norm
|
||||
|
||||
struct PushConstants {
|
||||
uint32_t inOff, outOff;
|
||||
uint32_t ne00, nb01;
|
||||
@ -1559,11 +1557,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
|
||||
} break;
|
||||
case GGML_OP_NORM:
|
||||
{
|
||||
ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0));
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0));
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user