mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
Merge branch 'cuda-cublas-opts' into gg/phi-2
This commit is contained in:
commit
d2f1e0dacc
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, BinaryIO, Sequence
|
from typing import Any, BinaryIO, Sequence
|
||||||
@ -11,43 +10,15 @@ from typing import Any, BinaryIO, Sequence
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||||
|
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
|
||||||
|
import gguf
|
||||||
|
|
||||||
|
|
||||||
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
|
NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1}
|
||||||
|
|
||||||
|
|
||||||
HF_SUBLAYER_TO_GGML = {
|
|
||||||
"self_attn.q_proj": "attn_q",
|
|
||||||
"self_attn.k_proj": "attn_k",
|
|
||||||
"self_attn.v_proj": "attn_v",
|
|
||||||
"self_attn.o_proj": "attn_output",
|
|
||||||
"mlp.gate_proj": "ffn_gate",
|
|
||||||
"mlp.down_proj": "ffn_down",
|
|
||||||
"mlp.up_proj": "ffn_up",
|
|
||||||
"input_layernorm": "attn_norm",
|
|
||||||
"post_attention_layernorm": "ffn_norm",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def translate_tensor_name(t: str) -> str:
|
|
||||||
match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
|
|
||||||
if match:
|
|
||||||
nn = match.group(1)
|
|
||||||
sub_layer = match.group(2)
|
|
||||||
lora_type = match.group(3)
|
|
||||||
|
|
||||||
sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
|
|
||||||
if sub_layer_renamed is None:
|
|
||||||
print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
output_string = (
|
|
||||||
f"blk.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.weight.lora{lora_type}"
|
|
||||||
)
|
|
||||||
return output_string
|
|
||||||
else:
|
|
||||||
print(f"Error: unrecognized tensor {t}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
|
def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
|
||||||
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
||||||
fout.write(struct.pack("i", 1)) # file version
|
fout.write(struct.pack("i", 1)) # file version
|
||||||
@ -61,9 +32,7 @@ def write_file_header(fout: BinaryIO, params: dict[str, Any]) -> None:
|
|||||||
fout.write(struct.pack("i", int(params["lora_alpha"])))
|
fout.write(struct.pack("i", int(params["lora_alpha"])))
|
||||||
|
|
||||||
|
|
||||||
def write_tensor_header(
|
def write_tensor_header(fout: BinaryIO, name: str, shape: Sequence[int], data_type: np.dtype[Any]) -> None:
|
||||||
self, name: str, shape: Sequence[int], data_type: np.dtype[Any]
|
|
||||||
) -> None:
|
|
||||||
sname = name.encode("utf-8")
|
sname = name.encode("utf-8")
|
||||||
fout.write(
|
fout.write(
|
||||||
struct.pack(
|
struct.pack(
|
||||||
@ -78,11 +47,12 @@ def write_tensor_header(
|
|||||||
fout.seek((fout.tell() + 31) & -32)
|
fout.seek((fout.tell() + 31) & -32)
|
||||||
|
|
||||||
|
|
||||||
if len(sys.argv) != 2:
|
if len(sys.argv) < 2:
|
||||||
print(f"Usage: python {sys.argv[0]} <path>")
|
print(f"Usage: python {sys.argv[0]} <path> [arch]")
|
||||||
print(
|
print(
|
||||||
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
|
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
|
||||||
)
|
)
|
||||||
|
print(f"Arch must be one of {list(gguf.MODEL_ARCH_NAMES.values())} (default: llama)")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
input_json = os.path.join(sys.argv[1], "adapter_config.json")
|
input_json = os.path.join(sys.argv[1], "adapter_config.json")
|
||||||
@ -90,6 +60,14 @@ input_model = os.path.join(sys.argv[1], "adapter_model.bin")
|
|||||||
output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")
|
output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")
|
||||||
|
|
||||||
model = torch.load(input_model, map_location="cpu")
|
model = torch.load(input_model, map_location="cpu")
|
||||||
|
arch_name = sys.argv[2] if len(sys.argv) == 3 else "llama"
|
||||||
|
|
||||||
|
if arch_name not in gguf.MODEL_ARCH_NAMES.values():
|
||||||
|
print(f"Error: unsupported architecture {arch_name}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
arch = list(gguf.MODEL_ARCH_NAMES.keys())[list(gguf.MODEL_ARCH_NAMES.values()).index(arch_name)]
|
||||||
|
name_map = gguf.TensorNameMap(arch, 200) # 200 layers ought to be enough for anyone
|
||||||
|
|
||||||
with open(input_json, "r") as f:
|
with open(input_json, "r") as f:
|
||||||
params = json.load(f)
|
params = json.load(f)
|
||||||
@ -117,6 +95,7 @@ with open(output_path, "wb") as fout:
|
|||||||
|
|
||||||
write_file_header(fout, params)
|
write_file_header(fout, params)
|
||||||
for k, v in model.items():
|
for k, v in model.items():
|
||||||
|
orig_k = k
|
||||||
if k.endswith(".default.weight"):
|
if k.endswith(".default.weight"):
|
||||||
k = k.replace(".default.weight", ".weight")
|
k = k.replace(".default.weight", ".weight")
|
||||||
if k in ["llama_proj.weight", "llama_proj.bias"]:
|
if k in ["llama_proj.weight", "llama_proj.bias"]:
|
||||||
@ -129,7 +108,32 @@ with open(output_path, "wb") as fout:
|
|||||||
v = v.float()
|
v = v.float()
|
||||||
|
|
||||||
t = v.detach().numpy()
|
t = v.detach().numpy()
|
||||||
tname = translate_tensor_name(k)
|
|
||||||
|
prefix = "base_model.model."
|
||||||
|
if k.startswith(prefix):
|
||||||
|
k = k[len(prefix) :]
|
||||||
|
|
||||||
|
lora_suffixes = (".lora_A.weight", ".lora_B.weight")
|
||||||
|
if k.endswith(lora_suffixes):
|
||||||
|
suffix = k[-len(lora_suffixes[0]):]
|
||||||
|
k = k[: -len(lora_suffixes[0])]
|
||||||
|
else:
|
||||||
|
print(f"Error: unrecognized tensor name {orig_k}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
tname = name_map.get_name(k)
|
||||||
|
if tname is None:
|
||||||
|
print(f"Error: could not map tensor name {orig_k}")
|
||||||
|
print(" Note: the arch parameter must be specified if the model is not llama")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if suffix == ".lora_A.weight":
|
||||||
|
tname += ".weight.loraA"
|
||||||
|
elif suffix == ".lora_B.weight":
|
||||||
|
tname += ".weight.loraB"
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
||||||
write_tensor_header(fout, tname, t.shape, t.dtype)
|
write_tensor_header(fout, tname, t.shape, t.dtype)
|
||||||
t.tofile(fout)
|
t.tofile(fout)
|
||||||
|
@ -34,7 +34,8 @@ export async function* llama(prompt, params = {}, config = {}) {
|
|||||||
headers: {
|
headers: {
|
||||||
'Connection': 'keep-alive',
|
'Connection': 'keep-alive',
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Accept': 'text/event-stream'
|
'Accept': 'text/event-stream',
|
||||||
|
...(params.api_key ? {'Authorization': `Bearer ${params.api_key}`} : {})
|
||||||
},
|
},
|
||||||
signal: controller.signal,
|
signal: controller.signal,
|
||||||
});
|
});
|
||||||
|
@ -235,7 +235,8 @@
|
|||||||
grammar: '',
|
grammar: '',
|
||||||
n_probs: 0, // no completion_probabilities,
|
n_probs: 0, // no completion_probabilities,
|
||||||
image_data: [],
|
image_data: [],
|
||||||
cache_prompt: true
|
cache_prompt: true,
|
||||||
|
api_key: ''
|
||||||
})
|
})
|
||||||
|
|
||||||
/* START: Support for storing prompt templates and parameters in browsers LocalStorage */
|
/* START: Support for storing prompt templates and parameters in browsers LocalStorage */
|
||||||
@ -790,6 +791,10 @@
|
|||||||
<fieldset>
|
<fieldset>
|
||||||
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
|
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
|
||||||
</fieldset>
|
</fieldset>
|
||||||
|
<fieldset>
|
||||||
|
<label for="api_key">API Key</label>
|
||||||
|
<input type="text" name="api_key" value="${params.value.api_key}" placeholder="Enter API key" oninput=${updateParams} />
|
||||||
|
</fieldset>
|
||||||
</details>
|
</details>
|
||||||
</form>
|
</form>
|
||||||
`
|
`
|
||||||
|
@ -36,6 +36,7 @@ using json = nlohmann::json;
|
|||||||
struct server_params
|
struct server_params
|
||||||
{
|
{
|
||||||
std::string hostname = "127.0.0.1";
|
std::string hostname = "127.0.0.1";
|
||||||
|
std::string api_key;
|
||||||
std::string public_path = "examples/server/public";
|
std::string public_path = "examples/server/public";
|
||||||
int32_t port = 8080;
|
int32_t port = 8080;
|
||||||
int32_t read_timeout = 600;
|
int32_t read_timeout = 600;
|
||||||
@ -1953,6 +1954,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||||||
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
|
||||||
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
||||||
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
||||||
|
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
||||||
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
||||||
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
||||||
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
||||||
@ -2002,6 +2004,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||||||
}
|
}
|
||||||
sparams.public_path = argv[i];
|
sparams.public_path = argv[i];
|
||||||
}
|
}
|
||||||
|
else if (arg == "--api-key")
|
||||||
|
{
|
||||||
|
if (++i >= argc)
|
||||||
|
{
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
sparams.api_key = argv[i];
|
||||||
|
}
|
||||||
else if (arg == "--timeout" || arg == "-to")
|
else if (arg == "--timeout" || arg == "-to")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
@ -2669,6 +2680,32 @@ int main(int argc, char **argv)
|
|||||||
|
|
||||||
httplib::Server svr;
|
httplib::Server svr;
|
||||||
|
|
||||||
|
// Middleware for API key validation
|
||||||
|
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
||||||
|
// If API key is not set, skip validation
|
||||||
|
if (sparams.api_key.empty()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for API key in the header
|
||||||
|
auto auth_header = req.get_header_value("Authorization");
|
||||||
|
std::string prefix = "Bearer ";
|
||||||
|
if (auth_header.substr(0, prefix.size()) == prefix) {
|
||||||
|
std::string received_api_key = auth_header.substr(prefix.size());
|
||||||
|
if (received_api_key == sparams.api_key) {
|
||||||
|
return true; // API key is valid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// API key is invalid or not provided
|
||||||
|
res.set_content("Unauthorized: Invalid API Key", "text/plain");
|
||||||
|
res.status = 401; // Unauthorized
|
||||||
|
|
||||||
|
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||||
|
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
svr.set_default_headers({{"Server", "llama.cpp"},
|
svr.set_default_headers({{"Server", "llama.cpp"},
|
||||||
{"Access-Control-Allow-Origin", "*"},
|
{"Access-Control-Allow-Origin", "*"},
|
||||||
{"Access-Control-Allow-Headers", "content-type"}});
|
{"Access-Control-Allow-Headers", "content-type"}});
|
||||||
@ -2711,8 +2748,11 @@ int main(int argc, char **argv)
|
|||||||
res.set_content(data.dump(), "application/json");
|
res.set_content(data.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
|
if (!validate_api_key(req, res)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
const int task_id = llama.request_completion(data, false, false, -1);
|
const int task_id = llama.request_completion(data, false, false, -1);
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
@ -2799,8 +2839,11 @@ int main(int argc, char **argv)
|
|||||||
});
|
});
|
||||||
|
|
||||||
// TODO: add mount point without "/v1" prefix -- how?
|
// TODO: add mount point without "/v1" prefix -- how?
|
||||||
svr.Post("/v1/chat/completions", [&llama](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
|
if (!validate_api_key(req, res)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||||
|
|
||||||
const int task_id = llama.request_completion(data, false, false, -1);
|
const int task_id = llama.request_completion(data, false, false, -1);
|
||||||
@ -2869,8 +2912,11 @@ int main(int argc, char **argv)
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
|
if (!validate_api_key(req, res)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
const int task_id = llama.request_completion(data, true, false, -1);
|
const int task_id = llama.request_completion(data, true, false, -1);
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!json_value(data, "stream", false)) {
|
||||||
@ -3005,11 +3051,15 @@ int main(int argc, char **argv)
|
|||||||
|
|
||||||
svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
|
svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
|
||||||
{
|
{
|
||||||
|
if (res.status == 401)
|
||||||
|
{
|
||||||
|
res.set_content("Unauthorized", "text/plain");
|
||||||
|
}
|
||||||
if (res.status == 400)
|
if (res.status == 400)
|
||||||
{
|
{
|
||||||
res.set_content("Invalid request", "text/plain");
|
res.set_content("Invalid request", "text/plain");
|
||||||
}
|
}
|
||||||
else if (res.status != 500)
|
else if (res.status == 404)
|
||||||
{
|
{
|
||||||
res.set_content("File Not Found", "text/plain");
|
res.set_content("File Not Found", "text/plain");
|
||||||
res.status = 404;
|
res.status = 404;
|
||||||
@ -3032,11 +3082,15 @@ int main(int argc, char **argv)
|
|||||||
// to make it ctrl+clickable:
|
// to make it ctrl+clickable:
|
||||||
LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
|
LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
|
||||||
|
|
||||||
LOG_INFO("HTTP server listening", {
|
std::unordered_map<std::string, std::string> log_data;
|
||||||
{"hostname", sparams.hostname},
|
log_data["hostname"] = sparams.hostname;
|
||||||
{"port", sparams.port},
|
log_data["port"] = std::to_string(sparams.port);
|
||||||
});
|
|
||||||
|
|
||||||
|
if (!sparams.api_key.empty()) {
|
||||||
|
log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("HTTP server listening", log_data);
|
||||||
// run the HTTP server in a thread - see comment below
|
// run the HTTP server in a thread - see comment below
|
||||||
std::thread t([&]()
|
std::thread t([&]()
|
||||||
{
|
{
|
||||||
|
62
ggml-cuda.cu
62
ggml-cuda.cu
@ -7406,27 +7406,20 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|||||||
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
||||||
}
|
}
|
||||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
|
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
|
||||||
size_t dst_as = 0;
|
|
||||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
|
|
||||||
|
|
||||||
const half alpha_f16 = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const half beta_f16 = 0.0f;
|
const float beta = 0.0f;
|
||||||
|
|
||||||
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
|
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
row_diff, src1_ncols, ne10,
|
row_diff, src1_ncols, ne10,
|
||||||
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
&alpha, src0_ptr, CUDA_R_16F, ne00,
|
||||||
src1_ptr, CUDA_R_16F, ne10,
|
src1_ptr, CUDA_R_16F, ne10,
|
||||||
&beta_f16, dst_f16, CUDA_R_16F, ldc,
|
&beta, dst_dd_i, CUDA_R_32F, ldc,
|
||||||
CUBLAS_COMPUTE_16F,
|
CUBLAS_COMPUTE_32F,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
|
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
|
||||||
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
|
||||||
|
|
||||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
|
||||||
|
|
||||||
if (src0_as != 0) {
|
if (src0_as != 0) {
|
||||||
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
||||||
}
|
}
|
||||||
@ -8306,8 +8299,8 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
|
|||||||
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
|
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void k_compute_batched_ptrs(
|
__global__ static void k_compute_batched_ptrs(
|
||||||
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
|
const half * src0_as_f16, const half * src1_as_f16, float * dst_f32,
|
||||||
const void ** ptrs_src, void ** ptrs_dst,
|
const void ** ptrs_src, void ** ptrs_dst,
|
||||||
int ne12, int ne13,
|
int ne12, int ne13,
|
||||||
int ne23,
|
int ne23,
|
||||||
@ -8327,7 +8320,7 @@ static __global__ void k_compute_batched_ptrs(
|
|||||||
|
|
||||||
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
||||||
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
|
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
|
||||||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
|
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f32 + i12* nb2 + i13* nb3 ;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
@ -8382,9 +8375,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|||||||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
|
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
|
||||||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
||||||
|
|
||||||
size_t dst_as = 0;
|
|
||||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
|
||||||
|
|
||||||
GGML_ASSERT(ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
GGML_ASSERT(ne13 % ne03 == 0);
|
GGML_ASSERT(ne13 % ne03 == 0);
|
||||||
|
|
||||||
@ -8392,8 +8382,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|||||||
const int64_t r2 = ne12/ne02;
|
const int64_t r2 = ne12/ne02;
|
||||||
const int64_t r3 = ne13/ne03;
|
const int64_t r3 = ne13/ne03;
|
||||||
|
|
||||||
const half alpha_f16 = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const half beta_f16 = 0.0f;
|
const float beta = 0.0f;
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// use cublasGemmEx
|
// use cublasGemmEx
|
||||||
@ -8406,10 +8396,10 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
&alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
||||||
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
||||||
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
|
&beta, ( char *) dst_ddf + i12* dst->nb[2] + i13* dst->nb[3] , CUDA_R_32F, ne01,
|
||||||
CUBLAS_COMPUTE_16F,
|
CUBLAS_COMPUTE_32F,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -8421,11 +8411,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
&alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
||||||
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
||||||
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
|
&beta, ( char *) dst_ddf, CUDA_R_32F, ne01, dst->nb[2]/sizeof(float), // strideC
|
||||||
ne12*ne13,
|
ne12*ne13,
|
||||||
CUBLAS_COMPUTE_16F,
|
CUBLAS_COMPUTE_32F,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
} else {
|
} else {
|
||||||
// use cublasGemmBatchedEx
|
// use cublasGemmBatchedEx
|
||||||
@ -8442,7 +8432,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|||||||
|
|
||||||
dim3 block_dims(ne13, ne12);
|
dim3 block_dims(ne13, ne12);
|
||||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||||
src0_as_f16, src1_as_f16, dst_f16,
|
src0_as_f16, src1_as_f16, dst_ddf,
|
||||||
ptrs_src, ptrs_dst,
|
ptrs_src, ptrs_dst,
|
||||||
ne12, ne13,
|
ne12, ne13,
|
||||||
ne23,
|
ne23,
|
||||||
@ -8455,11 +8445,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
ne01, ne11, ne10,
|
ne01, ne11, ne10,
|
||||||
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
&alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
||||||
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
||||||
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
|
&beta, ( void **) (ptrs_dst + 0*ne23), CUDA_R_32F, ne01,
|
||||||
ne23,
|
ne23,
|
||||||
CUBLAS_COMPUTE_16F,
|
CUBLAS_COMPUTE_32F,
|
||||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
|
|
||||||
if (ptrs_src_s != 0) {
|
if (ptrs_src_s != 0) {
|
||||||
@ -8471,11 +8461,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
|
||||||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
|
||||||
|
|
||||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
237
ggml.c
237
ggml.c
@ -9584,16 +9584,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// off1 = offset in i11 and i1
|
|
||||||
// cne1 = ne11 and ne1
|
|
||||||
// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
|
|
||||||
// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
|
|
||||||
static void ggml_compute_forward_mul_mat(
|
static void ggml_compute_forward_mul_mat(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
const struct ggml_tensor * src1,
|
const struct ggml_tensor * src1,
|
||||||
struct ggml_tensor * dst,
|
struct ggml_tensor * dst) {
|
||||||
int64_t off1, int64_t cne1) {
|
|
||||||
int64_t t0 = ggml_perf_time_us();
|
int64_t t0 = ggml_perf_time_us();
|
||||||
UNUSED(t0);
|
UNUSED(t0);
|
||||||
|
|
||||||
@ -9661,9 +9656,9 @@ static void ggml_compute_forward_mul_mat(
|
|||||||
const int64_t i03 = i13/r3;
|
const int64_t i03 = i13/r3;
|
||||||
const int64_t i02 = i12/r2;
|
const int64_t i02 = i12/r2;
|
||||||
|
|
||||||
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
|
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
|
||||||
const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
|
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
|
||||||
float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
|
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
||||||
|
|
||||||
if (type != GGML_TYPE_F32) {
|
if (type != GGML_TYPE_F32) {
|
||||||
float * const wdata = params->wdata;
|
float * const wdata = params->wdata;
|
||||||
@ -9680,7 +9675,7 @@ static void ggml_compute_forward_mul_mat(
|
|||||||
}
|
}
|
||||||
|
|
||||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||||
cne1, ne01, ne10,
|
ne1, ne01, ne10,
|
||||||
1.0f, y, ne10,
|
1.0f, y, ne10,
|
||||||
x, ne00,
|
x, ne00,
|
||||||
0.0f, d, ne01);
|
0.0f, d, ne01);
|
||||||
@ -9721,8 +9716,8 @@ static void ggml_compute_forward_mul_mat(
|
|||||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||||
|
|
||||||
const int64_t nr0 = ne01; // src0 rows
|
const int64_t nr0 = ne01; // src0 rows
|
||||||
const int64_t nr1 = cne1*ne12*ne13; // src1 rows
|
const int64_t nr1 = ne1*ne12*ne13; // src1 rows
|
||||||
|
|
||||||
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
|
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
|
||||||
|
|
||||||
@ -9764,9 +9759,9 @@ static void ggml_compute_forward_mul_mat(
|
|||||||
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
||||||
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
||||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
||||||
const int64_t i13 = (ir1/(ne12*cne1));
|
const int64_t i13 = (ir1/(ne12*ne1));
|
||||||
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
|
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
|
||||||
const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
|
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
|
||||||
|
|
||||||
// broadcast src0 into src1
|
// broadcast src0 into src1
|
||||||
const int64_t i03 = i13/r3;
|
const int64_t i03 = i13/r3;
|
||||||
@ -9806,28 +9801,191 @@ static void ggml_compute_forward_mul_mat(
|
|||||||
|
|
||||||
static void ggml_compute_forward_mul_mat_id(
|
static void ggml_compute_forward_mul_mat_id(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * ids,
|
||||||
const struct ggml_tensor * src1,
|
const struct ggml_tensor * src1,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
// during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
|
|
||||||
ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const struct ggml_tensor * ids = src0;
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const enum ggml_type type = src0->type;
|
||||||
|
|
||||||
|
const bool src1_cont = ggml_is_contiguous(src1);
|
||||||
|
|
||||||
|
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||||
|
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||||
|
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||||
|
|
||||||
|
GGML_ASSERT(ne0 == ne01);
|
||||||
|
GGML_ASSERT(ne1 == ne11);
|
||||||
|
GGML_ASSERT(ne2 == ne12);
|
||||||
|
GGML_ASSERT(ne3 == ne13);
|
||||||
|
|
||||||
|
// we don't support permuted src0 or src1
|
||||||
|
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||||
|
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||||
|
|
||||||
|
// dst cannot be transposed or permuted
|
||||||
|
GGML_ASSERT(nb0 == sizeof(float));
|
||||||
|
GGML_ASSERT(nb0 <= nb1);
|
||||||
|
GGML_ASSERT(nb1 <= nb2);
|
||||||
|
GGML_ASSERT(nb2 <= nb3);
|
||||||
|
|
||||||
|
// broadcast factors
|
||||||
|
const int64_t r2 = ne12/ne02;
|
||||||
|
const int64_t r3 = ne13/ne03;
|
||||||
|
|
||||||
|
// row groups
|
||||||
const int id = ggml_get_op_params_i32(dst, 0);
|
const int id = ggml_get_op_params_i32(dst, 0);
|
||||||
const int n_as = ggml_get_op_params_i32(dst, 1);
|
const int n_as = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
char * wdata_src1_end = (src1->type == vec_dot_type) ?
|
||||||
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
|
(char *) params->wdata :
|
||||||
|
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
|
||||||
|
|
||||||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
||||||
|
int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
|
||||||
|
|
||||||
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
|
||||||
ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
|
|
||||||
|
if (params->type == GGML_TASK_INIT) {
|
||||||
|
char * wdata = params->wdata;
|
||||||
|
if (src1->type != vec_dot_type) {
|
||||||
|
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||||
|
|
||||||
|
assert(params->wsize >= ne11*ne12*ne13*row_size);
|
||||||
|
assert(src1->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||||
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||||
|
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||||
|
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
||||||
|
wdata += row_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize matrix_row_counts
|
||||||
|
GGML_ASSERT(wdata == wdata_src1_end);
|
||||||
|
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
||||||
|
|
||||||
|
// group rows by src0 matrix
|
||||||
|
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||||
|
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
|
||||||
|
|
||||||
|
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||||
|
MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
|
||||||
|
matrix_row_counts[row_id] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_FINALIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute each matrix multiplication in sequence
|
||||||
|
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
||||||
|
const int64_t cne1 = matrix_row_counts[cur_a];
|
||||||
|
|
||||||
|
if (cne1 == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
|
||||||
|
|
||||||
|
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||||
|
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||||
|
|
||||||
|
const int64_t nr0 = ne01; // src0 rows
|
||||||
|
const int64_t nr1 = cne1*ne12*ne13; // src1 rows
|
||||||
|
|
||||||
|
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
|
||||||
|
|
||||||
|
// distribute the thread work across the inner or outer loop based on which one is larger
|
||||||
|
|
||||||
|
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
||||||
|
const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
|
||||||
|
|
||||||
|
const int64_t ith0 = ith % nth0;
|
||||||
|
const int64_t ith1 = ith / nth0;
|
||||||
|
|
||||||
|
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
|
||||||
|
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
|
||||||
|
|
||||||
|
const int64_t ir010 = dr0*ith0;
|
||||||
|
const int64_t ir011 = MIN(ir010 + dr0, nr0);
|
||||||
|
|
||||||
|
const int64_t ir110 = dr1*ith1;
|
||||||
|
const int64_t ir111 = MIN(ir110 + dr1, nr1);
|
||||||
|
|
||||||
|
//printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
|
||||||
|
|
||||||
|
// threads with no work simply yield (not sure if it helps)
|
||||||
|
if (ir010 >= ir011 || ir110 >= ir111) {
|
||||||
|
sched_yield();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(ne12 % ne02 == 0);
|
||||||
|
assert(ne13 % ne03 == 0);
|
||||||
|
|
||||||
|
// block-tiling attempt
|
||||||
|
const int64_t blck_0 = 16;
|
||||||
|
const int64_t blck_1 = 16;
|
||||||
|
|
||||||
|
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||||
|
float tmp[16];
|
||||||
|
|
||||||
|
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
||||||
|
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
||||||
|
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
||||||
|
const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
|
||||||
|
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
|
||||||
|
const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
|
||||||
|
const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
|
||||||
|
|
||||||
|
// broadcast src0 into src1
|
||||||
|
const int64_t i03 = i13/r3;
|
||||||
|
const int64_t i02 = i12/r2;
|
||||||
|
|
||||||
|
const int64_t i1 = i11;
|
||||||
|
const int64_t i2 = i12;
|
||||||
|
const int64_t i3 = i13;
|
||||||
|
|
||||||
|
const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03);
|
||||||
|
|
||||||
|
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
||||||
|
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
||||||
|
// the original src1 data pointer, so we should index using the indices directly
|
||||||
|
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||||
|
const char * src1_col = (const char *) wdata +
|
||||||
|
(src1_cont || src1->type != vec_dot_type
|
||||||
|
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
|
||||||
|
: (i11*nb11 + i12*nb12 + i13*nb13));
|
||||||
|
|
||||||
|
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
|
||||||
|
|
||||||
|
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||||
|
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
||||||
|
//}
|
||||||
|
|
||||||
|
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||||
|
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
|
||||||
|
}
|
||||||
|
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef MMID_MATRIX_ROW
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_out_prod
|
// ggml_compute_forward_out_prod
|
||||||
@ -14217,7 +14375,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
|
ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
@ -16017,7 +16175,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
// FIXME: blas
|
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_OUT_PROD:
|
case GGML_OP_OUT_PROD:
|
||||||
@ -16351,20 +16508,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
const struct ggml_tensor * a = node->src[2];
|
const struct ggml_tensor * src0 = node->src[2];
|
||||||
const struct ggml_tensor * b = node->src[1];
|
const struct ggml_tensor * src1 = node->src[1];
|
||||||
const enum ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
|
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
if (src1->type != vec_dot_type) {
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
|
cur = ggml_row_size(vec_dot_type, ggml_nelements(src1));
|
||||||
if (a->type != GGML_TYPE_F32) {
|
|
||||||
// here we need memory just for single 2D matrix from src0
|
|
||||||
cur = ggml_type_size(GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
|
|
||||||
}
|
|
||||||
} else
|
|
||||||
#endif
|
|
||||||
if (b->type != vec_dot_type) {
|
|
||||||
cur = ggml_row_size(vec_dot_type, ggml_nelements(b));
|
|
||||||
}
|
}
|
||||||
|
const int n_as = ggml_get_op_params_i32(node, 1);
|
||||||
|
cur = GGML_PAD(cur, sizeof(int64_t)); // align
|
||||||
|
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
||||||
|
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_OUT_PROD:
|
case GGML_OP_OUT_PROD:
|
||||||
{
|
{
|
||||||
|
155
llama.cpp
155
llama.cpp
@ -1521,6 +1521,10 @@ struct llama_context {
|
|||||||
|
|
||||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
|
#ifndef NDEBUG
|
||||||
|
// guard against access to unset logits
|
||||||
|
std::vector<bool> logits_valid;
|
||||||
|
#endif
|
||||||
bool logits_all = false;
|
bool logits_all = false;
|
||||||
|
|
||||||
// input embedding (1-dimensional array: [n_embd])
|
// input embedding (1-dimensional array: [n_embd])
|
||||||
@ -6386,6 +6390,14 @@ static int llama_decode_internal(
|
|||||||
{
|
{
|
||||||
auto & logits_out = lctx.logits;
|
auto & logits_out = lctx.logits;
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
auto & logits_valid = lctx.logits_valid;
|
||||||
|
logits_valid.clear();
|
||||||
|
logits_valid.resize(n_tokens);
|
||||||
|
|
||||||
|
logits_out.clear();
|
||||||
|
#endif
|
||||||
|
|
||||||
if (batch.logits) {
|
if (batch.logits) {
|
||||||
logits_out.resize(n_vocab * n_tokens);
|
logits_out.resize(n_vocab * n_tokens);
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
@ -6393,13 +6405,22 @@ static int llama_decode_internal(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
|
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
|
||||||
|
#ifndef NDEBUG
|
||||||
|
logits_valid[i] = true;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all) {
|
} else if (lctx.logits_all) {
|
||||||
logits_out.resize(n_vocab * n_tokens);
|
logits_out.resize(n_vocab * n_tokens);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
|
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
|
||||||
|
#ifndef NDEBUG
|
||||||
|
std::fill(logits_valid.begin(), logits_valid.end(), true);
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
logits_out.resize(n_vocab);
|
logits_out.resize(n_vocab);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
|
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
|
||||||
|
#ifndef NDEBUG
|
||||||
|
logits_valid[n_tokens - 1] = true;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -8862,53 +8883,60 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
|
|
||||||
const int64_t t_start_lora_us = ggml_time_us();
|
const int64_t t_start_lora_us = ggml_time_us();
|
||||||
|
|
||||||
auto fin = std::ifstream(path_lora, std::ios::binary);
|
llama_file fin(path_lora, "rb");
|
||||||
if (!fin) {
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_lora);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// verify magic and version
|
// verify magic and version
|
||||||
{
|
{
|
||||||
uint32_t magic;
|
uint32_t magic = fin.read_u32();
|
||||||
fin.read((char *) &magic, sizeof(magic));
|
if (magic != LLAMA_FILE_MAGIC_GGLA) {
|
||||||
uint32_t format_version;
|
LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
|
||||||
fin.read((char *) &format_version, sizeof(format_version));
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t format_version = fin.read_u32();
|
||||||
if (format_version != 1) {
|
if (format_version != 1) {
|
||||||
LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
|
LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t lora_r;
|
int32_t lora_r = fin.read_u32();
|
||||||
int32_t lora_alpha;
|
int32_t lora_alpha = fin.read_u32();
|
||||||
fin.read((char *) &lora_r, sizeof(lora_r));
|
|
||||||
fin.read((char *) &lora_alpha, sizeof(lora_alpha));
|
|
||||||
float scaling = scale * (float)lora_alpha / (float)lora_r;
|
float scaling = scale * (float)lora_alpha / (float)lora_r;
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
|
LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
|
||||||
|
|
||||||
|
// create a name -> tensor map of the model to accelerate lookups
|
||||||
|
// find the max tensor size to estimate the required temporary buffer size
|
||||||
|
size_t max_tensor_size = 0;
|
||||||
|
std::unordered_map<std::string, struct ggml_tensor*> model_tensors;
|
||||||
|
for (const auto & kv : model.tensors_by_name) {
|
||||||
|
model_tensors.insert(kv);
|
||||||
|
size_t f32_size = ggml_nelements(kv.second) * sizeof(float);
|
||||||
|
max_tensor_size = std::max(max_tensor_size, f32_size);
|
||||||
|
}
|
||||||
|
|
||||||
// create a temporary ggml context to store the lora tensors
|
// create a temporary ggml context to store the lora tensors
|
||||||
// todo: calculate size from biggest possible tensor
|
// TODO: use ggml-alloc
|
||||||
std::vector<uint8_t> lora_buf(1024ull * 1024ull * 1024ull);
|
size_t lora_ctx_size = max_tensor_size * 3;
|
||||||
|
LLAMA_LOG_INFO("%s: allocating %.f MB for lora temporary buffer\n", __func__, lora_ctx_size / 1024.0 / 1024.0);
|
||||||
|
std::vector<uint8_t> lora_buf(lora_ctx_size);
|
||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = lora_buf.size();
|
params.mem_size = lora_buf.size();
|
||||||
params.mem_buffer = lora_buf.data();
|
params.mem_buffer = lora_buf.data();
|
||||||
params.no_alloc = false;
|
params.no_alloc = false;
|
||||||
|
|
||||||
ggml_context * lora_ctx = ggml_init(params);
|
using unique_context = std::unique_ptr<ggml_context, decltype(&ggml_free)>;
|
||||||
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
|
|
||||||
|
|
||||||
// create a name -> tensor map of the model to accelerate lookups
|
unique_context lora_ctx(nullptr, ggml_free);
|
||||||
std::unordered_map<std::string, struct ggml_tensor*> model_tensors;
|
lora_ctx.reset(ggml_init(params));
|
||||||
for (const auto & kv : model.tensors_by_name) {
|
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
|
||||||
model_tensors.insert(kv);
|
|
||||||
}
|
|
||||||
|
|
||||||
// load base model
|
// load base model
|
||||||
std::unique_ptr<llama_model_loader> ml;
|
std::unique_ptr<llama_model_loader> ml;
|
||||||
ggml_context * base_ctx = NULL;
|
|
||||||
|
unique_context base_ctx(nullptr, ggml_free);
|
||||||
std::vector<uint8_t> base_buf;
|
std::vector<uint8_t> base_buf;
|
||||||
if (path_base_model) {
|
if (path_base_model) {
|
||||||
LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
|
LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
|
||||||
@ -8917,6 +8945,7 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
size_t ctx_size;
|
size_t ctx_size;
|
||||||
size_t mmapped_size;
|
size_t mmapped_size;
|
||||||
ml->calc_sizes(ctx_size, mmapped_size);
|
ml->calc_sizes(ctx_size, mmapped_size);
|
||||||
|
|
||||||
base_buf.resize(ctx_size);
|
base_buf.resize(ctx_size);
|
||||||
|
|
||||||
ggml_init_params base_params;
|
ggml_init_params base_params;
|
||||||
@ -8924,9 +8953,9 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
base_params.mem_buffer = base_buf.data();
|
base_params.mem_buffer = base_buf.data();
|
||||||
base_params.no_alloc = ml->use_mmap;
|
base_params.no_alloc = ml->use_mmap;
|
||||||
|
|
||||||
base_ctx = ggml_init(base_params);
|
base_ctx.reset(ggml_init(base_params));
|
||||||
|
|
||||||
// maybe this should in llama_model_loader
|
// maybe this should be in llama_model_loader
|
||||||
if (ml->use_mmap) {
|
if (ml->use_mmap) {
|
||||||
ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa()));
|
ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa()));
|
||||||
}
|
}
|
||||||
@ -8939,27 +8968,35 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
std::vector<uint8_t> work_buffer;
|
std::vector<uint8_t> work_buffer;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
if (fin.tell() == fin.size) {
|
||||||
|
// eof
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t n_dims;
|
int32_t n_dims;
|
||||||
int32_t length;
|
int32_t name_len;
|
||||||
int32_t ftype;
|
int32_t ftype;
|
||||||
|
|
||||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
fin.read_raw(&n_dims, sizeof(n_dims));
|
||||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
fin.read_raw(&name_len, sizeof(name_len));
|
||||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
fin.read_raw(&ftype, sizeof(ftype));
|
||||||
if (fin.eof()) {
|
|
||||||
break;
|
if (n_dims != 1 && n_dims != 2) {
|
||||||
|
LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t ne[2] = { 1, 1 };
|
int32_t ne[2] = { 1, 1 };
|
||||||
for (int i = 0; i < n_dims; ++i) {
|
for (int i = 0; i < n_dims; ++i) {
|
||||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
fin.read_raw(&ne[i], sizeof(ne[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string name;
|
std::string name;
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(name_len <= 1024);
|
||||||
char buf[1024];
|
char buf[1024];
|
||||||
fin.read(buf, length);
|
fin.read_raw(buf, name_len);
|
||||||
name = std::string(buf, length);
|
name = std::string(buf, name_len);
|
||||||
}
|
}
|
||||||
|
|
||||||
// check for lora suffix and get the type of tensor
|
// check for lora suffix and get the type of tensor
|
||||||
@ -8973,7 +9010,7 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
std::string lora_type = name.substr(pos + lora_suffix.length());
|
std::string lora_type = name.substr(pos + lora_suffix.length());
|
||||||
std::string base_name = name;
|
std::string base_name = name;
|
||||||
base_name.erase(pos);
|
base_name.erase(pos);
|
||||||
// LLAMA_LOG_INFO("%s: %s => %s (lora type %s) \n", __func__, name.c_str(),base_name.c_str(), lora_type.c_str());
|
// LLAMA_LOG_INFO("%s: %s => %s (lora type %s) \n", __func__, name.c_str(), base_name.c_str(), lora_type.c_str());
|
||||||
|
|
||||||
if (model_tensors.find(base_name) == model_tensors.end()) {
|
if (model_tensors.find(base_name) == model_tensors.end()) {
|
||||||
LLAMA_LOG_ERROR("%s: unknown tensor '%s' in lora adapter\n", __func__, name.data());
|
LLAMA_LOG_ERROR("%s: unknown tensor '%s' in lora adapter\n", __func__, name.data());
|
||||||
@ -8992,22 +9029,15 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ggml_tensor * lora_tensor;
|
ggml_tensor * lora_tensor = ggml_new_tensor_2d(lora_ctx.get(), wtype, ne[0], ne[1]);
|
||||||
if (n_dims == 2) {
|
ggml_set_name(lora_tensor, name.c_str());
|
||||||
lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
ggml_set_name(lora_tensor, "lora_tensor");
|
|
||||||
|
|
||||||
// load tensor data
|
// load tensor data
|
||||||
size_t offset = fin.tellg();
|
size_t offset = fin.tell();
|
||||||
size_t tensor_data_size = ggml_nbytes(lora_tensor);
|
size_t tensor_data_size = ggml_nbytes(lora_tensor);
|
||||||
offset = (offset + 31) & -32;
|
offset = (offset + 31) & -32;
|
||||||
fin.seekg(offset);
|
fin.seek(offset, SEEK_SET);
|
||||||
fin.read((char*)lora_tensor->data, tensor_data_size);
|
fin.read_raw(lora_tensor->data, tensor_data_size);
|
||||||
|
|
||||||
lora_tensors[name] = lora_tensor;
|
lora_tensors[name] = lora_tensor;
|
||||||
|
|
||||||
@ -9037,13 +9067,11 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
|
|
||||||
// load from base model
|
// load from base model
|
||||||
if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) {
|
if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) {
|
||||||
// TODO: throw
|
|
||||||
LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
|
LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: not tested!! maybe not working!
|
base_t = ml->create_tensor(base_ctx.get(), base_name, { dest_t->ne[0], dest_t->ne[1] }, GGML_BACKEND_CPU);
|
||||||
base_t = ml->create_tensor(base_ctx, base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU);
|
|
||||||
ml->load_data_for(base_t);
|
ml->load_data_for(base_t);
|
||||||
} else {
|
} else {
|
||||||
base_t = dest_t;
|
base_t = dest_t;
|
||||||
@ -9072,43 +9100,45 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// w = w + BA*s
|
// w = w + BA*s
|
||||||
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
|
ggml_tensor * BA = ggml_mul_mat(lora_ctx.get(), loraA, loraB);
|
||||||
offload_func(BA);
|
offload_func(BA);
|
||||||
ggml_set_name(BA, "BA");
|
ggml_set_name(BA, "BA");
|
||||||
|
|
||||||
if (scaling != 1.0f) {
|
if (scaling != 1.0f) {
|
||||||
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
|
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx.get(), scaling);
|
||||||
ggml_set_name(scale_tensor, "scale_tensor");
|
ggml_set_name(scale_tensor, "scale_tensor");
|
||||||
|
|
||||||
BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor);
|
BA = ggml_scale_inplace(lora_ctx.get(), BA, scale_tensor);
|
||||||
offload_func(BA);
|
offload_func(BA);
|
||||||
ggml_set_name(BA, "BA_scaled");
|
ggml_set_name(BA, "BA_scaled");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * r;
|
ggml_tensor * r;
|
||||||
if (base_t == dest_t) {
|
if (base_t == dest_t) {
|
||||||
r = ggml_add_inplace(lora_ctx, dest_t, BA);
|
r = ggml_add_inplace(lora_ctx.get(), dest_t, BA);
|
||||||
offload_func_force_inplace(r);
|
offload_func_force_inplace(r);
|
||||||
ggml_set_name(r, "r_add_inplace");
|
ggml_set_name(r, "r_add_inplace");
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
r = ggml_add(lora_ctx, base_t, BA);
|
r = ggml_add(lora_ctx.get(), base_t, BA);
|
||||||
offload_func(r);
|
offload_func(r);
|
||||||
ggml_set_name(r, "r_add");
|
ggml_set_name(r, "r_add");
|
||||||
|
|
||||||
r = ggml_cpy(lora_ctx, r, dest_t);
|
r = ggml_cpy(lora_ctx.get(), r, dest_t);
|
||||||
offload_func(r);
|
offload_func(r);
|
||||||
ggml_set_name(r, "r_cpy");
|
ggml_set_name(r, "r_cpy");
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * gf = ggml_new_graph(lora_ctx);
|
struct ggml_cgraph * gf = ggml_new_graph(lora_ctx.get());
|
||||||
ggml_build_forward_expand(gf, r);
|
ggml_build_forward_expand(gf, r);
|
||||||
|
|
||||||
ggml_graph_compute_helper(work_buffer, gf, n_threads);
|
ggml_graph_compute_helper(work_buffer, gf, n_threads);
|
||||||
|
|
||||||
|
// the tensors in the adapter must be sorted such that loraA and loraB of the same tensor are next to each other
|
||||||
|
GGML_ASSERT(lora_tensors.size() == 2);
|
||||||
|
|
||||||
// we won't need these tensors again, reset the context to save memory
|
// we won't need these tensors again, reset the context to save memory
|
||||||
ggml_free(lora_ctx);
|
lora_ctx.reset(ggml_init(params));
|
||||||
lora_ctx = ggml_init(params);
|
|
||||||
lora_tensors.clear();
|
lora_tensors.clear();
|
||||||
|
|
||||||
n_tensors++;
|
n_tensors++;
|
||||||
@ -9118,12 +9148,6 @@ static int llama_apply_lora_from_file_internal(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this should be in a destructor, it will leak on failure
|
|
||||||
ggml_free(lora_ctx);
|
|
||||||
if (base_ctx) {
|
|
||||||
ggml_free(base_ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
|
const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
|
||||||
LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
|
LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
|
||||||
|
|
||||||
@ -10288,6 +10312,7 @@ float * llama_get_logits(struct llama_context * ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
|
||||||
|
assert(ctx->logits_valid.at(i));
|
||||||
return ctx->logits.data() + i*ctx->model.hparams.n_vocab;
|
return ctx->logits.data() + i*ctx->model.hparams.n_vocab;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
1
llama.h
1
llama.h
@ -39,6 +39,7 @@
|
|||||||
|
|
||||||
#define LLAMA_MAX_RNG_STATE (64*1024)
|
#define LLAMA_MAX_RNG_STATE (64*1024)
|
||||||
|
|
||||||
|
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
|
Loading…
Reference in New Issue
Block a user