mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
Compare commits
4 Commits
6d40322020
...
66619e259b
Author | SHA1 | Date | |
---|---|---|---|
|
66619e259b | ||
|
2cd43f4900 | ||
|
09fe2e7613 | ||
|
d70f5fca74 |
@ -450,6 +450,8 @@ These words will not be included in the completion, so make sure to add them to
|
||||
|
||||
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
|
||||
|
||||
`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error.
|
||||
|
||||
**Response format**
|
||||
|
||||
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
||||
|
@ -92,6 +92,7 @@ struct slot_params {
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
std::vector<std::string> response_fields;
|
||||
bool timings_per_token = false;
|
||||
bool post_sampling_probs = false;
|
||||
bool ignore_eos = false;
|
||||
@ -209,6 +210,7 @@ struct server_task {
|
||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||
|
||||
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
||||
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
||||
@ -522,6 +524,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
|
||||
bool post_sampling_probs;
|
||||
std::vector<completion_token_output> probs_output;
|
||||
std::vector<std::string> response_fields;
|
||||
|
||||
slot_params generation_params;
|
||||
|
||||
@ -568,7 +571,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
if (!stream && !probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||
}
|
||||
return res;
|
||||
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
||||
}
|
||||
|
||||
json to_json_oaicompat_chat() {
|
||||
@ -2066,6 +2069,7 @@ struct server_context {
|
||||
res->tokens = slot.generated_tokens;
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||
res->response_fields = slot.params.response_fields;
|
||||
|
||||
res->truncated = slot.truncated;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
|
@ -95,7 +95,7 @@ def test_consistent_result_same_seed(n_slots: int):
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"seed": 42,
|
||||
"temperature": 1.0,
|
||||
"temperature": 0.0,
|
||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||
})
|
||||
if last_res is not None:
|
||||
@ -120,9 +120,10 @@ def test_different_result_different_seed(n_slots: int):
|
||||
assert res.body["content"] != last_res.body["content"]
|
||||
last_res = res
|
||||
|
||||
|
||||
# TODO figure why it don't work with temperature = 1
|
||||
# @pytest.mark.parametrize("temperature", [0.0, 1.0])
|
||||
@pytest.mark.parametrize("n_batch", [16, 32])
|
||||
@pytest.mark.parametrize("temperature", [0.0, 1.0])
|
||||
@pytest.mark.parametrize("temperature", [0.0])
|
||||
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
|
||||
global server
|
||||
server.n_batch = n_batch
|
||||
@ -257,6 +258,40 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
||||
# assert match_regex(re_content, res.body["content"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,n_predict,response_fields",
|
||||
[
|
||||
("I believe the meaning of life is", 8, []),
|
||||
("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
|
||||
],
|
||||
)
|
||||
def test_completion_response_fields(
|
||||
prompt: str, n_predict: int, response_fields: list[str]
|
||||
):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request(
|
||||
"POST",
|
||||
"/completion",
|
||||
data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": prompt,
|
||||
"response_fields": response_fields,
|
||||
},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert "content" in res.body
|
||||
assert len(res.body["content"])
|
||||
if len(response_fields):
|
||||
assert res.body["generation_settings/n_predict"] == n_predict
|
||||
assert res.body["prompt"] == "<s> " + prompt
|
||||
assert isinstance(res.body["content"], str)
|
||||
assert len(res.body) == len(response_fields)
|
||||
else:
|
||||
assert len(res.body)
|
||||
assert "generation_settings" in res.body
|
||||
|
||||
|
||||
def test_n_probs():
|
||||
global server
|
||||
server.start()
|
||||
|
@ -90,6 +90,28 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// get value by path(key1 / key2)
|
||||
static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
|
||||
json result = json::object();
|
||||
|
||||
for (const std::string & path : paths) {
|
||||
json current = js;
|
||||
const auto keys = string_split<std::string>(path, /*separator*/ '/');
|
||||
bool valid_path = true;
|
||||
for (const std::string & k : keys) {
|
||||
if (valid_path && current.is_object() && current.contains(k)) {
|
||||
current = current[k];
|
||||
} else {
|
||||
valid_path = false;
|
||||
}
|
||||
}
|
||||
if (valid_path) {
|
||||
result[path] = current;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* this handles 2 cases:
|
||||
* - only string, example: "string"
|
||||
|
@ -7419,14 +7419,14 @@ static void ggml_compute_forward_mul_mat(
|
||||
if (src1_cont) {
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++)
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++)
|
||||
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
||||
if (!llamafile_sgemm(params,
|
||||
ne01, ne11, ne00/ggml_blck_size(src0->type),
|
||||
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
||||
nb01/ggml_type_size(src0->type),
|
||||
(const char *)src1->data + i12*nb12 + i13*nb13,
|
||||
nb11/ggml_type_size(src1->type),
|
||||
(char *)dst->data + i12*nb2 + i13*nb3,
|
||||
nb1/ggml_type_size(dst->type),
|
||||
ith, nth,
|
||||
src0->type,
|
||||
src1->type,
|
||||
dst->type))
|
||||
@ -7471,14 +7471,14 @@ UseGgmlGemm1:;
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++)
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++)
|
||||
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
||||
if (!llamafile_sgemm(params,
|
||||
ne01, ne11, ne00/ggml_blck_size(src0->type),
|
||||
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
||||
nb01/ggml_type_size(src0->type),
|
||||
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
|
||||
row_size/ggml_type_size(vec_dot_type),
|
||||
(char *)dst->data + i12*nb2 + i13*nb3,
|
||||
nb1/ggml_type_size(dst->type),
|
||||
ith, nth,
|
||||
src0->type,
|
||||
vec_dot_type,
|
||||
dst->type))
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -5,8 +5,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
|
||||
const void *, int64_t, void *, int64_t, int, int,
|
||||
bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t, int64_t, int64_t,
|
||||
const void *, int64_t, const void *, int64_t, void *, int64_t,
|
||||
int, int, int);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -126,6 +126,8 @@ connection = sqlite3.connect(input_file)
|
||||
cursor = connection.cursor()
|
||||
builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
|
||||
|
||||
commit_short_len = len(builds[0][0])
|
||||
|
||||
try:
|
||||
repo = git.Repo(".", search_parent_directories=True)
|
||||
except git.InvalidGitRepositoryError:
|
||||
@ -138,11 +140,11 @@ def find_parent_in_data(commit: git.Commit):
|
||||
seen_hexsha8 = set()
|
||||
while heap:
|
||||
depth, current_commit = heapq.heappop(heap)
|
||||
current_hexsha8 = commit.hexsha[:8]
|
||||
current_hexsha8 = commit.hexsha[:commit_short_len]
|
||||
if (current_hexsha8,) in builds:
|
||||
return current_hexsha8
|
||||
for parent in commit.parents:
|
||||
parent_hexsha8 = parent.hexsha[:8]
|
||||
parent_hexsha8 = parent.hexsha[:commit_short_len]
|
||||
if parent_hexsha8 not in seen_hexsha8:
|
||||
seen_hexsha8.add(parent_hexsha8)
|
||||
heapq.heappush(heap, (depth + 1, parent))
|
||||
@ -156,9 +158,9 @@ def get_all_parent_hexsha8s(commit: git.Commit):
|
||||
|
||||
while unvisited:
|
||||
current_commit = unvisited.pop(0)
|
||||
visited.append(current_commit.hexsha[:8])
|
||||
visited.append(current_commit.hexsha[:commit_short_len])
|
||||
for parent in current_commit.parents:
|
||||
if parent.hexsha[:8] not in visited:
|
||||
if parent.hexsha[:commit_short_len] not in visited:
|
||||
unvisited.append(parent)
|
||||
|
||||
return visited
|
||||
@ -169,10 +171,10 @@ def get_commit_name(hexsha8):
|
||||
if repo is None:
|
||||
return hexsha8
|
||||
for h in repo.heads:
|
||||
if h.commit.hexsha[:8] == hexsha8:
|
||||
if h.commit.hexsha[:commit_short_len] == hexsha8:
|
||||
return h.name
|
||||
for t in repo.tags:
|
||||
if t.commit.hexsha[:8] == hexsha8:
|
||||
if t.commit.hexsha[:commit_short_len] == hexsha8:
|
||||
return t.name
|
||||
return hexsha8
|
||||
|
||||
@ -183,13 +185,13 @@ def get_commit_hexsha8(name):
|
||||
return None
|
||||
for h in repo.heads:
|
||||
if h.name == name:
|
||||
return h.commit.hexsha[:8]
|
||||
return h.commit.hexsha[:commit_short_len]
|
||||
for t in repo.tags:
|
||||
if t.name == name:
|
||||
return t.commit.hexsha[:8]
|
||||
return t.commit.hexsha[:commit_short_len]
|
||||
for c in repo.iter_commits("--all"):
|
||||
if c.hexsha[:8] == name[:8]:
|
||||
return c.hexsha[:8]
|
||||
if c.hexsha[:commit_short_len] == name[:commit_short_len]:
|
||||
return c.hexsha[:commit_short_len]
|
||||
return None
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ function has_cmd {
|
||||
}
|
||||
|
||||
if has_cmd wget; then
|
||||
cmd="wget -q --show-progress -c -O %s/%s %s"
|
||||
cmd="wget -q -c -O %s/%s %s"
|
||||
elif has_cmd curl; then
|
||||
cmd="curl -C - -f --output-dir %s -o %s -L %s"
|
||||
else
|
||||
|
Loading…
Reference in New Issue
Block a user