mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-09-22 21:16:20 +00:00
Compare commits
8 Commits
e6939a97b5
...
c671f109e3
Author | SHA1 | Date | |
---|---|---|---|
|
c671f109e3 | ||
|
ecd5d6b65b | ||
|
2a63caaa69 | ||
|
d09770cae7 | ||
|
2615459bb2 | ||
|
54ce8cd5d6 | ||
|
178821231e | ||
|
70f19efc40 |
@ -4102,16 +4102,43 @@ class GraniteModel(LlamaModel):
|
||||
# consistency
|
||||
if attention_scale := self.hparams.get("attention_multiplier"):
|
||||
self.gguf_writer.add_attention_scale(attention_scale)
|
||||
logger.info("gguf: (granite) attention_scale = %s", attention_scale)
|
||||
if embedding_scale := self.hparams.get("embedding_multiplier"):
|
||||
self.gguf_writer.add_embedding_scale(embedding_scale)
|
||||
logger.info("gguf: (granite) embedding_scale = %s", embedding_scale)
|
||||
if residual_scale := self.hparams.get("residual_multiplier"):
|
||||
self.gguf_writer.add_residual_scale(residual_scale)
|
||||
if logits_scaling := self.hparams.get("logits_scaling"):
|
||||
self.gguf_writer.add_logit_scale(logits_scaling)
|
||||
logger.info("gguf: (granite) residual_scale = %s", residual_scale)
|
||||
if logits_scale := self.hparams.get("logits_scaling"):
|
||||
self.gguf_writer.add_logit_scale(logits_scale)
|
||||
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
|
||||
|
||||
|
||||
@Model.register("GraniteMoeForCausalLM")
|
||||
class GraniteMoeModel(GraniteModel):
|
||||
"""Conversion for IBM's GraniteMoeForCausalLM"""
|
||||
model_arch = gguf.MODEL_ARCH.GRANITE_MOE
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
"""In modeling_granitemoe, the JetMoe implementation of parallel experts
|
||||
is used. This essentially merges w1 and w3 into a single tensor with 2x
|
||||
the hidden size that is then split during forward. To keep compativility
|
||||
with existing mixtral support, we pull them apart here.
|
||||
"""
|
||||
|
||||
if name.endswith("block_sparse_moe.input_linear.weight"):
|
||||
gate, up = data_torch.chunk(2, dim=-2)
|
||||
return [
|
||||
(self.map_tensor_name(f"model.layers.{bid}.block_sparse_moe.input_linear.gate.weight"), gate),
|
||||
(self.map_tensor_name(f"model.layers.{bid}.block_sparse_moe.input_linear.up.weight"), up),
|
||||
]
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
# tree of lazy tensors
|
||||
class LazyTorchTensor(gguf.LazyBase):
|
||||
_tensor_type = torch.Tensor
|
||||
|
@ -294,6 +294,12 @@ static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) {
|
||||
alloc->free_blocks[0].offset = 0;
|
||||
alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
|
||||
alloc->max_size = 0;
|
||||
|
||||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
for (int i = 0; i < 1024; i++) {
|
||||
alloc->allocated_tensors[i].tensor = NULL;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) {
|
||||
|
@ -34,6 +34,7 @@
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/rwkv-wkv.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
@ -2243,6 +2244,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
ggml_cuda_op_hardswish(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_EXP:
|
||||
ggml_cuda_op_exp(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -2345,6 +2349,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV:
|
||||
ggml_cuda_op_rwkv_wkv(ctx, dst);
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||
break;
|
||||
@ -2806,6 +2812,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
@ -2967,6 +2974,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
89
ggml/src/ggml-cuda/rwkv-wkv.cu
Normal file
89
ggml/src/ggml-cuda/rwkv-wkv.cu
Normal file
@ -0,0 +1,89 @@
|
||||
#include "common.cuh"
|
||||
#include "rwkv-wkv.cuh"
|
||||
|
||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = CUDA_WKV_BLOCK_SIZE;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
__syncthreads();
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
__syncthreads();
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4& k = (float4&)(_k[j]);
|
||||
const float4& r = (float4&)(_r[j]);
|
||||
const float4& tf = (float4&)(_tf[j]);
|
||||
const float4& td = (float4&)(_td[j]);
|
||||
float4& s = (float4&)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
y += r.x * (tf.x * kv.x + s.x);
|
||||
y += r.y * (tf.y * kv.y + s.y);
|
||||
y += r.z * (tf.z * kv.z + s.z);
|
||||
y += r.w * (tf.w * kv.w + s.w);
|
||||
|
||||
s.x = s.x * td.x + kv.x;
|
||||
s.y = s.y * td.y + kv.y;
|
||||
s.z = s.z * td.z + kv.z;
|
||||
s.w = s.w * td.w + kv.w;
|
||||
}
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
const float * tf_d = (const float *)dst->src[3]->data;
|
||||
const float * td_d = (const float *)dst->src[4]->data;
|
||||
const float * s_d = (const float *)dst->src[5]->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[3];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[2];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
|
||||
|
||||
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
}
|
5
ggml/src/ggml-cuda/rwkv-wkv.cuh
Normal file
5
ggml/src/ggml-cuda/rwkv-wkv.cuh
Normal file
@ -0,0 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_WKV_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
@ -95,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
|
||||
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
static __global__ void exp_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = expf(x[i]);
|
||||
}
|
||||
|
||||
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
if (i >= k) {
|
||||
@ -189,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
|
||||
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
|
||||
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
|
||||
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
|
||||
@ -354,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
|
@ -8,6 +8,7 @@
|
||||
#define CUDA_RELU_BLOCK_SIZE 256
|
||||
#define CUDA_SIGMOID_BLOCK_SIZE 256
|
||||
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
|
||||
#define CUDA_EXP_BLOCK_SIZE 256
|
||||
#define CUDA_HARDSWISH_BLOCK_SIZE 256
|
||||
#define CUDA_SQR_BLOCK_SIZE 256
|
||||
#define CUDA_SQRT_BLOCK_SIZE 256
|
||||
@ -32,6 +33,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
@ -235,6 +235,7 @@ class MODEL_ARCH(IntEnum):
|
||||
NEMOTRON = auto()
|
||||
EXAONE = auto()
|
||||
GRANITE = auto()
|
||||
GRANITE_MOE = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
@ -392,6 +393,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||
MODEL_ARCH.EXAONE: "exaone",
|
||||
MODEL_ARCH.GRANITE: "granite",
|
||||
MODEL_ARCH.GRANITE_MOE: "granitemoe",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
@ -1242,6 +1244,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.GRANITE_MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
@ -256,6 +256,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
|
||||
"transformer.decoder_layer.{bid}.router", # Grok
|
||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
@ -296,6 +297,7 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.input_linear.up", # granitemoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
@ -328,6 +330,7 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.input_linear.gate", # granitemoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
@ -368,6 +371,7 @@ class TensorNameMap:
|
||||
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
|
@ -215,6 +215,7 @@ enum llm_arch {
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_RWKV6,
|
||||
LLM_ARCH_GRANITE,
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -266,6 +267,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||
{ LLM_ARCH_GRANITE, "granite" },
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@ -1478,6 +1480,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@ -2396,7 +2415,7 @@ struct llama_hparams {
|
||||
float f_max_alibi_bias = 0.0f;
|
||||
float f_logit_scale = 0.0f;
|
||||
|
||||
// Additional scale factors (Granite)
|
||||
// Additional scale factors (Granite/Granite MoE)
|
||||
float f_residual_scale = 0.0f;
|
||||
float f_embedding_scale = 0.0f;
|
||||
float f_attention_scale = 0.0f;
|
||||
@ -3056,18 +3075,14 @@ struct llama_sbatch {
|
||||
} else {
|
||||
// simple split
|
||||
if (batch->n_seq_id) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
|
||||
}
|
||||
}
|
||||
if (batch->seq_id) {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.seq_id = batch->seq_id + seq.offset;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
|
||||
@ -6052,6 +6067,7 @@ static void llm_load_hparams(
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||
@ -6060,6 +6076,7 @@ static void llm_load_hparams(
|
||||
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: model.type = e_model::MODEL_3B; break;
|
||||
case 40: model.type = e_model::MODEL_3B; break;
|
||||
// Add additional layer/vocab/etc checks here for other model sizes
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
@ -6771,7 +6788,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||
}
|
||||
|
||||
if (model.arch == LLM_ARCH_GRANITE) {
|
||||
if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
|
||||
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
||||
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
||||
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
|
||||
@ -6945,6 +6962,7 @@ static bool llm_load_tensors(
|
||||
case LLM_ARCH_REFACT:
|
||||
case LLM_ARCH_MINICPM:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
@ -15872,6 +15890,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
switch (model.arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
{
|
||||
result = llm.build_llama();
|
||||
} break;
|
||||
@ -19173,6 +19192,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
case LLM_ARCH_CHATGLM:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
|
@ -1543,6 +1543,36 @@ struct test_ssm_scan : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_RWKV_WKV
|
||||
struct test_rwkv_wkv : public test_case {
|
||||
const ggml_type type;
|
||||
|
||||
const int64_t head_count;
|
||||
const int64_t head_size;
|
||||
const int64_t n_seq_tokens;
|
||||
const int64_t n_seqs;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
|
||||
}
|
||||
|
||||
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
|
||||
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t n_tokens = n_seq_tokens * n_seqs;
|
||||
ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
|
||||
ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
|
||||
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
||||
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_MUL_MAT
|
||||
struct test_mul_mat : public test_case {
|
||||
const ggml_type type_a;
|
||||
@ -3337,6 +3367,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
|
||||
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
#if 1
|
||||
for (ggml_type type_a : base_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
|
Loading…
Reference in New Issue
Block a user