From 97bdd26eee11fe109dec00de75690ceef61c03f2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 15 Jul 2024 20:50:47 +0200 Subject: [PATCH] Refactor lora adapter support (#8332) * lora: load to devide buft * add patch tensor function * correct tensor patch * llama_lora_adapter_apply * correct ggml_backend_tensor_copy * add llm_build_mm * fix auto merge * update based on review comments * add convert script * no more transpose A * add f16 convert * add metadata check * add sanity check * fix ftype * add requirements * fix requirements * fix outfile * conversion: only allow selected models * fix types * cuda : do not use dmmv if the tensor does not have enough cols * llama : lora fixes * do not disable mmap with lora Co-authored-by: slaren * llm_build_lora_mm_id * convert_lora : MoE LoRA conversion support * convert_lora : prefer safetensors, similarly to convert_hf * convert_hf : simplify modify_tensors for InternLM2 * convert_lora : lazy conversion * llama : load and use alpha from LoRA adapters * llama : use llm_build_lora_mm in most model graphs * auto scale * Revert "auto scale" This reverts commit 42415a4874e0f963e4aca6796ea5dfb97cd17464. * remove redundant params * Apply suggestions from code review Co-authored-by: slaren * change kv metadata * move add_type to __init__ * convert_hf : move add_type to main() * convert_lora : use the GGUFWriter from Model instead of overwriting it --------- Co-authored-by: slaren Co-authored-by: Francis Couture-Harpin --- common/common.cpp | 13 +- convert_hf_to_gguf.py | 34 +- convert_lora_to_gguf.py | 374 ++++++ ggml/src/ggml-cuda.cu | 3 +- ggml/src/ggml.c | 4 +- gguf-py/gguf/constants.py | 10 + gguf-py/gguf/gguf_writer.py | 3 + gguf-py/gguf/quants.py | 2 +- include/llama.h | 37 +- requirements.txt | 1 + .../requirements-convert_lora_to_gguf.txt | 2 + src/llama.cpp | 1010 +++++++++-------- 12 files changed, 963 insertions(+), 530 deletions(-) create mode 100755 convert_lora_to_gguf.py create mode 100644 requirements/requirements-convert_lora_to_gguf.txt diff --git a/common/common.cpp b/common/common.cpp index 9035c3592..dbb724fbb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -685,7 +685,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa if (arg == "--lora") { CHECK_ARG params.lora_adapter.emplace_back(argv[i], 1.0f); - params.use_mmap = false; return true; } if (arg == "--lora-scaled") { @@ -693,7 +692,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa const char* lora_adapter = argv[i]; CHECK_ARG params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); - params.use_mmap = false; return true; } if (arg == "--lora-base") { @@ -2089,19 +2087,14 @@ std::tuple llama_init_from_gpt_par for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); float lora_scale = std::get<1>(params.lora_adapter[i]); - int err = llama_model_apply_lora_from_file(model, - lora_adapter.c_str(), - lora_scale, - ((i > 0) || params.lora_base.empty()) - ? NULL - : params.lora_base.c_str(), - params.n_threads); - if (err != 0) { + auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str()); + if (adapter == nullptr) { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); llama_free(lctx); llama_free_model(model); return std::make_tuple(nullptr, nullptr); } + llama_lora_adapter_set(lctx, adapter, lora_scale); } if (params.ignore_eos) { diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 42dace219..a755b0a60 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2264,13 +2264,6 @@ class InternLM2Model(Model): special_vocab.add_to_gguf(self.gguf_writer) - def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int): - if n_head_kv is not None and n_head != n_head_kv: - n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) - def set_gguf_parameters(self): self.gguf_writer.add_name("InternLM2") self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) @@ -2290,26 +2283,22 @@ class InternLM2Model(Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: num_heads = self.hparams["num_attention_heads"] num_kv_heads = self.hparams["num_key_value_heads"] - hidden_size = self.hparams["hidden_size"] + n_embd = self.hparams["hidden_size"] q_per_kv = num_heads // num_kv_heads - head_dim = hidden_size // num_heads + head_dim = n_embd // num_heads num_groups = num_heads // q_per_kv - qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv" - - if re.match(qkv_pattern, name): - bid = re.findall(qkv_pattern, name)[0] + if bid is not None and f"model.layers.{bid}.attention.wqkv" in name: qkv = data_torch - # qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim) - qkv = qkv.T.reshape((-1, num_groups, q_per_kv + 2, head_dim)) - q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :] + + qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd)) + q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1] + # The model weights of q and k equire additional reshape. - # q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads) - q = self._hf_permute_qk(q.reshape((q.shape[0], -1)).T, num_heads, num_heads) - # k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads) - k = self._hf_permute_qk(k.reshape((k.shape[0], -1)).T, num_heads, num_kv_heads) - # v = rearrange(v, " o g n i -> o (g n i)").T - v = v.reshape((v.shape[0], -1)).T + q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads) + k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads) + v = v.reshape((-1, v.shape[-1])) + return [ (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), q), (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k), @@ -3585,6 +3574,7 @@ def main() -> None: small_first_shard=args.no_tensor_first_split) logger.info("Set model parameters") + model_instance.gguf_writer.add_type(gguf.GGUFType.MODEL) model_instance.set_gguf_parameters() logger.info("Set model tokenizer") diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py new file mode 100755 index 000000000..4bb939d45 --- /dev/null +++ b/convert_lora_to_gguf.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from dataclasses import dataclass +import logging +import argparse +import os +import sys +import json +from math import prod +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast + +import torch + +if TYPE_CHECKING: + from torch import Tensor + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf + +# reuse model definitions from convert_hf_to_gguf.py +from convert_hf_to_gguf import LazyTorchTensor, Model + +logger = logging.getLogger("lora-to-gguf") + + +@dataclass +class PartialLoraTensor: + A: Tensor | None = None + B: Tensor | None = None + + +# magic to support tensor shape modifications and splitting +class LoraTorchTensor: + _lora_A: Tensor # (n_rank, row_size) + _lora_B: Tensor # (col_size, n_rank) + _rank: int + + def __init__(self, A: Tensor, B: Tensor): + assert len(A.shape) == len(B.shape) + assert A.shape[-2] == B.shape[-1] + if A.dtype != B.dtype: + A = A.to(torch.float32) + B = B.to(torch.float32) + self._lora_A = A + self._lora_B = B + self._rank = B.shape[-1] + + def get_lora_A_B(self) -> tuple[Tensor, Tensor]: + return (self._lora_A, self._lora_B) + + def __getitem__( + self, + indices: ( + SupportsIndex + | slice + | tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature + ), + ) -> LoraTorchTensor: + shape = self.shape + if isinstance(indices, SupportsIndex): + if len(shape) > 2: + return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) + else: + raise NotImplementedError # can't return a vector + elif isinstance(indices, slice): + if len(shape) > 2: + return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices]) + else: + return LoraTorchTensor(self._lora_A, self._lora_B[indices]) + elif isinstance(indices, tuple): + assert len(indices) > 0 + if indices[-1] is Ellipsis: + return self[indices[:-1]] + # expand ellipsis + indices = tuple( + u + for v in ( + ( + (slice(None, None) for _ in range(len(indices) - 1)) + if i is Ellipsis + else (i,) + ) + for i in indices + ) + for u in v + ) + + if len(indices) < len(shape): + indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape)))) + + # TODO: make sure this is correct + indices_A = ( + *( + ( + j.__index__() % self._lora_A.shape[i] + if isinstance(j, SupportsIndex) + else slice(None, None) + ) + for i, j in enumerate(indices[:-2]) + ), + slice(None, None), + indices[-1], + ) + indices_B = indices[:-1] + return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B]) + else: + raise NotImplementedError # unknown indice type + + @property + def dtype(self) -> torch.dtype: + assert self._lora_A.dtype == self._lora_B.dtype + return self._lora_A.dtype + + @property + def shape(self) -> tuple[int, ...]: + assert len(self._lora_A.shape) == len(self._lora_B.shape) + return (*self._lora_B.shape[:-1], self._lora_A.shape[-1]) + + def size(self, dim=None): + assert dim is None + return self.shape + + def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor: + if isinstance(shape[0], tuple): + new_shape: tuple[int, ...] = shape[0] + else: + new_shape = cast(tuple[int, ...], shape) + orig_shape = self.shape + if len(new_shape) < 2: + raise NotImplementedError # can't become a vector + + # expand -1 in the shape + if any(dim == -1 for dim in new_shape): + n_elems = prod(orig_shape) + n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape) + assert n_elems % n_new_elems == 0 + new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),) + + if new_shape[-1] != orig_shape[-1]: + raise NotImplementedError # can't reshape the row size trivially + + shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1]) + shape_B = (*new_shape[:-1], self._rank) + return LoraTorchTensor( + self._lora_A.reshape(shape_A), + self._lora_B.reshape(shape_B), + ) + + def reshape_as(self, other: Tensor) -> LoraTorchTensor: + return self.reshape(*other.shape) + + def view(self, *size: int) -> LoraTorchTensor: + return self.reshape(*size) + + def permute(self, *dims: int) -> LoraTorchTensor: + shape = self.shape + dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims) + if dims[-1] == -1: + # TODO: support higher dimensional A shapes bigger than 1 + assert all(dim == 1 for dim in self._lora_A.shape[:-2]) + return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims)) + if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1: + return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims)) + else: + # TODO: compose the above two + raise NotImplementedError + + def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor: + shape = self.shape + dims = [i for i in range(len(shape))] + dims[dim0], dims[dim1] = dims[dim1], dims[dim0] + return self.permute(*dims) + + def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: + return self.transpose(axis0, axis1) + + def to(self, *args, **kwargs): + return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) + + @classmethod + def __torch_function__(cls, func: Callable, types, args=(), kwargs=None): + del types # unused + + if kwargs is None: + kwargs = {} + + if func is torch.permute: + return type(args[0]).permute(*args, **kwargs) + elif func is torch.reshape: + return type(args[0]).reshape(*args, **kwargs) + elif func is torch.stack: + assert isinstance(args[0], Sequence) + dim = kwargs.get("dim", 0) + assert dim == 0 + return LoraTorchTensor( + torch.stack([a._lora_A for a in args[0]], dim), + torch.stack([b._lora_B for b in args[0]], dim), + ) + elif func is torch.cat: + assert isinstance(args[0], Sequence) + dim = kwargs.get("dim", 0) + assert dim == 0 + if len(args[0][0].shape) > 2: + return LoraTorchTensor( + torch.cat([a._lora_A for a in args[0]], dim), + torch.cat([b._lora_B for b in args[0]], dim), + ) + elif all(torch.equal(args[0][0]._lora_A, t._lora_A) for t in args[0][1:]): + return LoraTorchTensor( + args[0][0]._lora_A, + torch.cat([b._lora_B for b in args[0]], dim), + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + +def get_base_tensor_name(lora_tensor_name: str) -> str: + base_name = lora_tensor_name.replace("base_model.model.", "") + base_name = base_name.replace(".lora_A.weight", ".weight") + base_name = base_name.replace(".lora_B.weight", ".weight") + return base_name + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file") + parser.add_argument( + "--outfile", type=Path, + help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "--no-lazy", action="store_true", + help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + parser.add_argument( + "--base", type=Path, required=True, + help="directory containing base model file", + ) + parser.add_argument( + "lora_path", type=Path, + help="directory containing LoRA adapter file", + ) + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + "auto": gguf.LlamaFileType.GUESSED, + } + + ftype = ftype_map[args.outtype] + + dir_base_model: Path = args.base + dir_lora: Path = args.lora_path + lora_config = dir_lora / "adapter_config.json" + input_model = dir_lora / "adapter_model.safetensors" + + if args.outfile is not None: + fname_out = args.outfile + else: + # output in the same directory as the model by default + fname_out = dir_lora / 'ggml-lora-{ftype}.gguf' + + if os.path.exists(input_model): + # lazy import load_file only if lora is in safetensors format. + from safetensors.torch import load_file + + lora_model = load_file(input_model, device="cpu") + else: + input_model = os.path.join(dir_lora, "adapter_model.bin") + lora_model = torch.load(input_model, map_location="cpu", weights_only=True) + + # load base model + logger.info(f"Loading base model: {dir_base_model.name}") + hparams = Model.load_hparams(dir_base_model) + with torch.inference_mode(): + try: + model_class = Model.from_model_architecture(hparams["architectures"][0]) + except NotImplementedError: + logger.error(f"Model {hparams['architectures'][0]} is not supported") + sys.exit(1) + + class LoraModel(model_class): + model_arch = model_class.model_arch + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + tensor_map: dict[str, PartialLoraTensor] = {} + + for name, tensor in lora_model.items(): + if self.lazy: + tensor = LazyTorchTensor.from_eager(tensor) + base_name = get_base_tensor_name(name) + is_lora_a = ".lora_A.weight" in name + is_lora_b = ".lora_B.weight" in name + if not is_lora_a and not is_lora_b: + if ".base_layer.weight" in name: + continue + logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") + sys.exit(1) + + if base_name in tensor_map: + if is_lora_a: + tensor_map[base_name].A = tensor + else: + tensor_map[base_name].B = tensor + else: + if is_lora_a: + tensor_map[base_name] = PartialLoraTensor(A=tensor) + else: + tensor_map[base_name] = PartialLoraTensor(B=tensor) + + for name, tensor in tensor_map.items(): + assert tensor.A is not None + assert tensor.B is not None + yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B))) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + dest = super().modify_tensors(data_torch, name, bid) + for dest_name, dest_data in dest: + assert isinstance(dest_data, LoraTorchTensor) + lora_a, lora_b = dest_data.get_lora_A_B() + + yield (dest_name + ".lora_a", lora_a) + yield (dest_name + ".lora_b", lora_b) + + model_instance = LoraModel( + dir_base_model, + ftype, + fname_out, + is_big_endian=args.bigendian, + use_temp_file=False, + eager=args.no_lazy, + model_name=None, + ) + + with open(lora_config, "r") as f: + lparams: dict[str, Any] = json.load(f) + + alpha = lparams["lora_alpha"] + + model_instance.gguf_writer.add_string(gguf.Keys.General.TYPE, gguf.GGUFType.ADAPTER) + model_instance.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") + model_instance.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, float(alpha)) + model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) + logger.info("Exporting model...") + model_instance.write() + logger.info(f"Model successfully exported to {model_instance.fname_out}") diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ed784ea1c..39e345b66 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -1876,7 +1876,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1; + && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2 + && src1->ne[1] == 1; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9a5414787..60b3c5e7a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -19478,7 +19478,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph fprintf(fp, "digraph G {\n"); fprintf(fp, " newrank = true;\n"); - fprintf(fp, " rankdir = LR;\n"); + fprintf(fp, " rankdir = TB;\n"); for (int i = 0; i < gb->n_nodes; i++) { struct ggml_tensor * node = gb->nodes[i]; @@ -19540,7 +19540,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph } fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); - if (ggml_nelements(node) < 5) { + if (ggml_nelements(node) < 5 && node->data != NULL) { fprintf(fp, " | ("); for (int j = 0; j < ggml_nelements(node); j++) { if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a95a44237..5eb3df706 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -19,6 +19,7 @@ GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h class Keys: class General: + TYPE = "general.type" ARCHITECTURE = "general.architecture" QUANTIZATION_VERSION = "general.quantization_version" ALIGNMENT = "general.alignment" @@ -120,11 +121,20 @@ class Keys: MIDDLE_ID = "tokenizer.ggml.middle_token_id" EOT_ID = "tokenizer.ggml.eot_token_id" + class Adapter: + TYPE = "adapter.type" + LORA_ALPHA = "adapter.lora.alpha" + # # recommended mapping of model tensor names for storage in gguf # +class GGUFType: + MODEL = "model" + ADAPTER = "adapter" + + class MODEL_ARCH(IntEnum): LLAMA = auto() FALCON = auto() diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index cf9554162..b0197961d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -424,6 +424,9 @@ class GGUFWriter: fout.close() self.fout = None + def add_type(self, type_name: str) -> None: + self.add_string(Keys.General.TYPE, type_name) + def add_architecture(self) -> None: self.add_string(Keys.General.ARCHITECTURE, self.arch) diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index b22eec166..16e0a9aaa 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -43,7 +43,7 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np. osize *= dim out = np.empty(shape=osize, dtype=otype) # compute over groups of 16 rows (arbitrary, but seems good for performance) - n_groups = rows.shape[0] // 16 + n_groups = (rows.shape[0] // 16) or 1 np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out) return out.reshape(oshape) diff --git a/include/llama.h b/include/llama.h index 3970c3aeb..c57d21f0c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -411,6 +411,9 @@ extern "C" { const char * content; } llama_chat_message; + // lora adapter + struct llama_lora_adapter; + // Helpers for getting default parameters LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); @@ -510,18 +513,28 @@ extern "C" { const char * fname_out, const llama_model_quantize_params * params); - // Apply a LoRA adapter to a loaded model - // path_base_model is the path to a higher quality model to use as a base for - // the layers modified by the adapter. Can be NULL to use the current loaded model. - // The model needs to be reloaded before applying a new adapter, otherwise the adapter - // will be applied on top of the previous one - // Returns 0 on success - LLAMA_API int32_t llama_model_apply_lora_from_file( - const struct llama_model * model, - const char * path_lora, - float scale, - const char * path_base_model, - int32_t n_threads); + // Load a LoRA adapter from file + // The loaded adapter will be associated to the given model, and will be free when the model is deleted + LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init( + struct llama_model * model, + const char * path_lora); + + // Add a loaded LoRA adapter to given context + // This will not modify model's weight + LLAMA_API int32_t llama_lora_adapter_set( + struct llama_context * ctx, + struct llama_lora_adapter * adapter, + float scale); + + // Remove a LoRA adapter from given context + // Return -1 if the adapter is not present in the context + LLAMA_API int32_t llama_lora_adapter_remove( + struct llama_context * ctx, + struct llama_lora_adapter * adapter); + + // Manually free a LoRA adapter + // Note: loaded adapters will be free when the associated model is deleted + LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter); // Apply a loaded control vector to a llama_context, or if data is NULL, clear // the currently loaded vector. diff --git a/requirements.txt b/requirements.txt index 52456c2e6..9e190ae27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ -r ./requirements/requirements-convert_hf_to_gguf.txt -r ./requirements/requirements-convert_hf_to_gguf_update.txt -r ./requirements/requirements-convert_llama_ggml_to_gguf.txt +-r ./requirements/requirements-convert_lora_to_gguf.txt diff --git a/requirements/requirements-convert_lora_to_gguf.txt b/requirements/requirements-convert_lora_to_gguf.txt new file mode 100644 index 000000000..5758076c4 --- /dev/null +++ b/requirements/requirements-convert_lora_to_gguf.txt @@ -0,0 +1,2 @@ +-r ./requirements-convert_hf_to_gguf.txt +--extra-index-url https://download.pytorch.org/whl/cpu diff --git a/src/llama.cpp b/src/llama.cpp index ddf0262d4..07bb42713 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -287,6 +287,7 @@ static const std::map LLM_ARCH_NAMES = { }; enum llm_kv { + LLM_KV_GENERAL_TYPE, LLM_KV_GENERAL_ARCHITECTURE, LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_ALIGNMENT, @@ -377,9 +378,13 @@ enum llm_kv { LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, LLM_KV_TOKENIZER_EOT_ID, + + LLM_KV_ADAPTER_TYPE, + LLM_KV_ADAPTER_LORA_ALPHA, }; static const std::map LLM_KV_NAMES = { + { LLM_KV_GENERAL_TYPE, "general.type" }, { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, @@ -470,6 +475,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, }; struct LLM_KV { @@ -2703,6 +2711,9 @@ struct llama_model { int64_t t_load_us = 0; int64_t t_start_us = 0; + // keep track of loaded lora adapters + std::set lora_adapters; + ~llama_model() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); @@ -2715,6 +2726,9 @@ struct llama_model { #endif ggml_backend_buffer_free(buf); } + while (!lora_adapters.empty()) { + llama_lora_adapter_free(*lora_adapters.begin()); + } } }; @@ -2819,6 +2833,52 @@ struct llama_context { // control vectors struct llama_control_vector cvec; + + // lora adapters and scales + std::unordered_map lora_adapters; +}; + +struct llama_lora_weight { + struct ggml_tensor * a = nullptr; + struct ggml_tensor * b = nullptr; + llama_lora_weight() = default; + llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {} +}; + +struct llama_lora_adapter { + struct llama_model * base_model; + // map tensor name to lora_a_b + std::unordered_map ab_map; + std::vector ctxs; + std::vector bufs; + + float alpha; + + llama_lora_adapter(struct llama_model * base_model): base_model(base_model) { + base_model->lora_adapters.insert(this); + } + + llama_lora_weight * get_weight(struct ggml_tensor * w) { + std::string name(w->name); + auto pos = ab_map.find(name); + if (ab_map.find(name) != ab_map.end()) { + return &pos->second; + } + return nullptr; + } + + ~llama_lora_adapter() { + for (struct ggml_context * ctx : ctxs) { + ggml_free(ctx); + } + for (ggml_backend_buffer_t buf : bufs) { + ggml_backend_buffer_free(buf); + } + auto pos = base_model->lora_adapters.find(this); + if (pos != base_model->lora_adapters.end()) { + base_model->lora_adapters.erase(pos); + } + } }; static size_t llama_get_device_count(const llama_model & model) { @@ -7809,6 +7869,58 @@ static void llm_build_kv_store( ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); } +// do mat_mul, while optionally apply lora +static struct ggml_tensor * llm_build_lora_mm( + struct llama_context & lctx, + struct ggml_context * ctx0, + struct ggml_tensor * w, + struct ggml_tensor * cur) { + struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); + for (auto & it : lctx.lora_adapters) { + struct llama_lora_weight * lora = it.first->get_weight(w); + if (lora == nullptr) { + continue; + } + const float alpha = it.first->alpha; + const float rank = (float) lora->b->ne[0]; + const float scale = alpha ? it.second * alpha / rank : it.second; + struct ggml_tensor * ab_cur = ggml_mul_mat( + ctx0, lora->b, + ggml_mul_mat(ctx0, lora->a, cur) + ); + ab_cur = ggml_scale(ctx0, ab_cur, scale); + res = ggml_add(ctx0, res, ab_cur); + } + return res; +} + +// do mat_mul_id, while optionally apply lora +static struct ggml_tensor * llm_build_lora_mm_id( + struct llama_context & lctx, + struct ggml_context * ctx0, + struct ggml_tensor * w, // struct ggml_tensor * as + struct ggml_tensor * cur, // struct ggml_tensor * b + struct ggml_tensor * ids) { + struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids); + for (auto & it : lctx.lora_adapters) { + struct llama_lora_weight * lora = it.first->get_weight(w); + if (lora == nullptr) { + continue; + } + const float alpha = it.first->alpha; + const float rank = (float) lora->b->ne[0]; + const float scale = alpha ? it.second * alpha / rank : it.second; + struct ggml_tensor * ab_cur = ggml_mul_mat_id( + ctx0, lora->b, + ggml_mul_mat_id(ctx0, lora->a, cur, ids), + ids + ); + ab_cur = ggml_scale(ctx0, ab_cur, scale); + res = ggml_add(ctx0, res, ab_cur); + } + return res; +} + static struct ggml_tensor * llm_build_norm( struct ggml_context * ctx, struct ggml_tensor * cur, @@ -7843,6 +7955,7 @@ static struct ggml_tensor * llm_build_norm( static struct ggml_tensor * llm_build_ffn( struct ggml_context * ctx, + struct llama_context & lctx, struct ggml_tensor * cur, struct ggml_tensor * up, struct ggml_tensor * up_b, @@ -7858,7 +7971,7 @@ static struct ggml_tensor * llm_build_ffn( llm_ffn_gate_type type_gate, const llm_build_cb & cb, int il) { - struct ggml_tensor * tmp = up ? ggml_mul_mat(ctx, up, cur) : cur; + struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur; cb(tmp, "ffn_up", il); if (up_b) { @@ -7875,12 +7988,12 @@ static struct ggml_tensor * llm_build_ffn( switch (type_gate) { case LLM_FFN_SEQ: { - cur = ggml_mul_mat(ctx, gate, tmp); + cur = llm_build_lora_mm(lctx, ctx, gate, tmp); cb(cur, "ffn_gate", il); } break; case LLM_FFN_PAR: { - cur = ggml_mul_mat(ctx, gate, cur); + cur = llm_build_lora_mm(lctx, ctx, gate, cur); cb(cur, "ffn_gate", il); } break; } @@ -7948,7 +8061,7 @@ static struct ggml_tensor * llm_build_ffn( } if (down) { - cur = ggml_mul_mat(ctx, down, cur); + cur = llm_build_lora_mm(lctx, ctx, down, cur); } if (down_b) { @@ -7969,6 +8082,7 @@ static struct ggml_tensor * llm_build_ffn( static struct ggml_tensor * llm_build_moe_ffn( struct ggml_context * ctx, + struct llama_context & lctx, struct ggml_tensor * cur, struct ggml_tensor * gate_inp, struct ggml_tensor * up_exps, @@ -7985,7 +8099,7 @@ static struct ggml_tensor * llm_build_moe_ffn( int64_t n_embd = cur->ne[0]; int64_t n_tokens = cur->ne[1]; - ggml_tensor * logits = ggml_mul_mat(ctx, gate_inp, cur); // [n_expert, n_tokens] + ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] @@ -8017,10 +8131,10 @@ static struct ggml_tensor * llm_build_moe_ffn( } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); - ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(up, "ffn_moe_up", il); - ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(gate, "ffn_moe_gate", il); switch (type_op) { @@ -8041,7 +8155,7 @@ static struct ggml_tensor * llm_build_moe_ffn( ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] cb(par, "ffn_moe_gate_par", il); - ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] + ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); experts = ggml_mul(ctx, experts, weights); @@ -8069,9 +8183,7 @@ static struct ggml_tensor * llm_build_moe_ffn( static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, - const llama_model & model, - const llama_hparams & hparams, - const llama_cparams & cparams, + struct llama_context & lctx, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -8083,6 +8195,10 @@ static struct ggml_tensor * llm_build_kqv( float kq_scale, const llm_build_cb & cb, int il) { + const llama_model & model = lctx.model; + const llama_hparams & hparams = lctx.model.hparams; + const llama_cparams & cparams = lctx.cparams; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head(il); const int64_t n_head_kv = hparams.n_head_kv(il); @@ -8181,7 +8297,7 @@ static struct ggml_tensor * llm_build_kqv( ggml_build_forward_expand(graph, cur); if (wo) { - cur = ggml_mul_mat(ctx, wo, cur); + cur = llm_build_lora_mm(lctx, ctx, wo, cur); } if (wo_b) { @@ -8197,9 +8313,7 @@ static struct ggml_tensor * llm_build_kqv( static struct ggml_tensor * llm_build_kv( struct ggml_context * ctx, - const llama_model & model, - const llama_hparams & hparams, - const llama_cparams & cparams, + struct llama_context & lctx, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -8214,6 +8328,8 @@ static struct ggml_tensor * llm_build_kv( float kq_scale, const llm_build_cb & cb, int il) { + const llama_hparams & hparams = lctx.model.hparams; + const llama_cparams & cparams = lctx.cparams; // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced @@ -8225,7 +8341,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * cur; - cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, + cur = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b, q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); @@ -8687,21 +8803,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -8722,7 +8838,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -8745,7 +8861,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -8759,7 +8875,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_moe_ffn(ctx0, cur, + cur = llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -8789,7 +8905,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -8825,13 +8941,13 @@ struct llm_build_context { // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); switch (model.type) { @@ -8857,7 +8973,7 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -8879,7 +8995,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -8904,7 +9020,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -8940,13 +9056,13 @@ struct llm_build_context { // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = ggml_rope_ext( @@ -8962,7 +9078,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -8984,7 +9100,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -9007,7 +9123,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9056,7 +9172,7 @@ struct llm_build_context { cur = attn_norm; } - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); @@ -9083,7 +9199,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -9100,7 +9216,7 @@ struct llm_build_context { // feed forward { - cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result + cur = llm_build_ffn(ctx0, lctx, attn_norm, // !! use the attn norm, not the result model.layers[il].ffn_up, NULL, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -9127,7 +9243,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9172,21 +9288,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -9207,7 +9323,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -9239,7 +9355,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_moe_ffn(ctx0, cur, + cur = llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -9278,7 +9394,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); // Grok // multiply logits by output_multiplier_scale of 0.5773502691896257 @@ -9329,7 +9445,7 @@ struct llm_build_context { struct ggml_tensor * Kcur = nullptr; struct ggml_tensor * Vcur = nullptr; - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); @@ -9357,7 +9473,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -9380,7 +9496,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "attn_out_norm", il); - cur = llm_build_moe_ffn(ctx0, cur, + cur = llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -9409,7 +9525,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); @@ -9451,7 +9567,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -9467,7 +9583,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -9491,7 +9607,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -9514,7 +9630,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9546,13 +9662,13 @@ struct llm_build_context { // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); @@ -9561,7 +9677,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -9583,7 +9699,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -9608,7 +9724,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9660,7 +9776,7 @@ struct llm_build_context { // self-attention if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { - Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq); + Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur), model.layers[il].bq); cb(Qcur, "Qcur", il); if (model.layers[il].attn_q_norm) { @@ -9670,7 +9786,7 @@ struct llm_build_context { LLM_NORM, cb, il); } - Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk); + Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur), model.layers[il].bk); cb(Kcur, "Kcur", il); if (model.layers[il].attn_k_norm) { @@ -9679,14 +9795,14 @@ struct llm_build_context { model.layers[il].attn_k_norm_b, LLM_NORM, cb, il); } - Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv); + Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur), model.layers[il].bv); cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } else { // compute Q and K and RoPE them - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); @@ -9735,7 +9851,7 @@ struct llm_build_context { ggml_build_forward_expand(gf, cur); - cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); if (model.layers[il].bo) { cb(cur, "kqv_wo", il); } @@ -9768,21 +9884,21 @@ struct llm_build_context { // feed-forward network if (model.arch == LLM_ARCH_BERT) { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_PAR, cb, il); } else { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -9840,7 +9956,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -9856,7 +9972,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -9880,7 +9996,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -9903,7 +10019,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -9950,7 +10066,7 @@ struct llm_build_context { { cur = attn_norm; - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); if (model.layers[il].bqkv){ @@ -9988,13 +10104,13 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10018,7 +10134,7 @@ struct llm_build_context { model.layers[il].ffn_norm_b, LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -10043,7 +10159,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10083,21 +10199,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -10139,7 +10255,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10167,7 +10283,7 @@ struct llm_build_context { // parallel residual cur = inpSA; } - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -10193,7 +10309,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10228,7 +10344,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -10258,7 +10374,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10280,7 +10396,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -10305,7 +10421,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10343,17 +10459,17 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); @@ -10372,7 +10488,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10393,7 +10509,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -10417,7 +10533,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10458,17 +10574,17 @@ struct llm_build_context { // self_attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); @@ -10487,7 +10603,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10510,7 +10626,7 @@ struct llm_build_context { cb(cur, "ffn_norm", il); ggml_tensor * moe_out = - llm_build_moe_ffn(ctx0, cur, + llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -10523,14 +10639,14 @@ struct llm_build_context { // FFN shared expert { - ggml_tensor * cur_gate_inp = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp_shexp, cur); + ggml_tensor * cur_gate_inp = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur); cb(cur_gate_inp, "ffn_shexp_gate_inp", il); // sigmoid ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp); cb(cur_gate, "ffn_shexp_gate", il); - ggml_tensor * cur_ffn = llm_build_ffn(ctx0, cur, + ggml_tensor * cur_ffn = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -10563,7 +10679,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10605,7 +10721,7 @@ struct llm_build_context { struct ggml_tensor * Vcur = nullptr; if (model.layers[il].wqkv) { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -10615,9 +10731,9 @@ struct llm_build_context { Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); } else { - Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); } cb(Qcur, "Qcur", il); @@ -10644,7 +10760,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -10659,7 +10775,7 @@ struct llm_build_context { // FF { - ffn_output = llm_build_ffn(ctx0, attn_norm_output, + ffn_output = llm_build_ffn(ctx0, lctx, attn_norm_output, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -10683,7 +10799,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output_no_bias", -1); cur = ggml_add(ctx0, cur, model.output_b); @@ -10729,7 +10845,7 @@ struct llm_build_context { struct ggml_tensor * Vcur = nullptr; if (model.layers[il].wqkv) { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output); cb(cur, "wqkv", il); Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); @@ -10737,9 +10853,9 @@ struct llm_build_context { Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); } else { - Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq); + Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk); + Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv); } cb(Qcur, "Qcur", il); @@ -10764,7 +10880,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -10788,7 +10904,7 @@ struct llm_build_context { // special-case: the up and gate tensors are merged into a single tensor // TOOD: support into llm_build_ffn { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -10811,7 +10927,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10851,13 +10967,13 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = ggml_rope_ext( @@ -10872,7 +10988,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10890,7 +11006,7 @@ struct llm_build_context { // feed-forward network { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -10916,7 +11032,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10958,7 +11074,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -10974,7 +11090,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -10998,7 +11114,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -11021,7 +11137,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11057,7 +11173,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -11085,7 +11201,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -11109,7 +11225,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -11132,7 +11248,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11170,21 +11286,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); // if (model.layers[il].bq) { // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); // cb(Qcur, "Qcur", il); // } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); // if (model.layers[il].bk) { // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); // cb(Kcur, "Kcur", il); // } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); // if (model.layers[il].bv) { // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11205,7 +11321,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -11226,7 +11342,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -11250,7 +11366,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11288,21 +11404,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11323,7 +11439,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -11344,7 +11460,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -11368,7 +11484,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11419,21 +11535,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11454,7 +11570,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -11481,7 +11597,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -11515,7 +11631,7 @@ struct llm_build_context { cb(cur, "lmhead_scaling", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11552,13 +11668,13 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = ggml_rope_ext( @@ -11576,7 +11692,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -11598,7 +11714,7 @@ struct llm_build_context { // feed-forward network { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -11623,7 +11739,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11665,13 +11781,13 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = ggml_rope_ext( @@ -11694,7 +11810,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il); } @@ -11721,7 +11837,7 @@ struct llm_build_context { // feed-forward network { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -11751,7 +11867,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); // final logit soft-capping cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); @@ -11796,21 +11912,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11831,7 +11947,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -11853,7 +11969,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -11877,7 +11993,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -11929,7 +12045,7 @@ struct llm_build_context { cb(cur, "attn_norm", il); // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens} - struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur); + struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, cur); // split the above in two // => {d_inner, n_tokens} struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0); @@ -11968,14 +12084,14 @@ struct llm_build_context { // ssm { // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens} - struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x); + struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x); // split struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0); struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank); struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state)); // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens} - dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt); + dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); // Custom operator to optimize the parallel associative scan @@ -12006,7 +12122,7 @@ struct llm_build_context { y = ggml_mul(ctx0, y, ggml_silu(ctx0, z)); // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens} - cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y); } // residual @@ -12025,7 +12141,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -12064,21 +12180,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -12124,7 +12240,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -12141,7 +12257,7 @@ struct llm_build_context { // feed-forward network { - cur = llm_build_ffn(ctx0, ffn_inp, + cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -12168,7 +12284,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); @@ -12221,21 +12337,21 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (hparams.f_clamp_kqv > 0.0f) { Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); cb(Qcur, "Qcur", il); } - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); if (hparams.f_clamp_kqv > 0.0f) { Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); cb(Kcur, "Kcur", il); } - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); if (hparams.f_clamp_kqv > 0.0f) { Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); @@ -12256,7 +12372,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, nullptr, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -12278,7 +12394,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -12304,7 +12420,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -12344,7 +12460,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens); @@ -12383,7 +12499,7 @@ struct llm_build_context { Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens); cb(Qcur, "Vcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -12405,7 +12521,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -12429,7 +12545,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -12464,7 +12580,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -12492,7 +12608,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -12517,7 +12633,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -12548,7 +12664,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -12571,7 +12687,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -12612,13 +12728,13 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = ggml_rope_ext( @@ -12635,7 +12751,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } @@ -12657,7 +12773,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -12674,7 +12790,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm_exps", il); - cur = llm_build_moe_ffn(ctx0, cur, + cur = llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -12703,7 +12819,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -12857,7 +12973,7 @@ struct llm_build_context { struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(k_states, "k_states", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); } @@ -12879,7 +12995,7 @@ struct llm_build_context { cb(cur, "ffn_norm", il); if ((uint32_t) il < hparams.n_layer_dense_lead) { - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -12889,7 +13005,7 @@ struct llm_build_context { } else { // MoE branch ggml_tensor * moe_out = - llm_build_moe_ffn(ctx0, cur, + llm_build_moe_ffn(ctx0, lctx, cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, model.layers[il].ffn_gate_exps, @@ -12902,7 +13018,7 @@ struct llm_build_context { // FFN shared expert { - ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur, + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -12967,7 +13083,7 @@ struct llm_build_context { // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { @@ -12976,7 +13092,7 @@ struct llm_build_context { } // B1.K - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { @@ -12985,7 +13101,7 @@ struct llm_build_context { } // B1.V - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { @@ -13007,7 +13123,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, NULL, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); @@ -13016,7 +13132,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_sub_norm", il); - cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); if (model.layers[il].bo) { cur = ggml_add(ctx0, cur, model.layers[il].bo); @@ -13040,7 +13156,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, NULL, NULL, NULL, @@ -13053,7 +13169,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_sub_norm", il); - cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur); cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); cb(cur, "ffn_down", il); @@ -13072,7 +13188,7 @@ struct llm_build_context { cb(cur, "result_norm", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.tok_embd, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -13174,7 +13290,7 @@ struct llm_build_context { cb(cur, "ffn_norm", il); // T5 uses relu, flan-T5 uses gelu-gated - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up_enc, NULL, NULL, model.layers[il].ffn_gate_enc, NULL, NULL, model.layers[il].ffn_down_enc, NULL, NULL, @@ -13354,7 +13470,7 @@ struct llm_build_context { cb(cur, "ffn_norm", il); // T5 uses relu, flan-T5 uses gelu-gated - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -13420,7 +13536,7 @@ struct llm_build_context { // self-attention { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -13436,7 +13552,7 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il); } @@ -13460,7 +13576,7 @@ struct llm_build_context { LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -13479,7 +13595,7 @@ struct llm_build_context { LLM_NORM, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); @@ -13521,7 +13637,7 @@ struct llm_build_context { struct ggml_tensor * Kcur = nullptr; struct ggml_tensor * Vcur = nullptr; - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); cb(cur, "wqkv", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); @@ -13549,7 +13665,7 @@ struct llm_build_context { ); cb(Kcur, "Kcur_rope", il); - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); @@ -13574,7 +13690,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, + cur = llm_build_ffn(ctx0, lctx, cur, model.layers[il].ffn_up, NULL, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -13594,7 +13710,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, -1); cb(cur, "result_norm", -1); - cur = ggml_mul_mat(ctx0, model.output, cur); + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -18458,284 +18574,212 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } -static int llama_apply_lora_from_file_internal( - const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads -) { - LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); +static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) { + LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); - const int64_t t_start_lora_us = ggml_time_us(); - - llama_file fin(path_lora, "rb"); - - // verify magic and version - { - uint32_t magic = fin.read_u32(); - if (magic != LLAMA_FILE_MAGIC_GGLA) { - LLAMA_LOG_ERROR("%s: bad file magic\n", __func__); - return 1; - } - - uint32_t format_version = fin.read_u32(); - if (format_version != 1) { - LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ ); - return 1; - } - } - - int32_t lora_r = fin.read_u32(); - int32_t lora_alpha = fin.read_u32(); - float scaling = scale * (float)lora_alpha / (float)lora_r; - - LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); - - // load base model - std::unique_ptr ml; - if (path_base_model) { - LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); - ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr)); - ml->init_mappings(/*prefetch*/ false); // no prefetching - } - - struct tensor_meta { - std::string name; - ggml_type type; - int32_t ne[2]; - size_t offset; + ggml_context * ctx = nullptr; + struct gguf_init_params meta_gguf_params = { + /* .no_alloc = */ true, + /* .ctx = */ &ctx, }; - std::map tensor_meta_map; - - // load all tensor meta - while (true) { - if (fin.tell() == fin.size) { - // eof - break; - } - - int32_t n_dims; - int32_t name_len; - int32_t ftype; - - fin.read_raw(&n_dims, sizeof(n_dims)); - fin.read_raw(&name_len, sizeof(name_len)); - fin.read_raw(&ftype, sizeof(ftype)); - - if (n_dims != 1 && n_dims != 2) { - LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims); - return 1; - } - - int32_t ne[2] = { 1, 1 }; - for (int i = 0; i < n_dims; ++i) { - fin.read_raw(&ne[i], sizeof(ne[i])); - } - - std::string name; - { - GGML_ASSERT(name_len < GGML_MAX_NAME); - char buf[GGML_MAX_NAME]; - fin.read_raw(buf, name_len); - name = std::string(buf, name_len); - } - - // check for lora suffix - std::string lora_suffix; - if (name.length() > 6) { - lora_suffix = name.substr(name.length() - 6); - } - if (lora_suffix != ".loraA" && lora_suffix != ".loraB") { - LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str()); - return 1; - } - - // tensor type - ggml_type wtype; - switch (ftype) { - case 0: wtype = GGML_TYPE_F32; break; - case 1: wtype = GGML_TYPE_F16; break; - default: - { - LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n", - __func__, ftype); - return 1; - } - } - - // data offset - size_t offset = fin.tell(); - offset = (offset + 31) & -32; - - // skip tensor data - fin.seek(offset + ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET); - - tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset }); + struct gguf_context * ctx_gguf = gguf_init_from_file(path_lora, meta_gguf_params); + if (!ctx_gguf) { + throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora)); } - bool warned = false; - int n_tensors = 0; - - // apply - ggml_backend_t backend_cpu = ggml_backend_cpu_init(); - if (backend_cpu == nullptr) { - LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__); - return 1; - } - ggml_backend_cpu_set_n_threads(backend_cpu, n_threads); - - std::vector> read_buf; - for (const auto & it : model.tensors_by_name) { - const std::string & base_name = it.first; - ggml_tensor * model_t = it.second; - - if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() || - tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) { - continue; - } - - tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA"); - tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB"); - - ggml_init_params lora_init_params = { - /* .mem_size */ ggml_tensor_overhead()*128 + ggml_graph_overhead(), - /* .mem_buffer */ nullptr, - /* .no_alloc */ true, + // check metadata + { + auto get_kv_str = [&](const std::string & key) -> std::string { + int id = gguf_find_key(ctx_gguf, key.c_str()); + return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id)); }; - ggml_context * lora_ctx = ggml_init(lora_init_params); - if (lora_ctx == nullptr) { - LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__); - ggml_backend_free(backend_cpu); - return 1; + auto get_kv_f32 = [&](const std::string & key) -> float { + int id = gguf_find_key(ctx_gguf, key.c_str()); + return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf, id); + }; + LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); + + auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE)); + if (general_type != "adapter") { + gguf_free(ctx_gguf); + throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type); } - // create tensors - ggml_tensor * loraA = ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]); - ggml_tensor * loraB = ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]); - ggml_set_name(loraA, metaA.name.c_str()); - ggml_set_name(loraB, metaB.name.c_str()); + auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE)); + auto general_arch = llm_arch_from_string(general_arch_str); + if (general_arch != model->arch) { + gguf_free(ctx_gguf); + throw std::runtime_error("model arch and LoRA arch mismatch"); + } - ggml_tensor * base_t; - if (ml) { - if (!ml->get_tensor_meta(base_name.c_str())) { - LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str()); - return 1; + auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE)); + if (adapter_type != "lora") { + gguf_free(ctx_gguf); + throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type); + } + + adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA)); + } + + int n_tensors = gguf_get_n_tensors(ctx_gguf); + + // contexts for each buffer type + std::map ctx_map; + auto get_ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + // add a new context + struct ggml_init_params params = { + /*.mem_size =*/ n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * buft_ctx = ggml_init(params); + ctx_map[buft] = buft_ctx; + return buft_ctx; + }; + return it->second; + }; + + // bundle lora_a and lora_b into pairs + std::map ab_map; + auto str_endswith = [](const std::string & str, const std::string & suffix) { + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; + }; + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string name(cur->name); + if (str_endswith(name, ".lora_a")) { + replace_all(name, ".lora_a", ""); + if (ab_map.find(name) == ab_map.end()) { + ab_map[name] = llama_lora_weight(cur, nullptr); + } else { + ab_map[name].a = cur; + } + } else if (str_endswith(name, ".lora_b")) { + replace_all(name, ".lora_b", ""); + if (ab_map.find(name) == ab_map.end()) { + ab_map[name] = llama_lora_weight(nullptr, cur); + } else { + ab_map[name].b = cur; } - base_t = ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str())); } else { - base_t = ggml_dup_tensor(lora_ctx, model_t); - } - ggml_set_name(base_t, base_name.c_str()); - - // allocate in backend buffer - ggml_backend_buffer_t lora_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type()); - if (lora_buf == nullptr) { - LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__); - return 1; - } - - // load tensor data - auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, ggml_tensor * tensor) { - read_buf.resize(ggml_nbytes(tensor)); - fin.seek(tensor_meta.offset, SEEK_SET); - fin.read_raw(read_buf.data(), ggml_nbytes(tensor)); - ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size()); - }; - load_tensor(metaA, loraA); - load_tensor(metaB, loraB); - - // load base model tensor data - if (ml) { - ml->load_data_for(base_t); - } else { - ggml_backend_tensor_copy(model_t, base_t); - } - - if (ggml_is_quantized(base_t->type) && !warned) { - LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, " - "use a f16 or f32 base model with --lora-base\n", __func__); - warned = true; - } - - if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) { - LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" - " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]); - ggml_free(lora_ctx); - ggml_backend_buffer_free(lora_buf); - ggml_backend_free(backend_cpu); - return 1; - } - - auto build_lora_graph = [&]() { - // w = w + BA*s - ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB); - ggml_set_name(BA, "BA"); - - if (scaling != 1.0f) { - BA = ggml_scale(lora_ctx, BA, scaling); - ggml_set_name(BA, "BA_scaled"); - } - - ggml_tensor * r; - r = ggml_add_inplace(lora_ctx, base_t, BA); - ggml_set_name(r, "r_add"); - - if (base_t->type != model_t->type) { - // convert the result to the model type - r = ggml_cast(lora_ctx, r, model_t->type); - ggml_set_name(r, "r_cast"); - } - - return r; - }; - - ggml_cgraph * gf = ggml_new_graph(lora_ctx); - ggml_tensor * r = build_lora_graph(); - ggml_build_forward_expand(gf, r); - - ggml_backend_buffer_t graph_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type()); - if (graph_buf == nullptr) { - LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__); - ggml_free(lora_ctx); - ggml_backend_buffer_free(lora_buf); - ggml_backend_free(backend_cpu); - return 1; - } - - ggml_backend_graph_compute(backend_cpu, gf); - - ggml_backend_tensor_set(model_t, r->data, 0, ggml_nbytes(r)); - -#if 0 - // TODO: use scheduler with fallback to CPU for less copies between CPU and GPU - //ggml_backend_sched_t sched = ggml_backend_sched_new(backends.data(), backends.size(), GGML_DEFAULT_GRAPH_SIZE); - - // sched compute - ggml_build_forward_expand(gf, build_graph()); - ggml_backend_sched_init_measure(sched, gf); - - // create the graph again, since the previous one was destroyed by the measure - ggml_graph_clear(gf); - ggml_build_forward_expand(gf, build_graph()); - ggml_backend_sched_graph_compute(sched, gf); - ggml_backend_sched_free(sched); -#endif - - ggml_backend_buffer_free(lora_buf); - ggml_backend_buffer_free(graph_buf); - ggml_free(lora_ctx); - - n_tensors++; - if (n_tensors % 4 == 0) { - LLAMA_LOG_INFO("."); + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix"); } } - ggml_backend_free(backend_cpu); + // add tensors + for (auto & it : ab_map) { + const std::string & name = it.first; + llama_lora_weight & w = it.second; - const int64_t t_lora_us = ggml_time_us() - t_start_lora_us; - LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0); + if (!w.a || !w.b) { + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component"); + } + // device buft and device ctx + auto * model_tensor = llama_get_model_tensor(model, name.c_str()); + if (!model_tensor) { + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model"); + } + struct ggml_context * dev_ctx = get_ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); + // validate tensor shape + if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("tensor '" + name + "' has incorrect shape"); + } + if (w.a->ne[1] != w.b->ne[0]) { + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); + } + // save tensor to adapter + struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); + struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); + ggml_set_name(tensor_a, w.a->name); + ggml_set_name(tensor_b, w.b->name); + adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b); + } + + // allocate tensors / buffers and zero + { + adapter.ctxs.reserve(ctx_map.size()); + adapter.bufs.reserve(ctx_map.size()); + for (auto it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx_dev = it.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft); + if (!buf) { + gguf_free(ctx_gguf); + ggml_free(ctx); + throw std::runtime_error("failed to allocate buffer for lora adapter\n"); + } + LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + adapter.ctxs.push_back(ctx_dev); + adapter.bufs.push_back(buf); + } + } + + // set tensor data + { + llama_file gguf_file(path_lora, "rb"); + std::vector read_buf; + auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) { + size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, gguf_find_tensor(ctx_gguf, orig->name)); + size_t size = ggml_nbytes(orig); + read_buf.resize(size); + gguf_file.seek(offs, SEEK_SET); + gguf_file.read_raw(read_buf.data(), size); + ggml_backend_tensor_set(dev, read_buf.data(), 0, size); + }; + for (auto & it : adapter.ab_map) { + auto orig = ab_map[it.first]; + auto dev = it.second; + set_tensor(orig.a, dev.a); + set_tensor(orig.b, dev.b); + } + } + + LLAMA_LOG_INFO("%s: loaded %ld tensors from lora file\n", __func__, adapter.ab_map.size()*2); + + // free ctx for reading gguf + gguf_free(ctx_gguf); + ggml_free(ctx); +} + +int32_t llama_lora_adapter_set( + struct llama_context * ctx, + struct llama_lora_adapter * adapter, + float scale) { + if (ctx->cparams.flash_attn) { + LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__); + return -1; + } + ctx->lora_adapters[adapter] = scale; return 0; } +int32_t llama_lora_adapter_remove( + struct llama_context * ctx, + struct llama_lora_adapter * adapter) { + auto pos = ctx->lora_adapters.find(adapter); + if (pos != ctx->lora_adapters.end()) { + ctx->lora_adapters.erase(pos); + return 0; + } + return -1; +} + +void llama_lora_adapter_free(struct llama_lora_adapter * adapter) { + delete adapter; +} + // // interface implementation // @@ -19514,12 +19558,14 @@ uint32_t llama_model_quantize( } } -int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) { +struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) { try { - return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads); + struct llama_lora_adapter * adapter = new llama_lora_adapter(model); + llama_lora_adapter_init_internal(model, path_lora, *adapter); + return adapter; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); - return 1; + return nullptr; } }