Compare commits

...

8 Commits

Author SHA1 Message Date
Gabe Goodhart
c671f109e3
Merge 2615459bb2 into ecd5d6b65b 2024-09-22 05:34:54 +02:00
Shankar
ecd5d6b65b
llama: remove redundant loop when constructing ubatch (#9574)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-09-22 04:30:34 +02:00
Molly Sophia
2a63caaa69
RWKV v6: RWKV_WKV op CUDA implementation (#9454)
* ggml: CUDA unary op EXP

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* ggml: rwkv_wkv op CUDA impl

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
2024-09-22 04:29:12 +02:00
slaren
d09770cae7
ggml-alloc : fix list of allocated tensors with GGML_ALLOCATOR_DEBUG (#9573)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-09-21 14:24:23 +02:00
Gabe Goodhart
2615459bb2 feat(granitemoe): Implement granitemoe
GraniteMoE follows the mixtral architecture (once the input_linear layers
are split into gate_exps/up_exps). The main delta is the addition of the
same four multipliers used in Granite.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2024-09-17 06:46:02 -06:00
Gabe Goodhart
54ce8cd5d6 fix(granitemoe convert): Split the double-sized input layer into gate and up
After a lot of staring and squinting, it's clear that the standard mixtral
expert implementation is equivalent to the vectorized parallel experts in
granite. The difference is that in granite, the w1 and w3 are concatenated
into a single tensor "input_linear." Rather than reimplementing all of the
math on the llama.cpp side, the much simpler route is to just split this
tensor during conversion and follow the standard mixtral route.

Branch: GraniteMoE

Co-Authored-By: alex.brooks@ibm.com

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2024-09-17 06:45:55 -06:00
Gabe Goodhart
178821231e feat(convert_hf_to_gguf): Add GraniteMoeModel
GraniteMoe has the same configuration deltas as Granite

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2024-09-17 06:42:57 -06:00
Gabe Goodhart
70f19efc40 feat(gguf-py): Add granitemoe architecture
This includes the addition of new tensor names for the new moe layers.
These may not be correct at this point due to the need for the hack in
gguf_writer.py to double-check the length of the shape for these layers.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
2024-09-17 06:42:57 -06:00
11 changed files with 268 additions and 27 deletions

View File

@ -4102,16 +4102,43 @@ class GraniteModel(LlamaModel):
# consistency
if attention_scale := self.hparams.get("attention_multiplier"):
self.gguf_writer.add_attention_scale(attention_scale)
logger.info("gguf: (granite) attention_scale = %s", attention_scale)
if embedding_scale := self.hparams.get("embedding_multiplier"):
self.gguf_writer.add_embedding_scale(embedding_scale)
logger.info("gguf: (granite) embedding_scale = %s", embedding_scale)
if residual_scale := self.hparams.get("residual_multiplier"):
self.gguf_writer.add_residual_scale(residual_scale)
if logits_scaling := self.hparams.get("logits_scaling"):
self.gguf_writer.add_logit_scale(logits_scaling)
logger.info("gguf: (granite) residual_scale = %s", residual_scale)
if logits_scale := self.hparams.get("logits_scaling"):
self.gguf_writer.add_logit_scale(logits_scale)
logger.info("gguf: (granite) logits_scale = %s", logits_scale)
@Model.register("GraniteMoeForCausalLM")
class GraniteMoeModel(GraniteModel):
"""Conversion for IBM's GraniteMoeForCausalLM"""
model_arch = gguf.MODEL_ARCH.GRANITE_MOE
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
"""In modeling_granitemoe, the JetMoe implementation of parallel experts
is used. This essentially merges w1 and w3 into a single tensor with 2x
the hidden size that is then split during forward. To keep compativility
with existing mixtral support, we pull them apart here.
"""
if name.endswith("block_sparse_moe.input_linear.weight"):
gate, up = data_torch.chunk(2, dim=-2)
return [
(self.map_tensor_name(f"model.layers.{bid}.block_sparse_moe.input_linear.gate.weight"), gate),
(self.map_tensor_name(f"model.layers.{bid}.block_sparse_moe.input_linear.up.weight"), up),
]
return super().modify_tensors(data_torch, name, bid)
###### CONVERSION LOGIC ######
# tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor

View File

@ -294,6 +294,12 @@ static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) {
alloc->free_blocks[0].offset = 0;
alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
alloc->max_size = 0;
#ifdef GGML_ALLOCATOR_DEBUG
for (int i = 0; i < 1024; i++) {
alloc->allocated_tensors[i].tensor = NULL;
}
#endif
}
static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) {

View File

@ -34,6 +34,7 @@
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/rwkv-wkv.cuh"
#include <algorithm>
#include <array>
@ -2243,6 +2244,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_HARDSWISH:
ggml_cuda_op_hardswish(ctx, dst);
break;
case GGML_UNARY_OP_EXP:
ggml_cuda_op_exp(ctx, dst);
break;
default:
return false;
}
@ -2345,6 +2349,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_RWKV_WKV:
ggml_cuda_op_rwkv_wkv(ctx, dst);
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
@ -2806,6 +2812,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP:
return ggml_is_contiguous(op->src[0]);
default:
return false;
@ -2967,6 +2974,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV:
return true;
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)

View File

@ -0,0 +1,89 @@
#include "common.cuh"
#include "rwkv-wkv.cuh"
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
const int tid = threadIdx.x;
const int bid = blockIdx.x;
const int head_size = CUDA_WKV_BLOCK_SIZE;
const int batch_i = bid / H;
const int head_i = bid % H;
const int state_size = C * head_size;
const int n_seq_tokens = T / B;
float state[head_size];
__shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
#pragma unroll
for (int i = 0; i < head_size; i++) {
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
}
__syncthreads();
_tf[tid] = tf[head_i * head_size + tid];
__syncthreads();
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
__syncthreads();
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
__syncthreads();
const float _v = v[t];
float y = 0;
for (int j = 0; j < head_size; j += 4) {
const float4& k = (float4&)(_k[j]);
const float4& r = (float4&)(_r[j]);
const float4& tf = (float4&)(_tf[j]);
const float4& td = (float4&)(_td[j]);
float4& s = (float4&)(state[j]);
float4 kv;
kv.x = k.x * _v;
kv.y = k.y * _v;
kv.z = k.z * _v;
kv.w = k.w * _v;
y += r.x * (tf.x * kv.x + s.x);
y += r.y * (tf.y * kv.y + s.y);
y += r.z * (tf.z * kv.z + s.z);
y += r.w * (tf.w * kv.w + s.w);
s.x = s.x * td.x + kv.x;
s.y = s.y * td.y + kv.y;
s.z = s.z * td.z + kv.z;
s.w = s.w * td.w + kv.w;
}
dst[t] = y;
}
#pragma unroll
for (int i = 0; i < head_size; i++) {
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
}
}
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * k_d = (const float *)dst->src[0]->data;
const float * v_d = (const float *)dst->src[1]->data;
const float * r_d = (const float *)dst->src[2]->data;
const float * tf_d = (const float *)dst->src[3]->data;
const float * td_d = (const float *)dst->src[4]->data;
const float * s_d = (const float *)dst->src[5]->data;
const int64_t B = dst->src[5]->ne[1];
const int64_t T = dst->src[0]->ne[3];
const int64_t C = dst->ne[0];
const int64_t H = dst->src[0]->ne[2];
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(C % H == 0);
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
}

View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_WKV_BLOCK_SIZE 64
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -95,6 +95,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
static __global__ void exp_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = expf(x[i]);
}
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@ -189,6 +198,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
@ -354,6 +368,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;

View File

@ -8,6 +8,7 @@
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_EXP_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
@ -32,6 +33,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -235,6 +235,7 @@ class MODEL_ARCH(IntEnum):
NEMOTRON = auto()
EXAONE = auto()
GRANITE = auto()
GRANITE_MOE = auto()
class MODEL_TENSOR(IntEnum):
@ -392,6 +393,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -1242,6 +1244,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GRANITE_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO
}

View File

@ -256,6 +256,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@ -296,6 +297,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.input_linear.up", # granitemoe
),
MODEL_TENSOR.FFN_UP_SHEXP: (
@ -328,6 +330,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.input_linear.gate", # granitemoe
),
MODEL_TENSOR.FFN_GATE_SHEXP: (
@ -368,6 +371,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
),
MODEL_TENSOR.FFN_DOWN_SHEXP: (

View File

@ -215,6 +215,7 @@ enum llm_arch {
LLM_ARCH_EXAONE,
LLM_ARCH_RWKV6,
LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_UNKNOWN,
};
@ -266,6 +267,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -1478,6 +1480,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_GRANITE_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@ -2396,7 +2415,7 @@ struct llama_hparams {
float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f;
// Additional scale factors (Granite)
// Additional scale factors (Granite/Granite MoE)
float f_residual_scale = 0.0f;
float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f;
@ -3056,18 +3075,14 @@ struct llama_sbatch {
} else {
// simple split
if (batch->n_seq_id) {
for (size_t i = 0; i < length; ++i) {
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
}
} else {
for (size_t i = 0; i < length; ++i) {
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
}
}
if (batch->seq_id) {
for (size_t i = 0; i < length; ++i) {
ubatch.seq_id = batch->seq_id + seq.offset;
}
} else {
for (size_t i = 0; i < length; ++i) {
ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
@ -6052,6 +6067,7 @@ static void llm_load_hparams(
}
} break;
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
@ -6060,6 +6076,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_3B; break;
case 40: model.type = e_model::MODEL_3B; break;
// Add additional layer/vocab/etc checks here for other model sizes
default: model.type = e_model::MODEL_UNKNOWN;
@ -6771,7 +6788,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
}
if (model.arch == LLM_ARCH_GRANITE) {
if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
@ -6945,6 +6962,7 @@ static bool llm_load_tensors(
case LLM_ARCH_REFACT:
case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@ -15872,6 +15890,7 @@ static struct ggml_cgraph * llama_build_graph(
switch (model.arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
result = llm.build_llama();
} break;
@ -19173,6 +19192,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2

View File

@ -1543,6 +1543,36 @@ struct test_ssm_scan : public test_case {
}
};
// GGML_OP_RWKV_WKV
struct test_rwkv_wkv : public test_case {
const ggml_type type;
const int64_t head_count;
const int64_t head_size;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override {
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
}
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t n_tokens = n_seq_tokens * n_seqs;
ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
return out;
}
};
// GGML_OP_MUL_MAT
struct test_mul_mat : public test_case {
const ggml_type type_a;
@ -3337,6 +3367,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {