From 9ac4b04aa221decd070d0c0c3d4a0b0ce9b6769b Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 00:34:07 +0100 Subject: [PATCH] `tool-call`: add fs_list_files to common, w/ win32 impl for msys2 build --- common/common.cpp | 38 ++++++++++++++++++++++++++++++++++++ common/common.h | 1 + tests/test-chat-template.cpp | 26 ++---------------------- 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index e247a2eb4..78263da85 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -44,6 +44,7 @@ #include #include #else +#include #include #include #include @@ -777,6 +778,43 @@ bool fs_create_directory_with_parents(const std::string & path) { #endif // _WIN32 } + +std::vector fs_list_files(const std::string & folder, const std::string & ext) { + std::vector files; + // Note: once we can use C++17 this becomes: + // for (const auto & entry : std::filesystem::directory_iterator(folder)) + // if (entry.path().extension() == ext) files.push_back(entry.path().string()); +#ifdef _WIN32 + std::string search_path = folder + "\\*" + ext; + WIN32_FIND_DATA fd; + HANDLE hFind = ::FindFirstFile(search_path.c_str(), &fd); + if (hFind != INVALID_HANDLE_VALUE) { + do { + if (!(fd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) { + files.push_back(folder + "\\" + fd.cFileName); + } + } while (::FindNextFile(hFind, &fd)); + ::FindClose(hFind); + } +#else + DIR* dir = opendir(folder.c_str()); + if (dir != nullptr) { + struct dirent* entry; + while ((entry = readdir(dir)) != nullptr) { + if (entry->d_type == DT_REG) { // If it's a regular file + std::string filename = entry->d_name; + if (filename.length() >= ext.length() && + filename.compare(filename.length() - ext.length(), ext.length(), ext) == 0) { + files.push_back(folder + "/" + filename); + } + } + } + closedir(dir); + } +#endif + return files; +} + std::string fs_get_cache_directory() { std::string cache_directory = ""; auto ensure_trailing_slash = [](std::string p) { diff --git a/common/common.h b/common/common.h index 64192a9eb..8681899ce 100644 --- a/common/common.h +++ b/common/common.h @@ -397,6 +397,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat bool fs_validate_filename(const std::string & filename); bool fs_create_directory_with_parents(const std::string & path); +std::vector fs_list_files(const std::string & path, const std::string & ext); std::string fs_get_cache_directory(); std::string fs_get_cache_file(const std::string & filename); diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 5781ecb71..64fb5b3c4 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -13,7 +13,6 @@ #include #include #include -#include using json = nlohmann::ordered_json; @@ -39,30 +38,9 @@ static void assert_equals(const T & expected, const T & actual) { } static std::vector find_files(const std::string & folder, const std::string & ext) { - auto do_find = [&](const std::string & folder) { - std::vector files; - // Note: once we can use C++17 this becomes: - // for (const auto & entry : std::filesystem::directory_iterator(folder)) - // if (entry.path().extension() == ext) files.push_back(entry.path().string()); - DIR* dir = opendir(folder.c_str()); - if (dir != nullptr) { - struct dirent* entry; - while ((entry = readdir(dir)) != nullptr) { - if (entry->d_type == DT_REG) { // If it's a regular file - std::string filename = entry->d_name; - if (filename.length() >= ext.length() && - filename.compare(filename.length() - ext.length(), ext.length(), ext) == 0) { - files.push_back(folder + "/" + filename); - } - } - } - closedir(dir); - } - return files; - }; - auto files = do_find(folder); + auto files = fs_list_files(folder, ext); if (files.empty()) { - files = do_find("../" + folder); + files = fs_list_files("../" + folder, ext); } return files; }