diff --git a/common/minja.hpp b/common/minja.hpp index 77d0ca450..a6e0bfcd4 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -236,7 +236,7 @@ public: if (it == object_->end()) return Value(); return it->second; } - throw std::runtime_error("Value is not an array or object: " + dump()); + return Value(); } void set(const Value& key, const Value& value) { if (!object_) throw std::runtime_error("Value is not an object: " + dump()); @@ -1092,15 +1092,24 @@ public: if (!index) throw std::runtime_error("SubscriptExpr.index is null"); auto target_value = base->evaluate(context); if (auto slice = dynamic_cast(index.get())) { - if (!target_value.is_array()) throw std::runtime_error("Subscripting non-array"); - - auto start = slice->start ? slice->start->evaluate(context).get() : 0; - auto end = slice->end ? slice->end->evaluate(context).get() : target_value.size(); - auto result = Value::array(); - for (auto i = start; i < end; ++i) { - result.push_back(target_value.at(i)); + auto start = slice->start ? slice->start->evaluate(context).get() : 0; + auto end = slice->end ? slice->end->evaluate(context).get() : (int64_t) target_value.size(); + if (target_value.is_string()) { + std::string s = target_value.get(); + if (start < 0) start = s.size() + start; + if (end < 0) end = s.size() + end; + return s.substr(start, end - start); + } else if (target_value.is_array()) { + if (start < 0) start = target_value.size() + start; + if (end < 0) end = target_value.size() + end; + auto result = Value::array(); + for (auto i = start; i < end; ++i) { + result.push_back(target_value.at(i)); + } + return result; + } else { + throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings"); } - return result; } else { auto index_value = index->evaluate(context); if (target_value.is_null()) { @@ -1247,6 +1256,9 @@ public: if (!object) throw std::runtime_error("MethodCallExpr.object is null"); if (!method) throw std::runtime_error("MethodCallExpr.method is null"); auto obj = object->evaluate(context); + if (obj.is_null()) { + throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null"); + } if (obj.is_array()) { if (method->get_name() == "append") { args.expectArgs("append method", {1, 1}, {0, 0}); @@ -2403,6 +2415,10 @@ inline std::shared_ptr Context::builtins() { globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { return args.at("value"); })); + globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("value"); + return items.to_str(); + })); globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { auto & items = args.at("items"); if (!items.is_array()) throw std::runtime_error("object is not iterable"); diff --git a/scripts/update_jinja_goldens.py b/scripts/update_jinja_goldens.py index 3570c5243..a90adf942 100644 --- a/scripts/update_jinja_goldens.py +++ b/scripts/update_jinja_goldens.py @@ -60,6 +60,7 @@ model_ids = [ # Gated models: "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Meta-Llama-3.1-8B-Instruct", + "mistralai/Mistral-Nemo-Instruct-2407", "google/gemma-7b-it", "google/gemma-2-2b-it", "mistralai/Mistral-7B-Instruct-v0.2", diff --git a/tests/chat/goldens/mistralai-Mistral-Nemo-Instruct-2407-simple.txt b/tests/chat/goldens/mistralai-Mistral-Nemo-Instruct-2407-simple.txt new file mode 100644 index 000000000..6119fde30 --- /dev/null +++ b/tests/chat/goldens/mistralai-Mistral-Nemo-Instruct-2407-simple.txt @@ -0,0 +1 @@ +<|startoftext|>[INST]What's your favourite LLM framework?[/INST]llama.cpp!<|endoftext|> \ No newline at end of file diff --git a/tests/chat/goldens/mistralai-Mistral-Nemo-Instruct-2407-system.txt b/tests/chat/goldens/mistralai-Mistral-Nemo-Instruct-2407-system.txt new file mode 100644 index 000000000..6119fde30 --- /dev/null +++ b/tests/chat/goldens/mistralai-Mistral-Nemo-Instruct-2407-system.txt @@ -0,0 +1 @@ +<|startoftext|>[INST]What's your favourite LLM framework?[/INST]llama.cpp!<|endoftext|> \ No newline at end of file diff --git a/tests/chat/goldens/mistralai-Mistral-Nemo-Instruct-2407-tool_use.txt b/tests/chat/goldens/mistralai-Mistral-Nemo-Instruct-2407-tool_use.txt new file mode 100644 index 000000000..d92e446c0 --- /dev/null +++ b/tests/chat/goldens/mistralai-Mistral-Nemo-Instruct-2407-tool_use.txt @@ -0,0 +1 @@ +<|startoftext|>[INST]Print a hello world message with python.[/INST][TOOL_CALLS][{"arguments": "{\"code\": \"print('Hello, World!')\"}", "name": "ipython", "id": "call_1___"}]<|endoftext|>[TOOL_RESULTS]{"content": {"stdout": "Hello, World!"}, "call_id": "call_1___"}[/TOOL_RESULTS]Anything else?<|endoftext|>[INST]Test a tautology.[/INST][TOOL_CALLS][{"arguments": "{\"condition\":true}", "name": "test", "id": "call_2___"}]<|endoftext|>[TOOL_RESULTS]{"content": true, "call_id": "call_2___"}[/TOOL_RESULTS]Truth is definitely true.<|endoftext|>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "ipython", "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The code to run in the ipython interpreter."}}, "required": ["code"]}}}, {"type": "function", "function": {"name": "brave_search", "description": "Executes a web search with Brave.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to search for."}}, "required": ["query"]}}}, {"type": "function", "function": {"name": "wolfram_alpha", "description": "Executes a query with Wolfram Alpha.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "The query to execute."}}, "required": ["query"]}}}, {"type": "function", "function": {"name": "test", "description": "Runs a test.", "parameters": {"type": "object", "properties": {"condition": {"type": "boolean", "description": "The condition to test."}}, "required": ["condition"]}}}][/AVAILABLE_TOOLS][INST]Check it on the web.[/INST][TOOL_CALLS][{"arguments": "{\"query\": \"what is truth anyway am I right?\"}", "name": "brave_search", "id": "call_3___"}]<|endoftext|>[TOOL_RESULTS]{"content": {"title":"Truth: don't ask the web, ask an LLM instead!","url":"https://en.wikipedia.org/wiki/Truth"}, "call_id": "call_3___"}[/TOOL_RESULTS]I don't need the web to answer you but I did check, as you asked. What now?<|endoftext|> \ No newline at end of file diff --git a/tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja b/tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja new file mode 100644 index 000000000..9c21a3f13 --- /dev/null +++ b/tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja @@ -0,0 +1,87 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{#- This block checks for alternating user/assistant messages, skipping tool calling messages #} +{%- set ns = namespace() %} +{%- set ns.index = 0 %} +{%- for message in loop_messages %} + {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %} + {%- if (message["role"] == "user") != (ns.index % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} + {%- set ns.index = ns.index + 1 %} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS][" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST]" + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST]" + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif (message.tool_calls is defined and message.tool_calls is not none) %} + {{- "[TOOL_CALLS][" }} + {%- for tool_call in message.tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length != 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }} + {%- endif %} + {{- ', "id": "' + tool_call.id + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- message["content"] + eos_token}} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS]{"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/tests/test-minja.cpp b/tests/test-minja.cpp index 2a8e92848..d0bc342b1 100644 --- a/tests/test-minja.cpp +++ b/tests/test-minja.cpp @@ -141,6 +141,9 @@ int main() { lstrip_trim_blocks, " 1" ); + test_render(R"({{ "abcd"[1:-1] }})", {}, {}, "bc"); + test_render(R"({{ [0, 1, 2, 3][1:-1] }})", {}, {}, "[1, 2]"); + test_render(R"({{ "123456789" | length }})", {}, {}, "9"); test_render(R"( {{- 'a' -}}{{ ' ' }}{{- 'b' -}} )", {}, {}, "a b"); test_render(R"( {%- if True %}{%- endif %}{{ ' ' }}{%- for x in [] %}foo{% endfor %}end)", {}, {}, " end"); test_render(R"({% set ns = namespace(is_first=false, nottool=false, and_or=true, delme='') %}{{ ns.is_first }})", {}, {}, "False");