mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 19:04:35 +00:00
llama: string_split fix (#10022)
* llama: Refactor string_split to use template specialization, fixes parsing strings with spaces * llama: Add static_assert in the string_split template to ensure the correct template specialization is used for std::string
This commit is contained in:
parent
2f8bd2b901
commit
d80fb71f8b
@ -128,13 +128,13 @@ static void common_params_handle_model_default(common_params & params) {
|
|||||||
}
|
}
|
||||||
params.hf_file = params.model;
|
params.hf_file = params.model;
|
||||||
} else if (params.model.empty()) {
|
} else if (params.model.empty()) {
|
||||||
params.model = fs_get_cache_file(string_split(params.hf_file, '/').back());
|
params.model = fs_get_cache_file(string_split<std::string>(params.hf_file, '/').back());
|
||||||
}
|
}
|
||||||
} else if (!params.model_url.empty()) {
|
} else if (!params.model_url.empty()) {
|
||||||
if (params.model.empty()) {
|
if (params.model.empty()) {
|
||||||
auto f = string_split(params.model_url, '#').front();
|
auto f = string_split<std::string>(params.model_url, '#').front();
|
||||||
f = string_split(f, '?').front();
|
f = string_split<std::string>(f, '?').front();
|
||||||
params.model = fs_get_cache_file(string_split(f, '/').back());
|
params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||||
}
|
}
|
||||||
} else if (params.model.empty()) {
|
} else if (params.model.empty()) {
|
||||||
params.model = DEFAULT_MODEL_PATH;
|
params.model = DEFAULT_MODEL_PATH;
|
||||||
@ -879,7 +879,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
{"--samplers"}, "SAMPLERS",
|
{"--samplers"}, "SAMPLERS",
|
||||||
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
|
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
const auto sampler_names = string_split(value, ';');
|
const auto sampler_names = string_split<std::string>(value, ';');
|
||||||
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
|
@ -416,19 +416,6 @@ std::string string_format(const char * fmt, ...) {
|
|||||||
return std::string(buf.data(), size);
|
return std::string(buf.data(), size);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> string_split(std::string input, char separator) {
|
|
||||||
std::vector<std::string> parts;
|
|
||||||
size_t separator_pos = input.find(separator);
|
|
||||||
while (separator_pos != std::string::npos) {
|
|
||||||
std::string part = input.substr(0, separator_pos);
|
|
||||||
parts.emplace_back(part);
|
|
||||||
input = input.substr(separator_pos + 1);
|
|
||||||
separator_pos = input.find(separator);
|
|
||||||
}
|
|
||||||
parts.emplace_back(input);
|
|
||||||
return parts;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string string_strip(const std::string & str) {
|
std::string string_strip(const std::string & str) {
|
||||||
size_t start = 0;
|
size_t start = 0;
|
||||||
size_t end = str.size();
|
size_t end = str.size();
|
||||||
|
@ -380,8 +380,6 @@ bool set_process_priority(enum ggml_sched_priority prio);
|
|||||||
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
||||||
std::string string_format(const char * fmt, ...);
|
std::string string_format(const char * fmt, ...);
|
||||||
|
|
||||||
std::vector<std::string> string_split(std::string input, char separator);
|
|
||||||
|
|
||||||
std::string string_strip(const std::string & str);
|
std::string string_strip(const std::string & str);
|
||||||
std::string string_get_sortable_timestamp();
|
std::string string_get_sortable_timestamp();
|
||||||
|
|
||||||
@ -389,6 +387,7 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
|||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
static std::vector<T> string_split(const std::string & str, char delim) {
|
static std::vector<T> string_split(const std::string & str, char delim) {
|
||||||
|
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
|
||||||
std::vector<T> values;
|
std::vector<T> values;
|
||||||
std::istringstream str_stream(str);
|
std::istringstream str_stream(str);
|
||||||
std::string token;
|
std::string token;
|
||||||
@ -401,6 +400,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
|
|||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
||||||
|
{
|
||||||
|
std::vector<std::string> parts;
|
||||||
|
size_t begin_pos = 0;
|
||||||
|
size_t separator_pos = input.find(separator);
|
||||||
|
while (separator_pos != std::string::npos) {
|
||||||
|
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
|
||||||
|
parts.emplace_back(part);
|
||||||
|
begin_pos = separator_pos + 1;
|
||||||
|
separator_pos = input.find(separator, begin_pos);
|
||||||
|
}
|
||||||
|
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
|
||||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||||
void string_process_escapes(std::string & input);
|
void string_process_escapes(std::string & input);
|
||||||
|
|
||||||
|
@ -2380,7 +2380,7 @@ int main(int argc, char ** argv) {
|
|||||||
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
|
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
|
||||||
server_state current_state = state.load();
|
server_state current_state = state.load();
|
||||||
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
||||||
auto tmp = string_split(req.path, '.');
|
auto tmp = string_split<std::string>(req.path, '.');
|
||||||
if (req.path == "/" || tmp.back() == "html") {
|
if (req.path == "/" || tmp.back() == "html") {
|
||||||
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
||||||
res.status = 503;
|
res.status = 503;
|
||||||
|
Loading…
Reference in New Issue
Block a user