mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
llama : add AWQ for llama, llama2, mpt, and mistral models (#4593)
* update: awq support llama-7b model * update: change order * update: benchmark results for llama2-7b * update: mistral 7b v1 benchmark * update: support 4 models * fix: Readme * update: ready for PR * update: readme * fix: readme * update: change order import * black * format code * update: work for bot mpt and awqmpt * update: readme * Rename to llm_build_ffn_mpt_awq * Formatted other files * Fixed params count * fix: remove code * update: more detail for mpt * fix: readme * fix: readme * update: change folder architecture * fix: common.cpp * fix: readme * fix: remove ggml_repeat * update: cicd * update: cicd * uppdate: remove use_awq arg * update: readme * llama : adapt plamo to new ffn ggml-ci --------- Co-authored-by: Trần Đức Nam <v.namtd12@vinai.io> Co-authored-by: Le Hoang Anh <v.anhlh33@vinai.io> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
879b690a9e
commit
f6793491b5
116
awq-py/README.md
Normal file
116
awq-py/README.md
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
# AWQ: Activation-aware Weight Quantization for LLM - version apply to llamacpp
|
||||||
|
[[Paper](https://arxiv.org/abs/2306.00978)][[Original Repo](https://github.com/mit-han-lab/llm-awq)][[Easy-to-use Repo](https://github.com/casper-hansen/AutoAWQ)]
|
||||||
|
|
||||||
|
**Supported models:**
|
||||||
|
|
||||||
|
- [X] LLaMA
|
||||||
|
- [x] LLaMA 2
|
||||||
|
- [X] MPT
|
||||||
|
- [X] Mistral AI v0.1
|
||||||
|
- [ ] Bloom
|
||||||
|
- [ ] Mixtral MoE
|
||||||
|
|
||||||
|
**TODO:**
|
||||||
|
- [x] Update version work with both MPT and MPT-AWQ model
|
||||||
|
- [ ] Add OPT model
|
||||||
|
- [ ] Add Bloom model
|
||||||
|
- [ ] Add Mixtral MoE
|
||||||
|
- [ ] Support w3, w2
|
||||||
|
|
||||||
|
|
||||||
|
## Contents
|
||||||
|
|
||||||
|
- [Install](##Install)
|
||||||
|
- [Convert](##Convert)
|
||||||
|
- [Quantize](##Quantize)
|
||||||
|
- [Test](##Test)
|
||||||
|
- [Benchmark](##Benchmark)
|
||||||
|
- [Results](##Results)
|
||||||
|
|
||||||
|
## Install
|
||||||
|
Install requirements
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
Get the pre-computed AWQ search results for multiple model families, including LLaMA, LLaMA2, MPT, OPT
|
||||||
|
```bash
|
||||||
|
git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo awq_cache
|
||||||
|
```
|
||||||
|
|
||||||
|
## Convert
|
||||||
|
Example for llama model
|
||||||
|
```bash
|
||||||
|
# For llama7b and llama2 models
|
||||||
|
python convert.py models/llama-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --outfile models/llama_7b_fp16.gguf
|
||||||
|
# For mistral and mpt models
|
||||||
|
python convert-hf-to-gguf.py models/mpt-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --outfile models/mpt_7b_fp16.gguf
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quantize
|
||||||
|
```bash
|
||||||
|
# We only benchmark and confirm the results on q4_0, q4_1, and q2_k types.
|
||||||
|
./quantize models/llama_7b_fp16.gguf models/llama_7b_q4_0.gguf q4_0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test
|
||||||
|
```bash
|
||||||
|
# For all models.
|
||||||
|
./build/bin/main -m models/llama_7b_q4_0.gguf -n 128 --prompt "Once upon a time"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Benchmark
|
||||||
|
The perplexity measurements in table above are done against the `wikitext2` test dataset (https://paperswithcode.com/dataset/wikitext-2), with context length of 512.
|
||||||
|
```bash
|
||||||
|
# For llama and llama2, and mistral models.
|
||||||
|
./perplexity -m models/llama_7b_q4_0.gguf -f datasets/wikitext-2-raw/wiki.test.raw
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results
|
||||||
|
Results are run on OpenBLAS (CPU) and CuBLAS (GPU) for fair comparison
|
||||||
|
We use three types of llamacpp quantization methods to work with our version, including q4_0, q4_1, and q2_k
|
||||||
|
|
||||||
|
### Llama 7B (Build with OpenBLAS)
|
||||||
|
|
||||||
|
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|
||||||
|
|-----------:|--------------|-------:|-------:|-------:|-------:|
|
||||||
|
|Llama 7B | perplexity | 5.9066 | 6.1214 | 6.0643 | 6.5808 |
|
||||||
|
|Llama 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||||
|
|Llama 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||||
|
|AWQ-LLama 7B| perplexity | 5.9175 | 6.0252 | 5.9987 | 6.3692 |
|
||||||
|
|AWQ-LLama 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||||
|
|AWQ-LLama 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||||
|
|
||||||
|
|
||||||
|
### Llama2 7B (Build with CuBLAS)
|
||||||
|
|
||||||
|
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|
||||||
|
|------------:|--------------|-------:|-------:|-------:|-------:|
|
||||||
|
|Llama2 7B | perplexity | 5.8664 | 6.0260 | 6.0656 | 6.4496 |
|
||||||
|
|Llama2 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||||
|
|Llama2 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||||
|
|AWQ-LLama2 7B| perplexity | 5.8801 | 6.0054 | 5.9849 | 6.3650 |
|
||||||
|
|AWQ-LLama2 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||||
|
|AWQ-LLama2 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||||
|
|
||||||
|
|
||||||
|
### Mistral 7B v0.1 (Build with CuBLAS)
|
||||||
|
|
||||||
|
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|
||||||
|
|-------------:|--------------|-------:|-------:|-------:|-------:|
|
||||||
|
|Mistral 7B | perplexity | 5.6931 | 5.8202 | 5.8268 | 6.1645 |
|
||||||
|
|Mistral 7B | file size | 14.5G | 4.1G | 4.5G | 3.1G |
|
||||||
|
|Mistral 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||||
|
|AWQ-Mistral 7B| perplexity | 5.6934 | 5.8020 | 5.7691 | 6.0426 |
|
||||||
|
|AWQ-Mistral 7B| file size | 14.5G | 4.1G | 4.5G | 3.1G |
|
||||||
|
|AWQ-Mistral 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||||
|
|
||||||
|
### MPT 7B (Build with OpenBLAS)
|
||||||
|
|
||||||
|
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|
||||||
|
|---------:|--------------|-------:|-------:|-------:|--------:|
|
||||||
|
|MPT 7B | perplexity | 8.4369 | 8.7956 | 8.6265 | 11.4913 |
|
||||||
|
|MPT 7B | file size | 13.7G | 3.9G | 4.3G | 2.8G |
|
||||||
|
|MPT 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||||
|
|AWQ-MPT 7B| perplexity | 8.4944 | 8.7053 | 8.6750 | 10.2873|
|
||||||
|
|AWQ-MPT 7B| file size | 13.7G | 3.9G | 4.3G | 2.8G |
|
||||||
|
|AWQ-MPT 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
254
awq-py/awq/apply_awq.py
Normal file
254
awq-py/awq/apply_awq.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
"""
|
||||||
|
Implements the AWQ for llama.cpp use cases.
|
||||||
|
Original paper: https://arxiv.org/abs/2306.00978
|
||||||
|
|
||||||
|
This code is based on versions of the AWQ implementation found in the following repositories:
|
||||||
|
* https://github.com/mit-han-lab/llm-awq
|
||||||
|
* https://github.com/casper-hansen/AutoAWQ
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from transformers import AutoModelForCausalLM, AutoConfig
|
||||||
|
from transformers.models.bloom.modeling_bloom import BloomGelu
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
from transformers.activations import GELUActivation
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledActivation(nn.Module):
|
||||||
|
"""
|
||||||
|
ScaledActivation module wraps an existing activation function and applies a
|
||||||
|
scale factor to its output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): The activation function to be scaled.
|
||||||
|
scales (torch.Tensor): A tensor of size (num_features,) containing the initial
|
||||||
|
scale factors for each feature.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The scaled output of the activation function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, module, scales):
|
||||||
|
super().__init__()
|
||||||
|
self.act = module
|
||||||
|
self.scales = nn.Parameter(scales.data)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
|
||||||
|
def set_op_by_name(layer, name, new_module):
|
||||||
|
"""
|
||||||
|
Set the new module for given module's name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer (nn.Module): The layer in which to replace the submodule.
|
||||||
|
name (str): The path to the submodule to be replaced, using dot notation
|
||||||
|
to access nested modules.
|
||||||
|
new_module (nn.Module): The new module to replace the existing one.
|
||||||
|
"""
|
||||||
|
levels = name.split(".")
|
||||||
|
if len(levels) > 1:
|
||||||
|
mod_ = layer
|
||||||
|
for l_idx in range(len(levels) - 1):
|
||||||
|
if levels[l_idx].isdigit():
|
||||||
|
mod_ = mod_[int(levels[l_idx])]
|
||||||
|
else:
|
||||||
|
mod_ = getattr(mod_, levels[l_idx])
|
||||||
|
setattr(mod_, levels[-1], new_module)
|
||||||
|
else:
|
||||||
|
setattr(layer, name, new_module)
|
||||||
|
|
||||||
|
|
||||||
|
def get_op_by_name(module, op_name):
|
||||||
|
"""
|
||||||
|
Retrieves a submodule within a given layer based on its name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): The layer containing the submodule to find.
|
||||||
|
op_name (str): The name of the submodule.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The requested submodule found within the given layer.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the specified submodule cannot be found within the layer.
|
||||||
|
"""
|
||||||
|
for name, m in module.named_modules():
|
||||||
|
if name == op_name:
|
||||||
|
return m
|
||||||
|
raise ValueError(f"Cannot find op {op_name} in module {module}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def scale_ln_fcs(ln, fcs, scales):
|
||||||
|
"""
|
||||||
|
Scales the weights of a LayerNorm and a list of fully-connected layers proportionally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ln (nn.LayerNorm): The LayerNorm module to be scaled.
|
||||||
|
fcs (List[nn.Linear]): A list of fully-connected layers to be scaled.
|
||||||
|
scales (torch.Tensor): A 1D tensor of size (num_features,).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(fcs, list):
|
||||||
|
fcs = [fcs]
|
||||||
|
|
||||||
|
scales = scales.to(ln.weight.device)
|
||||||
|
|
||||||
|
ln.weight.div_(scales)
|
||||||
|
if hasattr(ln, "bias") and ln.bias is not None:
|
||||||
|
ln.bias.div_(scales)
|
||||||
|
|
||||||
|
for fc in fcs:
|
||||||
|
fc.weight.mul_(scales.view(1, -1))
|
||||||
|
|
||||||
|
for p in ln.parameters():
|
||||||
|
assert torch.isnan(p).sum() == 0
|
||||||
|
for fc in fcs:
|
||||||
|
for p in fc.parameters():
|
||||||
|
assert torch.isnan(p).sum() == 0
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def scale_fc_fc(fc1, fc2, scales):
|
||||||
|
"""
|
||||||
|
Scales the weights of two fully-connected layers in a specific pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fc1 (nn.Linear): The first fully-connected layer to be scaled.
|
||||||
|
fc2 (nn.Linear): The second fully-connected layer to be scaled.
|
||||||
|
scales (torch.Tensor): A 1D tensor of size (num_features,).
|
||||||
|
"""
|
||||||
|
assert isinstance(fc1, nn.Linear)
|
||||||
|
assert isinstance(fc2, nn.Linear)
|
||||||
|
|
||||||
|
scales = scales.to(fc1.weight.device)
|
||||||
|
|
||||||
|
fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
|
||||||
|
if fc1.bias is not None:
|
||||||
|
fc1.bias.div_(scales.view(-1))
|
||||||
|
|
||||||
|
fc2.weight.mul_(scales.view(1, -1))
|
||||||
|
|
||||||
|
for p in fc1.parameters():
|
||||||
|
assert torch.isnan(p).sum() == 0
|
||||||
|
for p in fc2.parameters():
|
||||||
|
assert torch.isnan(p).sum() == 0
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def scale_gelu_fc(gelu, fc, scales):
|
||||||
|
"""
|
||||||
|
Scales the weight of a GELU activation and a fully-connected layer proportionally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gelu (Union[nn.GELU, BloomGelu, GELUActivation]): The GELU activation module to be scaled.
|
||||||
|
fc (nn.Linear): The fully-connected layer to be scaled.
|
||||||
|
scales (torch.Tensor): A 1D tensor of size (num_features,).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the `gelu` module is not of type `nn.GELU`, `BloomGelu`, or `GELUActivation`.
|
||||||
|
TypeError: If the `fc` module is not of type `nn.Linear`.
|
||||||
|
"""
|
||||||
|
assert isinstance(gelu, (nn.GELU, BloomGelu, GELUActivation))
|
||||||
|
assert isinstance(fc, nn.Linear)
|
||||||
|
|
||||||
|
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
|
||||||
|
|
||||||
|
for p in fc.parameters():
|
||||||
|
assert torch.isnan(p).sum() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def apply_scale(module, scales_list, input_feat_dict=None):
|
||||||
|
"""
|
||||||
|
Applies different scaling strategies to layers based on their type and hierarchy within a given module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): The module containing the layers to be scaled.
|
||||||
|
scales_list (List[Tuple[str, List[str], torch.Tensor]]): A list of tuples containing:
|
||||||
|
* prev_op_name (str): The name of the preceding operation or module,
|
||||||
|
relative to which the layers to be scaled are located.
|
||||||
|
* layer_names (List[str]): A list of names of the layers to be scaled, relative to the preceding operation.
|
||||||
|
* scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
|
||||||
|
input_feat_dict (Optional[Dict[str, torch.Tensor]]): A dictionary mapping layer names to their corresponding
|
||||||
|
input features (optional).
|
||||||
|
"""
|
||||||
|
for prev_op_name, layer_names, scales in scales_list:
|
||||||
|
prev_op = get_op_by_name(module, prev_op_name)
|
||||||
|
layers = [get_op_by_name(module, name) for name in layer_names]
|
||||||
|
|
||||||
|
prev_op.cuda()
|
||||||
|
for layer in layers:
|
||||||
|
layer.cuda()
|
||||||
|
scales.cuda()
|
||||||
|
|
||||||
|
if isinstance(prev_op, nn.Linear):
|
||||||
|
assert len(layers) == 1
|
||||||
|
scale_fc_fc(prev_op, layers[0], scales)
|
||||||
|
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)) or "rmsnorm" in str(prev_op.__class__).lower():
|
||||||
|
scale_ln_fcs(prev_op, layers, scales)
|
||||||
|
elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)):
|
||||||
|
new_module = ScaledActivation(prev_op, scales)
|
||||||
|
set_op_by_name(module, prev_op_name, new_module)
|
||||||
|
scale_gelu_fc(prev_op, layers[0], scales)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")
|
||||||
|
|
||||||
|
# apply the scaling to input feat if given; prepare it for clipping
|
||||||
|
if input_feat_dict is not None:
|
||||||
|
for layer_name in layer_names:
|
||||||
|
inp = input_feat_dict[layer_name]
|
||||||
|
inp.div_(scales.view(1, -1).to(inp.device))
|
||||||
|
|
||||||
|
prev_op.cpu()
|
||||||
|
for layer in layers:
|
||||||
|
layer.cpu()
|
||||||
|
scales.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def apply_clip(module, clip_list):
|
||||||
|
"""
|
||||||
|
Applies element-wise clipping to the weight of a specific layer within a given module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): The module containing the layer to be clipped.
|
||||||
|
clip_list (List[Tuple[str, torch.Tensor]]): A list of tuples containing:
|
||||||
|
* name (str): The name of the layer to be clipped, relative to the root of the module.
|
||||||
|
* max_val (torch.Tensor): A 1D or 2D tensor defining the upper bound for each element of the layer's weight.
|
||||||
|
"""
|
||||||
|
for name, max_val in clip_list:
|
||||||
|
layer = get_op_by_name(module, name)
|
||||||
|
layer.cuda()
|
||||||
|
max_val = max_val.to(layer.weight.device)
|
||||||
|
org_shape = layer.weight.shape
|
||||||
|
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
|
||||||
|
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
|
||||||
|
layer.weight.data = layer.weight.data.reshape(org_shape)
|
||||||
|
layer.cpu()
|
||||||
|
|
||||||
|
|
||||||
|
def add_scale_weights(model_path, scale_path, tmp_path):
|
||||||
|
"""
|
||||||
|
Adds pre-computed Activation Weight Quantization (AWQ) results to a model,
|
||||||
|
including scaling factors and clipping bounds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): Path to the pre-trained model to be equipped with AWQ.
|
||||||
|
scale_path (str): Path to the AWQ scale factors (.pt file).
|
||||||
|
tmp_path (str): Path to the temporary directory where the equipped model will be saved.
|
||||||
|
"""
|
||||||
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path, config=config, trust_remote_code=True
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
awq_results = torch.load(str(scale_path), map_location="cpu")
|
||||||
|
apply_scale(model, awq_results["scale"])
|
||||||
|
apply_clip(model, awq_results["clip"])
|
||||||
|
model.save_pretrained(str(tmp_path))
|
||||||
|
os.system(f"cp {str(model_path)}/tokenizer* {str(tmp_path)}")
|
2
awq-py/requirements.txt
Normal file
2
awq-py/requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
torch>=2.0.0
|
||||||
|
transformers>=4.32.0
|
@ -46,7 +46,7 @@ class Model:
|
|||||||
self.part_names = self._get_part_names()
|
self.part_names = self._get_part_names()
|
||||||
self.hparams = Model.load_hparams(self.dir_model)
|
self.hparams = Model.load_hparams(self.dir_model)
|
||||||
self.model_arch = self._get_model_architecture()
|
self.model_arch = self._get_model_architecture()
|
||||||
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess)
|
self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False)
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
self._set_vocab_gpt2()
|
self._set_vocab_gpt2()
|
||||||
@ -59,7 +59,7 @@ class Model:
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
||||||
else:
|
else:
|
||||||
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", weights_only=True))
|
||||||
|
|
||||||
with ctx as model_part:
|
with ctx as model_part:
|
||||||
for name in model_part.keys():
|
for name in model_part.keys():
|
||||||
@ -464,7 +464,11 @@ class MPTModel(Model):
|
|||||||
data = data_torch.squeeze().numpy()
|
data = data_torch.squeeze().numpy()
|
||||||
|
|
||||||
# map tensor names
|
# map tensor names
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
if "scales" in name:
|
||||||
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
|
||||||
|
new_name = new_name.replace("scales", "act.scales")
|
||||||
|
else:
|
||||||
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
if new_name is None:
|
if new_name is None:
|
||||||
print(f"Can not map tensor {name!r}")
|
print(f"Can not map tensor {name!r}")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
@ -1095,6 +1099,9 @@ def parse_args() -> argparse.Namespace:
|
|||||||
"--vocab-only", action="store_true",
|
"--vocab-only", action="store_true",
|
||||||
help="extract only the vocab",
|
help="extract only the vocab",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--awq-path", type=Path, default=None,
|
||||||
|
help="Path to scale awq cache file")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--outfile", type=Path,
|
"--outfile", type=Path,
|
||||||
help="path to write to; default: based on input",
|
help="path to write to; default: based on input",
|
||||||
@ -1115,6 +1122,20 @@ def parse_args() -> argparse.Namespace:
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
dir_model = args.model
|
dir_model = args.model
|
||||||
|
|
||||||
|
if args.awq_path:
|
||||||
|
sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
|
||||||
|
from awq.apply_awq import add_scale_weights
|
||||||
|
tmp_model_path = args.model / "weighted_model"
|
||||||
|
dir_model = tmp_model_path
|
||||||
|
if tmp_model_path.is_dir():
|
||||||
|
print(f"{tmp_model_path} exists as a weighted model.")
|
||||||
|
else:
|
||||||
|
tmp_model_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
print("Saving new weighted model ...")
|
||||||
|
add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
|
||||||
|
print(f"Saved weighted model at {tmp_model_path}.")
|
||||||
|
|
||||||
if not dir_model.is_dir():
|
if not dir_model.is_dir():
|
||||||
print(f'Error: {args.model} is not a directory', file=sys.stderr)
|
print(f'Error: {args.model} is not a directory', file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
14
convert.py
14
convert.py
@ -1187,6 +1187,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||||||
# We currently only support Q8_0 output on little endian systems.
|
# We currently only support Q8_0 output on little endian systems.
|
||||||
output_choices.append("q8_0")
|
output_choices.append("q8_0")
|
||||||
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
|
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
|
||||||
|
parser.add_argument("--awq-path", type=Path, help="Path to scale awq cache file", default=None)
|
||||||
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
|
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
|
||||||
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
|
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
|
||||||
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
|
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
|
||||||
@ -1200,6 +1201,19 @@ def main(args_in: list[str] | None = None) -> None:
|
|||||||
parser.add_argument("--padvocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
|
parser.add_argument("--padvocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
|
||||||
|
|
||||||
args = parser.parse_args(args_in)
|
args = parser.parse_args(args_in)
|
||||||
|
if args.awq_path:
|
||||||
|
sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
|
||||||
|
from awq.apply_awq import add_scale_weights
|
||||||
|
tmp_model_path = args.model / "weighted_model"
|
||||||
|
if tmp_model_path.is_dir():
|
||||||
|
print(f"{tmp_model_path} exists as a weighted model.")
|
||||||
|
else:
|
||||||
|
tmp_model_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
print("Saving new weighted model ...")
|
||||||
|
add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
|
||||||
|
print(f"Saved weighted model at {tmp_model_path}.")
|
||||||
|
args.model = tmp_model_path
|
||||||
|
|
||||||
if args.dump_single:
|
if args.dump_single:
|
||||||
model_plus = lazy_load_file(args.model)
|
model_plus = lazy_load_file(args.model)
|
||||||
do_dump_model(model_plus)
|
do_dump_model(model_plus)
|
||||||
|
@ -120,6 +120,7 @@ class MODEL_TENSOR(IntEnum):
|
|||||||
FFN_GATE = auto()
|
FFN_GATE = auto()
|
||||||
FFN_DOWN = auto()
|
FFN_DOWN = auto()
|
||||||
FFN_UP = auto()
|
FFN_UP = auto()
|
||||||
|
FFN_ACT = auto()
|
||||||
FFN_GATE_EXP = auto()
|
FFN_GATE_EXP = auto()
|
||||||
FFN_DOWN_EXP = auto()
|
FFN_DOWN_EXP = auto()
|
||||||
FFN_UP_EXP = auto()
|
FFN_UP_EXP = auto()
|
||||||
@ -169,6 +170,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||||
|
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
|
||||||
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
|
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
|
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
|
||||||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
|
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
|
||||||
@ -269,6 +271,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.FFN_NORM,
|
MODEL_TENSOR.FFN_NORM,
|
||||||
MODEL_TENSOR.FFN_DOWN,
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
MODEL_TENSOR.FFN_UP,
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
MODEL_TENSOR.FFN_ACT,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.GPTJ: [
|
MODEL_ARCH.GPTJ: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
@ -188,6 +188,11 @@ class TensorNameMap:
|
|||||||
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
|
"model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# AWQ-activation gate
|
||||||
|
MODEL_TENSOR.FFN_ACT: (
|
||||||
|
"transformer.blocks.{bid}.ffn.act", # mpt
|
||||||
|
),
|
||||||
|
|
||||||
# Feed-forward gate
|
# Feed-forward gate
|
||||||
MODEL_TENSOR.FFN_GATE: (
|
MODEL_TENSOR.FFN_GATE: (
|
||||||
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
|
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
|
||||||
|
27
llama.cpp
27
llama.cpp
@ -354,6 +354,7 @@ enum llm_tensor {
|
|||||||
LLM_TENSOR_FFN_GATE,
|
LLM_TENSOR_FFN_GATE,
|
||||||
LLM_TENSOR_FFN_DOWN,
|
LLM_TENSOR_FFN_DOWN,
|
||||||
LLM_TENSOR_FFN_UP,
|
LLM_TENSOR_FFN_UP,
|
||||||
|
LLM_TENSOR_FFN_ACT,
|
||||||
LLM_TENSOR_FFN_DOWN_EXP,
|
LLM_TENSOR_FFN_DOWN_EXP,
|
||||||
LLM_TENSOR_FFN_GATE_EXP,
|
LLM_TENSOR_FFN_GATE_EXP,
|
||||||
LLM_TENSOR_FFN_UP_EXP,
|
LLM_TENSOR_FFN_UP_EXP,
|
||||||
@ -473,6 +474,7 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
|||||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1285,6 +1287,7 @@ struct llama_hparams {
|
|||||||
float f_clamp_kqv;
|
float f_clamp_kqv;
|
||||||
float f_max_alibi_bias;
|
float f_max_alibi_bias;
|
||||||
|
|
||||||
|
|
||||||
bool operator!=(const llama_hparams & other) const {
|
bool operator!=(const llama_hparams & other) const {
|
||||||
if (this->vocab_only != other.vocab_only) return true;
|
if (this->vocab_only != other.vocab_only) return true;
|
||||||
if (this->n_vocab != other.n_vocab) return true;
|
if (this->n_vocab != other.n_vocab) return true;
|
||||||
@ -1388,6 +1391,7 @@ struct llama_layer {
|
|||||||
// ff bias
|
// ff bias
|
||||||
struct ggml_tensor * ffn_down_b; // b2
|
struct ggml_tensor * ffn_down_b; // b2
|
||||||
struct ggml_tensor * ffn_up_b; // b3
|
struct ggml_tensor * ffn_up_b; // b3
|
||||||
|
struct ggml_tensor * ffn_act;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_kv_cell {
|
struct llama_kv_cell {
|
||||||
@ -3471,7 +3475,6 @@ static bool llm_load_tensors(
|
|||||||
case LLM_ARCH_MPT:
|
case LLM_ARCH_MPT:
|
||||||
{
|
{
|
||||||
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||||
|
|
||||||
// output
|
// output
|
||||||
{
|
{
|
||||||
ggml_backend_type backend_norm;
|
ggml_backend_type backend_norm;
|
||||||
@ -3509,6 +3512,9 @@ static bool llm_load_tensors(
|
|||||||
|
|
||||||
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||||
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||||
|
|
||||||
|
// AWQ ScaleActivation layer
|
||||||
|
layer.ffn_act = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, backend, false);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_STABLELM:
|
case LLM_ARCH_STABLELM:
|
||||||
@ -4039,6 +4045,7 @@ static struct ggml_tensor * llm_build_ffn(
|
|||||||
struct ggml_tensor * gate_b,
|
struct ggml_tensor * gate_b,
|
||||||
struct ggml_tensor * down,
|
struct ggml_tensor * down,
|
||||||
struct ggml_tensor * down_b,
|
struct ggml_tensor * down_b,
|
||||||
|
struct ggml_tensor * act_scales,
|
||||||
llm_ffn_op_type type_op,
|
llm_ffn_op_type type_op,
|
||||||
llm_ffn_gate_type type_gate,
|
llm_ffn_gate_type type_gate,
|
||||||
const llm_build_cb & cb,
|
const llm_build_cb & cb,
|
||||||
@ -4083,6 +4090,10 @@ static struct ggml_tensor * llm_build_ffn(
|
|||||||
{
|
{
|
||||||
cur = ggml_gelu(ctx, cur);
|
cur = ggml_gelu(ctx, cur);
|
||||||
cb(cur, "ffn_gelu", il);
|
cb(cur, "ffn_gelu", il);
|
||||||
|
if (act_scales != NULL) {
|
||||||
|
cur = ggml_div(ctx, cur, act_scales);
|
||||||
|
cb(cur, "ffn_act", il);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_FFN_RELU:
|
case LLM_FFN_RELU:
|
||||||
{
|
{
|
||||||
@ -4401,6 +4412,7 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up, NULL,
|
model.layers[il].ffn_up, NULL,
|
||||||
model.layers[il].ffn_gate, NULL,
|
model.layers[il].ffn_gate, NULL,
|
||||||
model.layers[il].ffn_down, NULL,
|
model.layers[il].ffn_down, NULL,
|
||||||
|
NULL,
|
||||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
} else {
|
} else {
|
||||||
@ -4580,6 +4592,7 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up, NULL,
|
model.layers[il].ffn_up, NULL,
|
||||||
model.layers[il].ffn_gate, NULL,
|
model.layers[il].ffn_gate, NULL,
|
||||||
model.layers[il].ffn_down, NULL,
|
model.layers[il].ffn_down, NULL,
|
||||||
|
NULL,
|
||||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
@ -4694,6 +4707,7 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up, NULL,
|
model.layers[il].ffn_up, NULL,
|
||||||
NULL, NULL,
|
NULL, NULL,
|
||||||
model.layers[il].ffn_down, NULL,
|
model.layers[il].ffn_down, NULL,
|
||||||
|
NULL,
|
||||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
@ -4798,6 +4812,7 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||||
NULL, NULL,
|
NULL, NULL,
|
||||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||||
|
NULL,
|
||||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
@ -5002,6 +5017,7 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||||
NULL, NULL,
|
NULL, NULL,
|
||||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||||
|
NULL,
|
||||||
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
|
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
@ -5088,6 +5104,7 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up, NULL,
|
model.layers[il].ffn_up, NULL,
|
||||||
model.layers[il].ffn_gate, NULL,
|
model.layers[il].ffn_gate, NULL,
|
||||||
model.layers[il].ffn_down, NULL,
|
model.layers[il].ffn_down, NULL,
|
||||||
|
NULL,
|
||||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
@ -5183,6 +5200,7 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||||
NULL, NULL,
|
NULL, NULL,
|
||||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||||
|
NULL,
|
||||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
@ -5268,11 +5286,11 @@ struct llm_build_context {
|
|||||||
NULL,
|
NULL,
|
||||||
LLM_NORM, cb, il);
|
LLM_NORM, cb, il);
|
||||||
cb(cur, "ffn_norm", il);
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
cur = llm_build_ffn(ctx0, cur,
|
cur = llm_build_ffn(ctx0, cur,
|
||||||
model.layers[il].ffn_up, NULL,
|
model.layers[il].ffn_up, NULL,
|
||||||
NULL, NULL,
|
NULL, NULL,
|
||||||
model.layers[il].ffn_down, NULL,
|
model.layers[il].ffn_down, NULL,
|
||||||
|
model.layers[il].ffn_act,
|
||||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
@ -5381,6 +5399,7 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up, NULL,
|
model.layers[il].ffn_up, NULL,
|
||||||
model.layers[il].ffn_gate, NULL,
|
model.layers[il].ffn_gate, NULL,
|
||||||
model.layers[il].ffn_down, NULL,
|
model.layers[il].ffn_down, NULL,
|
||||||
|
NULL,
|
||||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
@ -5493,6 +5512,7 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up, NULL,
|
model.layers[il].ffn_up, NULL,
|
||||||
model.layers[il].ffn_gate, NULL,
|
model.layers[il].ffn_gate, NULL,
|
||||||
model.layers[il].ffn_down, NULL,
|
model.layers[il].ffn_down, NULL,
|
||||||
|
NULL,
|
||||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
@ -5600,6 +5620,7 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||||
NULL, NULL,
|
NULL, NULL,
|
||||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||||
|
NULL,
|
||||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||||
cb(ffn_output, "ffn_out", il);
|
cb(ffn_output, "ffn_out", il);
|
||||||
}
|
}
|
||||||
@ -5703,6 +5724,7 @@ struct llm_build_context {
|
|||||||
model.layers[il].ffn_up, NULL,
|
model.layers[il].ffn_up, NULL,
|
||||||
model.layers[il].ffn_gate, NULL,
|
model.layers[il].ffn_gate, NULL,
|
||||||
model.layers[il].ffn_down, NULL,
|
model.layers[il].ffn_down, NULL,
|
||||||
|
NULL,
|
||||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
}
|
}
|
||||||
@ -5887,6 +5909,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
|
|||||||
{ "ffn_gate", OFFLOAD_FUNC },
|
{ "ffn_gate", OFFLOAD_FUNC },
|
||||||
{ "ffn_gate_b", OFFLOAD_FUNC },
|
{ "ffn_gate_b", OFFLOAD_FUNC },
|
||||||
{ "ffn_gate_par", OFFLOAD_FUNC },
|
{ "ffn_gate_par", OFFLOAD_FUNC },
|
||||||
|
{ "ffn_act", OFFLOAD_FUNC },
|
||||||
{ "ffn_down", OFFLOAD_FUNC },
|
{ "ffn_down", OFFLOAD_FUNC },
|
||||||
{ "ffn_down_b", OFFLOAD_FUNC },
|
{ "ffn_down_b", OFFLOAD_FUNC },
|
||||||
{ "ffn_out", OFFLOAD_FUNC },
|
{ "ffn_out", OFFLOAD_FUNC },
|
||||||
|
Loading…
Reference in New Issue
Block a user