llama.cpp/ggml/include/ggml-backend.h

242 lines
14 KiB
C
Raw Permalink Normal View History

#pragma once
#include "ggml.h"
#include "ggml-alloc.h"
#ifdef __cplusplus
extern "C" {
#endif
typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
typedef struct ggml_backend_event * ggml_backend_event_t;
typedef struct ggml_backend * ggml_backend_t;
typedef void * ggml_backend_graph_plan_t;
//
// Backend buffer
//
// buffer type
GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft);
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
ggml : add Vulkan backend (#2059) * Vulkan loader code * Fix matmul kernel, continue implementation * Continue implementation * Vulkan memory management * Vulkan development * Matmul call * Add aligned malloc and free for VMA * Continue implementation * First matmul success * GEMM Kernel optimization * 1D Blocktiling * 2D Blocktiling * Write coalescing * Continue vulkan implementation and optimization * First FP16 attempt, disabled for now * Code abstraction, FP16 implementation, fix kernel, add FP16 to FP32 kernel * Enable device extensions properly, restore fp16 matmul op * Fix mulmat_f16 * Output FP32 in fp16 matmul shader * Fix f16_to_f32 kernel * dequant_q4_0 kernel * Add VMA library * Avoid requesting dedicated memory, VMA can decide that by itself * Add bounds checking to matmul kernels, improve implementation, fix command buffers not freed properly * add cmake commands * Add 2d write operation, profiling code * Fix 2d write * Fix queue selection for AMD RADV * Fix trailing whitespace in vk_mem_alloc.h * Add WIP warp tile mat mul shaders * Disable glslc optimization * Disable glslc optimization for CMake * Optimize warptile matmul shader, replace blocktile with it * Add split-k optimization for small matrix multiplication Use semaphores for synchronization instead of fences or waitidle Rework async write/read for synchronization * Fix validation errors, improve compatibility with AMD GPUs * Rework command buffer handling * Variable matmul kernel using specialization constants * Fix synchronization on AMD, add barriers for buffer ownership transfer, add debug flag and prints * Reuse semaphores * Handle stage flags during command buffer submission properly * Increase matmul test runs for consistent results * Fix F32 matmul * Add vectorized loading and zeropadding for matrix multiplication * Use pinned memory for f16 preprocessing * Don't force aligned matmul * Don't free before queue done * Replace VMA library with native Vulkan buffer management * Basic offloading support with mul_f32 and dmmv for q4_0 * Run glslc commands in parallel * Unroll loops in dmmv shader * Reduce usage of waitIdle * Reuse pinned allocation for f16 conversion * Handle devices with only a single queue * Fix trailing whitespace in CMakeLists.txt * Allow parallel execution of kernels, parallelize third and fourth dimension calls * Add fallback for devices only supporting one DescriptorSet per DescriptorPool * Move to graph function similar to CUDA implementation * Use F16 kernel for most things, replace q_f32 with mul_mat_q_f16 function * Add F32 dmmv shaders * Batch submissions * Add .spv to gitignore * Split off matrix vector multiplication for separate optimization * Use single command buffer for matrix vector multiplication ops * Reduce overhead of mul_f32 calls by using a single command buffer * Add submission batching to mul_f32 * Fix tests * Add missing barrier * Add further missing barrier * Add further ops * Replace vk::QueueFamilyIgnored with VK_QUEUE_FAMILY_IGNORED to support more Vulkan header versions * Remove unnecessary cblas link * Fix descriptor set pre-allocation assert * Add runtime shader compilation, start transferring shaders to this approach * Transfer remaining shaders to header and compile on runtime * Fix fp32 fallback if device doesn't support fp16, add force disable env var GGML_VULKAN_DISABLE_F16 * Add support for q4_1, q5_0, q5_1 and q8_0 * Remove unnecessary scalar layout extension * Parse graph early to pre-record command buffers * Add q6_k support * Add multi-submit for command buffers * Fix q6_k dequant shader for AMD * Fix q6_k for GPUs without fp16 support * Simplify q6_k fp16 fix * Minor fixes * Fix wg_denom of m-mulmat shaders * Add Python-based Vulkan shader generator * Replace shaderc dependency with precompiled shaders Fix python script to generate shaders * Clean up code * Fix shader generator script Windows compatibility Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com> * Close file before deletion * Fix vulkan shader fp32 name * Add q2_k and q3_k support Add validation check to compare shader results to cpu results * Add q4_k support * Add q5_k support * Bake SPIR-V bytecode into the library instead of loading shaders from file * Switch to signal semaphores for flexibility Prepare broadcasting support for mul mat * Finish broadcasting mul mat support for GQA * Clean up unused functions Add repeat op * Add further ops, not yet enabled. Improve semaphore code * Reduce number of used semaphores by utilizing timelines more properly * Remove queue information * Reuse timeline semaphores, allow parallel operation with binary semaphores to work around nvidia driver limitations * Add Vulkan to llama-bench * Remove cblas dependency * Fix matmul k-split bug * Fix q4_k dmmv K_QUANTS_PER_ITERATION 1 shader * Add RMS Norm shader, rework op_f32 shader setup, fix matmul bug * Fix issues with float16 overflows in shaders * Fix issues with older Vulkan headers on Ubuntu 22.04 * Allow multi-op partial offloading by parsing the graph to preallocate enough between-op buffers * Implement further ops, rework op_f32 calls, fix bugs * Finish full offloading support, add last remaining ops, fix bugs, remove redundant code * Upload generated file ggml-vulkan-shaders.hpp, remove redundant shaders * Merge upstream changes, fix conflicts, adapt soft_max op * Fix Python and shader header format * Free model gpu buffers on exit * Use single queue per device to simplify code * Add matmul shader support for running multiple calculations in parallel * Switch from semaphore-synchronized multiple command buffers per op to single command buffer for multiple ops, whole graph if possible * Fix missing event cast * Replace uint64_t(-1) with UINT64_MAX, rename function for clarity * Fix warning about empty C function parameters * Fix compiler warnings * Properly implement Vulkan backend buffer handling * Fix oversized host staging buffers * Simplify barrier synchronization calls * Fix gcc warnings * Implement max_size for backend buffer types to limit the size of a single allocation * Use min of maxMemoryAllocationSize and maxBufferSize for device max allocation size * refactor multi buf * Disable unsupported ops to fix tests * Check for maintenance4 support before using it * Handle devices with only a single queue * Fix single queue logic * propagate buffer usage in multi buffers * Implement rope_neox op * Cleanup header and other files * Simplify gpu_extras by removing events and putting staging memcpys into contexts * Move queue into context Add not-yet-enabled async backend ops * Simplify context use, optimize matmul shader for warp size 64 (AMD GCN), fix split_k matmul shader optimization * Add get_max_size to SYCL backend. Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : fix trailing whitespace --------- Co-authored-by: Henri Vasserman <henv@hot.ee> Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-01-28 17:03:59 +00:00
GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft);
GGML_API GGML_CALL size_t ggml_backend_buft_get_alloc_size (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
// buffer
llama : ggml-backend integration (#4766) * llama : ggml-backend integration * ggml-backend : add names to buffers * fix unmap after loading * batched-bench : add tensor_split param * llama : check for null tensor_split * ggml-backend : increase GGML_MAX_BACKENDS * improve graph splitting, partial fix for --no-kv-offload * cuda : add ggml-backend split buffer support * cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available) * ggml : fix null backend dereference (#4807) * ggml : fix null backend dereference * ggml : also check ggml_backend_is_cpu * test-backend-ops : check buffer allocation failures * llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row) * ggml : fix mul_mat_id work size * llama : rewrite session kv load/set without graphs * minor * llama : only initialize used backends, free backends on context free * llama : abort ctx if cuda backend init fails * llama : rewrite lora with ggml-backend and compute on CPU ggml-ci * llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer * opencl : add ggml-backend buffer type * cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf) * llama : on Metal, by default offload the full model ggml-ci * metal : page align the data ptr (#4854) * Apply suggestions from code review Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * cuda : fix split buffer free * address review comments * llama-bench : add split-mode parameter * fix whitespace * opencl : fix double initialization * server : add --split-mode parameter * use async copy and compute to improve multi-gpu performance ggml-ci * use async memcpys to copy the graph outputs to the CPU * fix opencl * use a host buffer for the cpu compute buffer for faster copies to the gpu --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2024-01-12 19:07:38 +00:00
enum ggml_backend_buffer_usage {
GGML_BACKEND_BUFFER_USAGE_ANY = 0,
GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2,
llama : ggml-backend integration (#4766) * llama : ggml-backend integration * ggml-backend : add names to buffers * fix unmap after loading * batched-bench : add tensor_split param * llama : check for null tensor_split * ggml-backend : increase GGML_MAX_BACKENDS * improve graph splitting, partial fix for --no-kv-offload * cuda : add ggml-backend split buffer support * cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available) * ggml : fix null backend dereference (#4807) * ggml : fix null backend dereference * ggml : also check ggml_backend_is_cpu * test-backend-ops : check buffer allocation failures * llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row) * ggml : fix mul_mat_id work size * llama : rewrite session kv load/set without graphs * minor * llama : only initialize used backends, free backends on context free * llama : abort ctx if cuda backend init fails * llama : rewrite lora with ggml-backend and compute on CPU ggml-ci * llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer * opencl : add ggml-backend buffer type * cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf) * llama : on Metal, by default offload the full model ggml-ci * metal : page align the data ptr (#4854) * Apply suggestions from code review Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * cuda : fix split buffer free * address review comments * llama-bench : add split-mode parameter * fix whitespace * opencl : fix double initialization * server : add --split-mode parameter * use async copy and compute to improve multi-gpu performance ggml-ci * use async memcpys to copy the graph outputs to the CPU * fix opencl * use a host buffer for the cpu compute buffer for faster copies to the gpu --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2024-01-12 19:07:38 +00:00
};
GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
GGML_API GGML_CALL void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage (ggml_backend_buffer_t buffer);
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer);
//
// Backend
//
GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend);
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
GGML_API void ggml_backend_free(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
ggml : add Vulkan backend (#2059) * Vulkan loader code * Fix matmul kernel, continue implementation * Continue implementation * Vulkan memory management * Vulkan development * Matmul call * Add aligned malloc and free for VMA * Continue implementation * First matmul success * GEMM Kernel optimization * 1D Blocktiling * 2D Blocktiling * Write coalescing * Continue vulkan implementation and optimization * First FP16 attempt, disabled for now * Code abstraction, FP16 implementation, fix kernel, add FP16 to FP32 kernel * Enable device extensions properly, restore fp16 matmul op * Fix mulmat_f16 * Output FP32 in fp16 matmul shader * Fix f16_to_f32 kernel * dequant_q4_0 kernel * Add VMA library * Avoid requesting dedicated memory, VMA can decide that by itself * Add bounds checking to matmul kernels, improve implementation, fix command buffers not freed properly * add cmake commands * Add 2d write operation, profiling code * Fix 2d write * Fix queue selection for AMD RADV * Fix trailing whitespace in vk_mem_alloc.h * Add WIP warp tile mat mul shaders * Disable glslc optimization * Disable glslc optimization for CMake * Optimize warptile matmul shader, replace blocktile with it * Add split-k optimization for small matrix multiplication Use semaphores for synchronization instead of fences or waitidle Rework async write/read for synchronization * Fix validation errors, improve compatibility with AMD GPUs * Rework command buffer handling * Variable matmul kernel using specialization constants * Fix synchronization on AMD, add barriers for buffer ownership transfer, add debug flag and prints * Reuse semaphores * Handle stage flags during command buffer submission properly * Increase matmul test runs for consistent results * Fix F32 matmul * Add vectorized loading and zeropadding for matrix multiplication * Use pinned memory for f16 preprocessing * Don't force aligned matmul * Don't free before queue done * Replace VMA library with native Vulkan buffer management * Basic offloading support with mul_f32 and dmmv for q4_0 * Run glslc commands in parallel * Unroll loops in dmmv shader * Reduce usage of waitIdle * Reuse pinned allocation for f16 conversion * Handle devices with only a single queue * Fix trailing whitespace in CMakeLists.txt * Allow parallel execution of kernels, parallelize third and fourth dimension calls * Add fallback for devices only supporting one DescriptorSet per DescriptorPool * Move to graph function similar to CUDA implementation * Use F16 kernel for most things, replace q_f32 with mul_mat_q_f16 function * Add F32 dmmv shaders * Batch submissions * Add .spv to gitignore * Split off matrix vector multiplication for separate optimization * Use single command buffer for matrix vector multiplication ops * Reduce overhead of mul_f32 calls by using a single command buffer * Add submission batching to mul_f32 * Fix tests * Add missing barrier * Add further missing barrier * Add further ops * Replace vk::QueueFamilyIgnored with VK_QUEUE_FAMILY_IGNORED to support more Vulkan header versions * Remove unnecessary cblas link * Fix descriptor set pre-allocation assert * Add runtime shader compilation, start transferring shaders to this approach * Transfer remaining shaders to header and compile on runtime * Fix fp32 fallback if device doesn't support fp16, add force disable env var GGML_VULKAN_DISABLE_F16 * Add support for q4_1, q5_0, q5_1 and q8_0 * Remove unnecessary scalar layout extension * Parse graph early to pre-record command buffers * Add q6_k support * Add multi-submit for command buffers * Fix q6_k dequant shader for AMD * Fix q6_k for GPUs without fp16 support * Simplify q6_k fp16 fix * Minor fixes * Fix wg_denom of m-mulmat shaders * Add Python-based Vulkan shader generator * Replace shaderc dependency with precompiled shaders Fix python script to generate shaders * Clean up code * Fix shader generator script Windows compatibility Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com> * Close file before deletion * Fix vulkan shader fp32 name * Add q2_k and q3_k support Add validation check to compare shader results to cpu results * Add q4_k support * Add q5_k support * Bake SPIR-V bytecode into the library instead of loading shaders from file * Switch to signal semaphores for flexibility Prepare broadcasting support for mul mat * Finish broadcasting mul mat support for GQA * Clean up unused functions Add repeat op * Add further ops, not yet enabled. Improve semaphore code * Reduce number of used semaphores by utilizing timelines more properly * Remove queue information * Reuse timeline semaphores, allow parallel operation with binary semaphores to work around nvidia driver limitations * Add Vulkan to llama-bench * Remove cblas dependency * Fix matmul k-split bug * Fix q4_k dmmv K_QUANTS_PER_ITERATION 1 shader * Add RMS Norm shader, rework op_f32 shader setup, fix matmul bug * Fix issues with float16 overflows in shaders * Fix issues with older Vulkan headers on Ubuntu 22.04 * Allow multi-op partial offloading by parsing the graph to preallocate enough between-op buffers * Implement further ops, rework op_f32 calls, fix bugs * Finish full offloading support, add last remaining ops, fix bugs, remove redundant code * Upload generated file ggml-vulkan-shaders.hpp, remove redundant shaders * Merge upstream changes, fix conflicts, adapt soft_max op * Fix Python and shader header format * Free model gpu buffers on exit * Use single queue per device to simplify code * Add matmul shader support for running multiple calculations in parallel * Switch from semaphore-synchronized multiple command buffers per op to single command buffer for multiple ops, whole graph if possible * Fix missing event cast * Replace uint64_t(-1) with UINT64_MAX, rename function for clarity * Fix warning about empty C function parameters * Fix compiler warnings * Properly implement Vulkan backend buffer handling * Fix oversized host staging buffers * Simplify barrier synchronization calls * Fix gcc warnings * Implement max_size for backend buffer types to limit the size of a single allocation * Use min of maxMemoryAllocationSize and maxBufferSize for device max allocation size * refactor multi buf * Disable unsupported ops to fix tests * Check for maintenance4 support before using it * Handle devices with only a single queue * Fix single queue logic * propagate buffer usage in multi buffers * Implement rope_neox op * Cleanup header and other files * Simplify gpu_extras by removing events and putting staging memcpys into contexts * Move queue into context Add not-yet-enabled async backend ops * Simplify context use, optimize matmul shader for warp size 64 (AMD GCN), fix split_k matmul shader optimization * Add get_max_size to SYCL backend. Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : fix trailing whitespace --------- Co-authored-by: Henri Vasserman <henv@hot.ee> Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-01-28 17:03:59 +00:00
GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend);
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
2024-08-27 19:01:45 +00:00
// "offset" refers to the offset of the tensor data for setting/getting data
GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
2024-09-20 16:04:44 +00:00
GGML_API GGML_CALL void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
// tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
// asynchronous copy
// the copy is performed after all the currently queued operations in backend_src
// backend_dst will wait for the copy to complete before performing other operations
// automatic fallback to sync copy if async is not supported
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
// events
GGML_API ggml_backend_event_t ggml_backend_event_new (ggml_backend_t backend);
GGML_API void ggml_backend_event_free (ggml_backend_event_t event);
GGML_API void ggml_backend_event_record (ggml_backend_event_t event);
GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event);
GGML_API void ggml_backend_event_wait (ggml_backend_t backend, ggml_backend_event_t event);
//
// CPU backend
//
GGML_API ggml_backend_t ggml_backend_cpu_init(void);
GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
Threadpool: take 2 (#8672) * Introduce ggml_compute_threadpool - OpenMP functional: check - Vanilla ggml functional: Check - ggml w/threadpool functional: Check - OpenMP no regression: No glaring problems - Vanilla ggml no regression: No glaring problems - ggml w/threadpool no regression: No glaring problems * Minor fixes * fixed use after release bug * fixed a harmless race condition * Fix Android bulid issue * fix more race conditions * fix deadlock for cases where cgraph.n_nodes == 1 and fix --poll case * threadpool: use cpu_get_num_math to set the default number of threadpool threads This way we avoid using E-Cores and Hyperthreaded siblings. * bench: create fresh threadpool for each test For benchmarking it's better to start a fresh pool for each test with the exact number of threads needed for that test. Having larger pools is suboptimal (causes more load, etc). * atomics: always use stdatomics with clang and use relaxed memory order when polling in ggml_barrier This also removes sched_yield() calls from ggml_barrier() to match OpenMP behavior. * threadpool: make polling the default to match openmp behavior All command line args now allow for setting poll to 0 (false). * threadpool: do not wakeup threads in already paused threadpool * fix potential race condition in check_for_work * threadpool: do not create two threadpools if their params are identical * threadpool: reduce pause/resume/wakeup overhead in common cases We now start threadpool in paused state only if we have two. The resume is now implicit (ie new work) which allows for reduced locking and context-switch overhead. * threadpool: add support for hybrid polling poll params (--poll, ...) now specify "polling level", i.e. how aggresively we poll before waiting on cond.var. poll=0 means no polling, 1 means poll for 128K rounds then wait, 2 for 256K rounds, ... The default value of 50 (ie 50x128K rounds) seems like a decent default across modern platforms. We can tune this further as things evolve. * threadpool: reduce the number of barrier required New work is now indicated with an atomic counter that is incremented for each new graph that needs to be computed. This removes the need for extra barrier for clearing the "new_work" and removes the special case for trivial graphs. * threadpool: remove special-casing for disposable threadpools With the efficient hybrid polling there is no need to make disposable pools any different. This simplifies the overall logic and reduces branching. Include n_threads in debug print for disposable threadpool. Declare pause and stop flags as atomic_bool This doesn't actually generate any memory barriers and simply informs the thread sanitizer that these flags can be written & read by different threads without locking. * threadpool: do not clear barrier counters between graphs computes (fixes race with small graphs) This fixes the race condition with very small graphs where the main thread happens to start a new graph while the workers are just about to exit from barriers. * threadpool: use relaxed order for chunk sync Full memory barrier is an overkill for this since each thread works on different chunk * threadpool: remove abort_callback from threadpool state * threadpool: better naming for thread/cpumask releated functions * threadpool: consistent use of int type for n_threads params * threadpool: add support for ggml_threadpool_params_default/init Also removes the need for explicit mask_specified param. all-zero cpumask means use default (usually inherited) cpu affinity mask. * threadpool: move typedef into ggml.h * threadpool: fix apply_priority() function name * threadpool: fix swift wrapper errors due to n_threads int type cleanup * threadpool: enable --cpu-mask and other threadpool related options only if threadpool is enabled * threadpool: replace checks for compute_thread ret code with proper status check * threadpool: simplify threadpool init logic and fix main thread affinity application Most of the init code is now exactly the same between threadpool and openmp. * threadpool: update threadpool resume/pause function names * threadpool: enable openmp by default for now * threadpool: don't forget to free workers state when omp is enabled * threadpool: avoid updating process priority on the platforms that do not require it On Windows we need to change overall process priority class in order to set thread priorities, but on Linux, Mac, etc we do not need to touch the overall process settings. * threadpool: update calling thread prio and affinity only at start/resume This avoids extra syscalls for each graph_compute() * llama-bench: turn threadpool params into vectors, add output headers, etc * llama-bench: add support for cool off between tests --delay This helps for long running tests on platforms that are thermally limited (phones, laptops, etc). --delay (disabled by default) introduces the sleep for N seconds before starting each test. * threadpool: move process priority setting into the apps (bench and cli) This avoids changing the overall process priority on Windows for the apps that use ggml/llama.cpp directy. * threadpool: move all pause/resume logic into ggml * threadpool: futher api cleanup and prep for future refactoring All threadpool related functions and structs use ggml_threadpool prefix. * threadpool: minor indent fixes * threadpool: improve setprioty error message * Update examples/llama-bench/llama-bench.cpp Co-authored-by: slaren <slarengh@gmail.com> * threadpool: fix indent in set_threadpool call * use int32_t for n_thread type in public llama.cpp API * threadpool: use _new and _free instead of _create and _release * fix two more public APIs to use int32_t for n_threads * build: set _GNU_SOURCE for Adroid --------- Co-authored-by: Max Krasnyansky <quic_maxk@quicinc.com> Co-authored-by: fmz <quic_fzaghlou@quic.com> Co-authored-by: Max Krasnyansky <max.krasnyansky@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
2024-08-29 23:20:53 +00:00
GGML_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
// Create a backend buffer from an existing pointer
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
#ifdef GGML_USE_CPU_HBM
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
#endif
//
// Backend registry
//
// The backend registry is a registry of all the available backends, and allows initializing backends in a generic way
GGML_API size_t ggml_backend_reg_get_count(void);
2024-09-20 16:04:44 +00:00
GGML_API size_t ggml_backend_reg_find_by_name(const char * name); // returns index of backend with name, or SIZE_MAX if not found
GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional)
GGML_API const char * ggml_backend_reg_get_name(size_t i);
GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific
GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i);
GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size);
//
// Backend scheduler
//
// The backend scheduler allows for multiple backends to be used together
// Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
// The backends are selected based on:
// - the backend that supports the operation
// - the location of the pre-allocated tensors (e.g. the weights)
/*
Example usage:
// operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned
// preferrably to run on the same backend as the buffer
ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
// initialize buffers from a max size graph (optional)
reserve_graph = build_graph(sched, max_batch_size);
// manually assign nodes to a backend (optional, should not be needed in most cases)
struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu);
ggml_backend_sched_reserve(sched, reserve_graph);
// compute
graph = build_graph(sched);
ggml_backend_sched_graph_compute(sched, graph);
// if there are graph inputs:
ggml_backend_sched_reset(sched);
ggml_backend_sched_alloc_graph(sched, graph);
ggml_backend_tensor_set(input_tensor, ...);
ggml_backend_sched_graph_compute(sched, graph);
}
*/
struct ggml_backend_sched;
typedef struct ggml_backend_sched * ggml_backend_sched_t;
// when ask == true, the scheduler wants to know if the user wants to observe this node
// this allows the scheduler to batch nodes together in order to evaluate them in a single call
//
// when ask == false, the scheduler is passing the node tensor to the user for observation
// if the user returns false, the scheduler will cancel the graph compute
//
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
// Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
// Initialize backend buffers from a measure graph
GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
llama : ggml-backend integration (#4766) * llama : ggml-backend integration * ggml-backend : add names to buffers * fix unmap after loading * batched-bench : add tensor_split param * llama : check for null tensor_split * ggml-backend : increase GGML_MAX_BACKENDS * improve graph splitting, partial fix for --no-kv-offload * cuda : add ggml-backend split buffer support * cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available) * ggml : fix null backend dereference (#4807) * ggml : fix null backend dereference * ggml : also check ggml_backend_is_cpu * test-backend-ops : check buffer allocation failures * llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row) * ggml : fix mul_mat_id work size * llama : rewrite session kv load/set without graphs * minor * llama : only initialize used backends, free backends on context free * llama : abort ctx if cuda backend init fails * llama : rewrite lora with ggml-backend and compute on CPU ggml-ci * llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer * opencl : add ggml-backend buffer type * cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf) * llama : on Metal, by default offload the full model ggml-ci * metal : page align the data ptr (#4854) * Apply suggestions from code review Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * cuda : fix split buffer free * address review comments * llama-bench : add split-mode parameter * fix whitespace * opencl : fix double initialization * server : add --split-mode parameter * use async copy and compute to improve multi-gpu performance ggml-ci * use async memcpys to copy the graph outputs to the CPU * fix opencl * use a host buffer for the cpu compute buffer for faster copies to the gpu --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2024-01-12 19:07:38 +00:00
// Get the number of splits of the last graph
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
llama : ggml-backend integration (#4766) * llama : ggml-backend integration * ggml-backend : add names to buffers * fix unmap after loading * batched-bench : add tensor_split param * llama : check for null tensor_split * ggml-backend : increase GGML_MAX_BACKENDS * improve graph splitting, partial fix for --no-kv-offload * cuda : add ggml-backend split buffer support * cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available) * ggml : fix null backend dereference (#4807) * ggml : fix null backend dereference * ggml : also check ggml_backend_is_cpu * test-backend-ops : check buffer allocation failures * llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row) * ggml : fix mul_mat_id work size * llama : rewrite session kv load/set without graphs * minor * llama : only initialize used backends, free backends on context free * llama : abort ctx if cuda backend init fails * llama : rewrite lora with ggml-backend and compute on CPU ggml-ci * llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer * opencl : add ggml-backend buffer type * cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf) * llama : on Metal, by default offload the full model ggml-ci * metal : page align the data ptr (#4854) * Apply suggestions from code review Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * cuda : fix split buffer free * address review comments * llama-bench : add split-mode parameter * fix whitespace * opencl : fix double initialization * server : add --split-mode parameter * use async copy and compute to improve multi-gpu performance ggml-ci * use async memcpys to copy the graph outputs to the CPU * fix opencl * use a host buffer for the cpu compute buffer for faster copies to the gpu --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2024-01-12 19:07:38 +00:00
// Allocate and compute graph on the backend scheduler
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
// Reset all assignments and allocators - must be called before changing the node backends
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
// Set a callback to be called for each resulting node during graph compute
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
//
// Utils
//
struct ggml_backend_graph_copy {
ggml_backend_buffer_t buffer;
struct ggml_context * ctx_allocated;
struct ggml_context * ctx_unallocated;
struct ggml_cgraph * graph;
};
// Copy a graph to a different backend
GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
typedef bool (*GGML_CALL ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
// Compare the output of two backends
llama : ggml-backend integration (#4766) * llama : ggml-backend integration * ggml-backend : add names to buffers * fix unmap after loading * batched-bench : add tensor_split param * llama : check for null tensor_split * ggml-backend : increase GGML_MAX_BACKENDS * improve graph splitting, partial fix for --no-kv-offload * cuda : add ggml-backend split buffer support * cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available) * ggml : fix null backend dereference (#4807) * ggml : fix null backend dereference * ggml : also check ggml_backend_is_cpu * test-backend-ops : check buffer allocation failures * llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row) * ggml : fix mul_mat_id work size * llama : rewrite session kv load/set without graphs * minor * llama : only initialize used backends, free backends on context free * llama : abort ctx if cuda backend init fails * llama : rewrite lora with ggml-backend and compute on CPU ggml-ci * llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer * opencl : add ggml-backend buffer type * cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf) * llama : on Metal, by default offload the full model ggml-ci * metal : page align the data ptr (#4854) * Apply suggestions from code review Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * cuda : fix split buffer free * address review comments * llama-bench : add split-mode parameter * fix whitespace * opencl : fix double initialization * server : add --split-mode parameter * use async copy and compute to improve multi-gpu performance ggml-ci * use async memcpys to copy the graph outputs to the CPU * fix opencl * use a host buffer for the cpu compute buffer for faster copies to the gpu --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2024-01-12 19:07:38 +00:00
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
// Tensor initialization
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
#ifdef __cplusplus
}
#endif