mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-30 21:34:36 +00:00
f6793491b5
* update: awq support llama-7b model * update: change order * update: benchmark results for llama2-7b * update: mistral 7b v1 benchmark * update: support 4 models * fix: Readme * update: ready for PR * update: readme * fix: readme * update: change order import * black * format code * update: work for bot mpt and awqmpt * update: readme * Rename to llm_build_ffn_mpt_awq * Formatted other files * Fixed params count * fix: remove code * update: more detail for mpt * fix: readme * fix: readme * update: change folder architecture * fix: common.cpp * fix: readme * fix: remove ggml_repeat * update: cicd * update: cicd * uppdate: remove use_awq arg * update: readme * llama : adapt plamo to new ffn ggml-ci --------- Co-authored-by: Trần Đức Nam <v.namtd12@vinai.io> Co-authored-by: Le Hoang Anh <v.anhlh33@vinai.io> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
255 lines
9.0 KiB
Python
255 lines
9.0 KiB
Python
"""
|
|
Implements the AWQ for llama.cpp use cases.
|
|
Original paper: https://arxiv.org/abs/2306.00978
|
|
|
|
This code is based on versions of the AWQ implementation found in the following repositories:
|
|
* https://github.com/mit-han-lab/llm-awq
|
|
* https://github.com/casper-hansen/AutoAWQ
|
|
"""
|
|
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from transformers import AutoModelForCausalLM, AutoConfig
|
|
from transformers.models.bloom.modeling_bloom import BloomGelu
|
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
|
from transformers.activations import GELUActivation
|
|
|
|
|
|
class ScaledActivation(nn.Module):
|
|
"""
|
|
ScaledActivation module wraps an existing activation function and applies a
|
|
scale factor to its output.
|
|
|
|
Args:
|
|
module (nn.Module): The activation function to be scaled.
|
|
scales (torch.Tensor): A tensor of size (num_features,) containing the initial
|
|
scale factors for each feature.
|
|
|
|
Returns:
|
|
torch.Tensor: The scaled output of the activation function.
|
|
"""
|
|
|
|
def __init__(self, module, scales):
|
|
super().__init__()
|
|
self.act = module
|
|
self.scales = nn.Parameter(scales.data)
|
|
|
|
def forward(self, x):
|
|
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
|
|
|
|
|
def set_op_by_name(layer, name, new_module):
|
|
"""
|
|
Set the new module for given module's name.
|
|
|
|
Args:
|
|
layer (nn.Module): The layer in which to replace the submodule.
|
|
name (str): The path to the submodule to be replaced, using dot notation
|
|
to access nested modules.
|
|
new_module (nn.Module): The new module to replace the existing one.
|
|
"""
|
|
levels = name.split(".")
|
|
if len(levels) > 1:
|
|
mod_ = layer
|
|
for l_idx in range(len(levels) - 1):
|
|
if levels[l_idx].isdigit():
|
|
mod_ = mod_[int(levels[l_idx])]
|
|
else:
|
|
mod_ = getattr(mod_, levels[l_idx])
|
|
setattr(mod_, levels[-1], new_module)
|
|
else:
|
|
setattr(layer, name, new_module)
|
|
|
|
|
|
def get_op_by_name(module, op_name):
|
|
"""
|
|
Retrieves a submodule within a given layer based on its name.
|
|
|
|
Args:
|
|
module (nn.Module): The layer containing the submodule to find.
|
|
op_name (str): The name of the submodule.
|
|
|
|
Returns:
|
|
nn.Module: The requested submodule found within the given layer.
|
|
|
|
Raises:
|
|
ValueError: If the specified submodule cannot be found within the layer.
|
|
"""
|
|
for name, m in module.named_modules():
|
|
if name == op_name:
|
|
return m
|
|
raise ValueError(f"Cannot find op {op_name} in module {module}")
|
|
|
|
|
|
@torch.no_grad()
|
|
def scale_ln_fcs(ln, fcs, scales):
|
|
"""
|
|
Scales the weights of a LayerNorm and a list of fully-connected layers proportionally.
|
|
|
|
Args:
|
|
ln (nn.LayerNorm): The LayerNorm module to be scaled.
|
|
fcs (List[nn.Linear]): A list of fully-connected layers to be scaled.
|
|
scales (torch.Tensor): A 1D tensor of size (num_features,).
|
|
"""
|
|
|
|
if not isinstance(fcs, list):
|
|
fcs = [fcs]
|
|
|
|
scales = scales.to(ln.weight.device)
|
|
|
|
ln.weight.div_(scales)
|
|
if hasattr(ln, "bias") and ln.bias is not None:
|
|
ln.bias.div_(scales)
|
|
|
|
for fc in fcs:
|
|
fc.weight.mul_(scales.view(1, -1))
|
|
|
|
for p in ln.parameters():
|
|
assert torch.isnan(p).sum() == 0
|
|
for fc in fcs:
|
|
for p in fc.parameters():
|
|
assert torch.isnan(p).sum() == 0
|
|
|
|
|
|
@torch.no_grad()
|
|
def scale_fc_fc(fc1, fc2, scales):
|
|
"""
|
|
Scales the weights of two fully-connected layers in a specific pattern.
|
|
|
|
Args:
|
|
fc1 (nn.Linear): The first fully-connected layer to be scaled.
|
|
fc2 (nn.Linear): The second fully-connected layer to be scaled.
|
|
scales (torch.Tensor): A 1D tensor of size (num_features,).
|
|
"""
|
|
assert isinstance(fc1, nn.Linear)
|
|
assert isinstance(fc2, nn.Linear)
|
|
|
|
scales = scales.to(fc1.weight.device)
|
|
|
|
fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
|
|
if fc1.bias is not None:
|
|
fc1.bias.div_(scales.view(-1))
|
|
|
|
fc2.weight.mul_(scales.view(1, -1))
|
|
|
|
for p in fc1.parameters():
|
|
assert torch.isnan(p).sum() == 0
|
|
for p in fc2.parameters():
|
|
assert torch.isnan(p).sum() == 0
|
|
|
|
|
|
@torch.no_grad()
|
|
def scale_gelu_fc(gelu, fc, scales):
|
|
"""
|
|
Scales the weight of a GELU activation and a fully-connected layer proportionally.
|
|
|
|
Args:
|
|
gelu (Union[nn.GELU, BloomGelu, GELUActivation]): The GELU activation module to be scaled.
|
|
fc (nn.Linear): The fully-connected layer to be scaled.
|
|
scales (torch.Tensor): A 1D tensor of size (num_features,).
|
|
|
|
Raises:
|
|
TypeError: If the `gelu` module is not of type `nn.GELU`, `BloomGelu`, or `GELUActivation`.
|
|
TypeError: If the `fc` module is not of type `nn.Linear`.
|
|
"""
|
|
assert isinstance(gelu, (nn.GELU, BloomGelu, GELUActivation))
|
|
assert isinstance(fc, nn.Linear)
|
|
|
|
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
|
|
|
|
for p in fc.parameters():
|
|
assert torch.isnan(p).sum() == 0
|
|
|
|
|
|
def apply_scale(module, scales_list, input_feat_dict=None):
|
|
"""
|
|
Applies different scaling strategies to layers based on their type and hierarchy within a given module.
|
|
|
|
Args:
|
|
module (nn.Module): The module containing the layers to be scaled.
|
|
scales_list (List[Tuple[str, List[str], torch.Tensor]]): A list of tuples containing:
|
|
* prev_op_name (str): The name of the preceding operation or module,
|
|
relative to which the layers to be scaled are located.
|
|
* layer_names (List[str]): A list of names of the layers to be scaled, relative to the preceding operation.
|
|
* scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
|
|
input_feat_dict (Optional[Dict[str, torch.Tensor]]): A dictionary mapping layer names to their corresponding
|
|
input features (optional).
|
|
"""
|
|
for prev_op_name, layer_names, scales in scales_list:
|
|
prev_op = get_op_by_name(module, prev_op_name)
|
|
layers = [get_op_by_name(module, name) for name in layer_names]
|
|
|
|
prev_op.cuda()
|
|
for layer in layers:
|
|
layer.cuda()
|
|
scales.cuda()
|
|
|
|
if isinstance(prev_op, nn.Linear):
|
|
assert len(layers) == 1
|
|
scale_fc_fc(prev_op, layers[0], scales)
|
|
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)) or "rmsnorm" in str(prev_op.__class__).lower():
|
|
scale_ln_fcs(prev_op, layers, scales)
|
|
elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)):
|
|
new_module = ScaledActivation(prev_op, scales)
|
|
set_op_by_name(module, prev_op_name, new_module)
|
|
scale_gelu_fc(prev_op, layers[0], scales)
|
|
else:
|
|
raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")
|
|
|
|
# apply the scaling to input feat if given; prepare it for clipping
|
|
if input_feat_dict is not None:
|
|
for layer_name in layer_names:
|
|
inp = input_feat_dict[layer_name]
|
|
inp.div_(scales.view(1, -1).to(inp.device))
|
|
|
|
prev_op.cpu()
|
|
for layer in layers:
|
|
layer.cpu()
|
|
scales.cpu()
|
|
|
|
|
|
@torch.no_grad()
|
|
def apply_clip(module, clip_list):
|
|
"""
|
|
Applies element-wise clipping to the weight of a specific layer within a given module.
|
|
|
|
Args:
|
|
module (nn.Module): The module containing the layer to be clipped.
|
|
clip_list (List[Tuple[str, torch.Tensor]]): A list of tuples containing:
|
|
* name (str): The name of the layer to be clipped, relative to the root of the module.
|
|
* max_val (torch.Tensor): A 1D or 2D tensor defining the upper bound for each element of the layer's weight.
|
|
"""
|
|
for name, max_val in clip_list:
|
|
layer = get_op_by_name(module, name)
|
|
layer.cuda()
|
|
max_val = max_val.to(layer.weight.device)
|
|
org_shape = layer.weight.shape
|
|
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
|
|
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
|
|
layer.weight.data = layer.weight.data.reshape(org_shape)
|
|
layer.cpu()
|
|
|
|
|
|
def add_scale_weights(model_path, scale_path, tmp_path):
|
|
"""
|
|
Adds pre-computed Activation Weight Quantization (AWQ) results to a model,
|
|
including scaling factors and clipping bounds.
|
|
|
|
Args:
|
|
model_path (str): Path to the pre-trained model to be equipped with AWQ.
|
|
scale_path (str): Path to the AWQ scale factors (.pt file).
|
|
tmp_path (str): Path to the temporary directory where the equipped model will be saved.
|
|
"""
|
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_path, config=config, trust_remote_code=True
|
|
)
|
|
model.eval()
|
|
awq_results = torch.load(str(scale_path), map_location="cpu")
|
|
apply_scale(model, awq_results["scale"])
|
|
apply_clip(model, awq_results["clip"])
|
|
model.save_pretrained(str(tmp_path))
|
|
os.system(f"cp {str(model_path)}/tokenizer* {str(tmp_path)}")
|