mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-30 21:34:36 +00:00
Merge branch 'master' into gg/flash-attn
This commit is contained in:
commit
e307882c34
@ -19,7 +19,12 @@ if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../.git")
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(GIT_INDEX "${GIT_DIR}/index")
|
if(EXISTS "${GIT_DIR}/index")
|
||||||
|
set(GIT_INDEX "${GIT_DIR}/index")
|
||||||
|
else()
|
||||||
|
message(WARNING "Git index not found in git repository.")
|
||||||
|
set(GIT_INDEX "")
|
||||||
|
endif()
|
||||||
else()
|
else()
|
||||||
message(WARNING "Git repository not found; to enable automatic generation of build info, make sure Git is installed and the project is a Git repository.")
|
message(WARNING "Git repository not found; to enable automatic generation of build info, make sure Git is installed and the project is a Git repository.")
|
||||||
set(GIT_INDEX "")
|
set(GIT_INDEX "")
|
||||||
|
@ -513,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_sequences = std::stoi(argv[i]);
|
params.n_sequences = std::stoi(argv[i]);
|
||||||
} else if (arg == "--p-accept" || arg == "-pa") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params.p_accept = std::stof(argv[i]);
|
|
||||||
} else if (arg == "--p-split" || arg == "-ps") {
|
} else if (arg == "--p-split" || arg == "-ps") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -1044,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
||||||
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
|
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
|
||||||
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
|
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
|
||||||
printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
|
|
||||||
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
|
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
|
||||||
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
||||||
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
|
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
|
||||||
|
@ -43,7 +43,7 @@ extern char const *LLAMA_BUILD_TARGET;
|
|||||||
int32_t get_num_physical_cores();
|
int32_t get_num_physical_cores();
|
||||||
|
|
||||||
struct gpt_params {
|
struct gpt_params {
|
||||||
uint32_t seed = -1; // RNG seed
|
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
|
||||||
|
|
||||||
int32_t n_threads = get_num_physical_cores();
|
int32_t n_threads = get_num_physical_cores();
|
||||||
int32_t n_threads_draft = -1;
|
int32_t n_threads_draft = -1;
|
||||||
@ -53,11 +53,10 @@ struct gpt_params {
|
|||||||
int32_t n_ctx = 512; // context size
|
int32_t n_ctx = 512; // context size
|
||||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
|
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
|
||||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||||
int32_t n_parallel = 1; // number of parallel sequences to decode
|
int32_t n_parallel = 1; // number of parallel sequences to decode
|
||||||
int32_t n_sequences = 1; // number of sequences to decode
|
int32_t n_sequences = 1; // number of sequences to decode
|
||||||
float p_accept = 0.5f; // speculative decoding accept probability
|
|
||||||
float p_split = 0.1f; // speculative decoding split probability
|
float p_split = 0.1f; // speculative decoding split probability
|
||||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||||
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||||
|
@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl(
|
|||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static llama_token_data_array llama_sample_probability_distribution_impl(
|
||||||
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
struct llama_context * ctx_cfg,
|
||||||
|
const int idx) {
|
||||||
|
const llama_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
|
||||||
|
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||||
|
const float penalty_repeat = params.penalty_repeat;
|
||||||
|
const float penalty_freq = params.penalty_freq;
|
||||||
|
const float penalty_present = params.penalty_present;
|
||||||
|
const bool penalize_nl = params.penalize_nl;
|
||||||
|
|
||||||
|
auto & prev = ctx_sampling->prev;
|
||||||
|
auto & cur = ctx_sampling->cur;
|
||||||
|
|
||||||
|
// Get a pointer to the logits
|
||||||
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
|
||||||
|
// Declare original_logits at the beginning of the function scope
|
||||||
|
std::vector<float> original_logits;
|
||||||
|
|
||||||
|
// apply params.logit_bias map
|
||||||
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||||
|
logits[it->first] += it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx_cfg) {
|
||||||
|
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
|
||||||
|
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur.clear();
|
||||||
|
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||||
|
|
||||||
|
// apply penalties
|
||||||
|
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
|
||||||
|
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
|
||||||
|
if (penalty_tokens_used_size) {
|
||||||
|
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
|
||||||
|
|
||||||
|
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||||
|
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
|
||||||
|
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
|
||||||
|
|
||||||
|
if (!penalize_nl) {
|
||||||
|
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||||
|
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
|
||||||
|
cur_p.data[idx].logit = nl_logit;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply grammar checks
|
||||||
|
if (ctx_sampling->grammar != NULL) {
|
||||||
|
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sample_softmax(ctx_main, &cur_p);
|
||||||
|
return cur_p;
|
||||||
|
}
|
||||||
|
|
||||||
llama_token llama_sampling_sample(
|
llama_token llama_sampling_sample(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
@ -304,6 +375,14 @@ llama_token llama_sampling_sample(
|
|||||||
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
|
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_token_data_array llama_sampling_probability_distribution(
|
||||||
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
struct llama_context * ctx_cfg,
|
||||||
|
const int idx) {
|
||||||
|
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
|
@ -131,6 +131,13 @@ llama_token llama_sampling_sample(
|
|||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
int idx = 0);
|
int idx = 0);
|
||||||
|
|
||||||
|
// returns the probability that token of given id will be sampled
|
||||||
|
llama_token_data_array llama_sampling_probability_distribution(
|
||||||
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
struct llama_context * ctx_cfg,
|
||||||
|
int idx = 0);
|
||||||
|
|
||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
|
@ -511,6 +511,14 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
std::vector<llama_token> embd_guidance;
|
std::vector<llama_token> embd_guidance;
|
||||||
|
|
||||||
|
// tokenized antiprompts
|
||||||
|
std::vector<std::vector<llama_token>> antiprompt_ids;
|
||||||
|
|
||||||
|
antiprompt_ids.reserve(params.antiprompt.size());
|
||||||
|
for (const std::string & antiprompt : params.antiprompt) {
|
||||||
|
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||||
|
|
||||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||||
@ -769,6 +777,18 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check for reverse prompt using special tokens
|
||||||
|
llama_token last_token = llama_sampling_last(ctx_sampling);
|
||||||
|
for (std::vector<llama_token> ids : antiprompt_ids) {
|
||||||
|
if (ids.size() == 1 && last_token == ids[0]) {
|
||||||
|
if (params.interactive) {
|
||||||
|
is_interacting = true;
|
||||||
|
}
|
||||||
|
is_antiprompt = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (is_antiprompt) {
|
if (is_antiprompt) {
|
||||||
LOG("found antiprompt: %s\n", last_output.c_str());
|
LOG("found antiprompt: %s\n", last_output.c_str());
|
||||||
}
|
}
|
||||||
|
@ -413,7 +413,7 @@ struct llama_server_context
|
|||||||
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
|
int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size());
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||||
sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template
|
sparams.chat_template = "chatml";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,3 +6,4 @@ More info:
|
|||||||
|
|
||||||
- https://github.com/ggerganov/llama.cpp/pull/2926
|
- https://github.com/ggerganov/llama.cpp/pull/2926
|
||||||
- https://github.com/ggerganov/llama.cpp/pull/3624
|
- https://github.com/ggerganov/llama.cpp/pull/3624
|
||||||
|
- https://github.com/ggerganov/llama.cpp/pull/5625
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
||||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||||
@ -18,6 +19,7 @@ struct seq_draft {
|
|||||||
std::vector<int> i_batch_tgt;
|
std::vector<int> i_batch_tgt;
|
||||||
|
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
|
std::vector<std::vector<llama_token_data>> dists;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling;
|
struct llama_sampling_context * ctx_sampling;
|
||||||
};
|
};
|
||||||
@ -37,12 +39,15 @@ int main(int argc, char ** argv) {
|
|||||||
// max number of parallel drafting sequences (i.e. tree branches)
|
// max number of parallel drafting sequences (i.e. tree branches)
|
||||||
const int n_seq_dft = params.n_parallel;
|
const int n_seq_dft = params.n_parallel;
|
||||||
|
|
||||||
// probability threshold for accepting a token from the draft model
|
|
||||||
const float p_accept = params.p_accept;
|
|
||||||
|
|
||||||
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
||||||
const float p_split = params.p_split;
|
const float p_split = params.p_split;
|
||||||
|
|
||||||
|
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||||
|
params.seed = time(NULL);
|
||||||
|
}
|
||||||
|
std::default_random_engine rng(params.seed);
|
||||||
|
std::uniform_real_distribution<> u_dist;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("speculative", "log"));
|
log_set_target(log_filename_generator("speculative", "log"));
|
||||||
LOG_TEE("Log start\n");
|
LOG_TEE("Log start\n");
|
||||||
@ -166,7 +171,9 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<seq_draft> drafts(n_seq_dft);
|
std::vector<seq_draft> drafts(n_seq_dft);
|
||||||
|
|
||||||
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
||||||
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
|
if (params.sparams.temp == 0) {
|
||||||
|
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
|
||||||
|
}
|
||||||
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
@ -182,12 +189,15 @@ int main(int argc, char ** argv) {
|
|||||||
drafts[0].i_batch_tgt[0] = 0;
|
drafts[0].i_batch_tgt[0] = 0;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
std::set<int> active_seqs = {};
|
||||||
|
|
||||||
// print current draft sequences
|
// print current draft sequences
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
if (!drafts[s].active) {
|
if (!drafts[s].active) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
active_seqs.insert(s);
|
||||||
const auto & tokens = drafts[s].tokens;
|
const auto & tokens = drafts[s].tokens;
|
||||||
|
|
||||||
LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
|
LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
|
||||||
@ -196,48 +206,156 @@ int main(int argc, char ** argv) {
|
|||||||
int i_dft = 0;
|
int i_dft = 0;
|
||||||
int s_keep = 0;
|
int s_keep = 0;
|
||||||
|
|
||||||
|
llama_token token_id;
|
||||||
|
std::string token_str;
|
||||||
|
|
||||||
|
// loop until we fail to accept a drafted token or we run out of drafted tokens
|
||||||
while (true) {
|
while (true) {
|
||||||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
||||||
|
|
||||||
// sample from the target model
|
|
||||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
|
|
||||||
|
|
||||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
|
||||||
|
|
||||||
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
|
|
||||||
|
|
||||||
if (!params.use_color) {
|
|
||||||
printf("%s", token_str.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (id == llama_token_eos(model_tgt)) {
|
|
||||||
has_eos = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
++n_predict;
|
|
||||||
|
|
||||||
// check if the target token matches any of the drafts
|
// check if the target token matches any of the drafts
|
||||||
|
// for stochastic sampling, attempt to match the token with the drafted tokens
|
||||||
{
|
{
|
||||||
bool matches = false;
|
bool accept = false;
|
||||||
|
if (params.sparams.temp > 0) {
|
||||||
|
// stochastic verification
|
||||||
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
if (!drafts[s].active) {
|
float p_tgt = 0, p_dft = 0;
|
||||||
continue;
|
|
||||||
|
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
||||||
|
|
||||||
|
while (active_seqs.size() > 0) {
|
||||||
|
// randomly select a sequence to verify from active sequences
|
||||||
|
std::uniform_int_distribution<u_int> u_int_dist(0, active_seqs.size() - 1);
|
||||||
|
int s = *std::next(active_seqs.begin(), u_int_dist(rng));
|
||||||
|
if (i_dft >= (int) drafts[s].tokens.size()) {
|
||||||
|
drafts[s].active = false;
|
||||||
|
active_seqs.erase(s);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (accept) {
|
||||||
|
// if we already accepted a token, we can skip the rest
|
||||||
|
if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
|
||||||
|
drafts[s].active = false;
|
||||||
|
active_seqs.erase(s);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
|
||||||
|
float r = u_dist(rng);
|
||||||
|
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
|
||||||
|
// acquire the token probabilities assigned by the draft and target models
|
||||||
|
for (size_t i = 0; i < dist_tgt.size; i++) {
|
||||||
|
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
||||||
|
p_tgt = dist_tgt.data[i].p;
|
||||||
|
}
|
||||||
|
if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
|
||||||
|
p_dft = dist_dft.data[i].p;
|
||||||
|
}
|
||||||
|
if (p_tgt && p_dft) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
|
||||||
|
if (r <= p_tgt / p_dft) {
|
||||||
|
s_keep = s;
|
||||||
|
accept = true;
|
||||||
|
token_id = drafts[s].tokens[i_dft];
|
||||||
|
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||||
|
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||||
|
|
||||||
|
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
|
||||||
|
drafts[s].active = false;
|
||||||
|
|
||||||
|
// calculate residual probability
|
||||||
|
GGML_ASSERT(dist_tgt.sorted);
|
||||||
|
GGML_ASSERT(dist_dft.sorted);
|
||||||
|
float sum_probs = 0.0f;
|
||||||
|
|
||||||
|
// sort dist by id
|
||||||
|
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
||||||
|
return a.id < b.id;
|
||||||
|
});
|
||||||
|
std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
|
||||||
|
return a.id < b.id;
|
||||||
|
});
|
||||||
|
|
||||||
|
for (size_t i = 0; i < dist_tgt.size; i++) {
|
||||||
|
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
|
||||||
|
sum_probs += dist_tgt.data[i].p;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < dist_tgt.size; i++) {
|
||||||
|
dist_tgt.data[i].p /= sum_probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort dist_tgt by p desc
|
||||||
|
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
||||||
|
return a.p > b.p;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
active_seqs.erase(s);
|
||||||
|
for(int i = 0; i < n_seq_dft; i++) {
|
||||||
|
if (i == s) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
||||||
|
// synchronize active status for sequences with the same drafted token
|
||||||
|
drafts[i].active = drafts[i].active && accept;
|
||||||
|
if (!drafts[i].active) {
|
||||||
|
active_seqs.erase(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
|
if (!accept) {
|
||||||
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
|
// all drafted tokens were rejected
|
||||||
|
// sample from the target model
|
||||||
|
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
|
||||||
|
token_id = llama_sample_token(ctx_tgt, &dist_tgt);
|
||||||
|
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||||
|
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||||
|
}
|
||||||
|
|
||||||
s_keep = s;
|
} else {
|
||||||
matches = true;
|
// greedy verification
|
||||||
} else {
|
|
||||||
drafts[s].active = false;
|
// sample from the target model
|
||||||
|
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
|
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||||
|
|
||||||
|
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
||||||
|
|
||||||
|
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||||
|
|
||||||
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
|
if (!drafts[s].active) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
|
||||||
|
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
|
||||||
|
|
||||||
|
s_keep = s;
|
||||||
|
accept = true;
|
||||||
|
} else {
|
||||||
|
drafts[s].active = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (matches) {
|
if (token_id == llama_token_eos(model_tgt)) {
|
||||||
|
has_eos = true;
|
||||||
|
}
|
||||||
|
++n_predict;
|
||||||
|
|
||||||
|
if (accept) {
|
||||||
++n_accept;
|
++n_accept;
|
||||||
++n_past_tgt;
|
++n_past_tgt;
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
@ -245,17 +363,21 @@ int main(int argc, char ** argv) {
|
|||||||
if (params.use_color) {
|
if (params.use_color) {
|
||||||
// Color token according to its origin sequence
|
// Color token according to its origin sequence
|
||||||
printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
|
printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
|
||||||
fflush(stdout);
|
} else {
|
||||||
|
printf("%s", token_str.c_str());
|
||||||
}
|
}
|
||||||
|
fflush(stdout);
|
||||||
continue;
|
continue;
|
||||||
|
} else {
|
||||||
|
printf("%s", token_str.c_str());
|
||||||
|
fflush(stdout);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (params.use_color) {
|
}
|
||||||
printf("%s", token_str.c_str());
|
|
||||||
}
|
|
||||||
fflush(stdout);
|
|
||||||
|
|
||||||
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
|
{
|
||||||
|
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
|
||||||
|
|
||||||
// TODO: simplify
|
// TODO: simplify
|
||||||
{
|
{
|
||||||
@ -275,21 +397,21 @@ int main(int argc, char ** argv) {
|
|||||||
drafts[s].active = false;
|
drafts[s].active = false;
|
||||||
drafts[s].tokens.clear();
|
drafts[s].tokens.clear();
|
||||||
drafts[s].i_batch_tgt.clear();
|
drafts[s].i_batch_tgt.clear();
|
||||||
|
drafts[s].dists.clear();
|
||||||
}
|
}
|
||||||
// note: will be erased after the speculation phase
|
// note: will be erased after the speculation phase
|
||||||
drafts[0].tokens.push_back(id);
|
drafts[0].tokens.push_back(token_id);
|
||||||
|
drafts[0].dists.push_back(std::vector<llama_token_data>());
|
||||||
drafts[0].i_batch_tgt.push_back(0);
|
drafts[0].i_batch_tgt.push_back(0);
|
||||||
|
|
||||||
llama_batch_clear(batch_dft);
|
llama_batch_clear(batch_dft);
|
||||||
llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true);
|
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||||
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||||
llama_decode (ctx_dft, batch_dft);
|
llama_decode(ctx_dft, batch_dft);
|
||||||
|
|
||||||
++n_past_dft;
|
++n_past_dft;
|
||||||
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_predict > params.n_predict || has_eos) {
|
if (n_predict > params.n_predict || has_eos) {
|
||||||
@ -334,12 +456,6 @@ int main(int argc, char ** argv) {
|
|||||||
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cur_p[0].p < p_accept) {
|
|
||||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
|
|
||||||
drafts[s].drafting = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int> sa(1, s);
|
std::vector<int> sa(1, s);
|
||||||
|
|
||||||
// attempt to split the branch if the probability is high enough
|
// attempt to split the branch if the probability is high enough
|
||||||
@ -367,6 +483,7 @@ int main(int argc, char ** argv) {
|
|||||||
drafts[n_seq_cur].skip = true;
|
drafts[n_seq_cur].skip = true;
|
||||||
|
|
||||||
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
||||||
|
drafts[n_seq_cur].dists = drafts[s].dists;
|
||||||
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||||
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||||
|
|
||||||
@ -389,6 +506,8 @@ int main(int argc, char ** argv) {
|
|||||||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
||||||
|
|
||||||
drafts[s].tokens.push_back(id);
|
drafts[s].tokens.push_back(id);
|
||||||
|
// save cur_p.data into drafts[s].dists
|
||||||
|
drafts[s].dists.push_back(cur_p);
|
||||||
|
|
||||||
// add unique drafted tokens to the target batch
|
// add unique drafted tokens to the target batch
|
||||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||||
@ -440,6 +559,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
||||||
|
drafts[s].dists.erase(drafts[s].dists.begin());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
228
ggml-cuda.cu
228
ggml-cuda.cu
@ -617,6 +617,8 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + Q
|
|||||||
#define CUDA_UPSCALE_BLOCK_SIZE 256
|
#define CUDA_UPSCALE_BLOCK_SIZE 256
|
||||||
#define CUDA_CONCAT_BLOCK_SIZE 256
|
#define CUDA_CONCAT_BLOCK_SIZE 256
|
||||||
#define CUDA_PAD_BLOCK_SIZE 256
|
#define CUDA_PAD_BLOCK_SIZE 256
|
||||||
|
#define CUDA_ARANGE_BLOCK_SIZE 256
|
||||||
|
#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
||||||
#define CUDA_ACC_BLOCK_SIZE 256
|
#define CUDA_ACC_BLOCK_SIZE 256
|
||||||
#define CUDA_IM2COL_BLOCK_SIZE 256
|
#define CUDA_IM2COL_BLOCK_SIZE 256
|
||||||
#define CUDA_POOL2D_BLOCK_SIZE 256
|
#define CUDA_POOL2D_BLOCK_SIZE 256
|
||||||
@ -1014,17 +1016,21 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
|
|||||||
nidx +
|
nidx +
|
||||||
blockIdx.y * ne0 +
|
blockIdx.y * ne0 +
|
||||||
blockIdx.z * ne0 * gridDim.y;
|
blockIdx.z * ne0 * gridDim.y;
|
||||||
dst[offset_dst] = x[offset_src];
|
dst[offset_dst] = x[offset_src];
|
||||||
} else {
|
} else {
|
||||||
int offset_src =
|
int offset_src =
|
||||||
nidx +
|
nidx +
|
||||||
blockIdx.y * ne0 +
|
blockIdx.y * ne0 +
|
||||||
(blockIdx.z - ne02) * ne0 * gridDim.y;
|
(blockIdx.z - ne02) * ne0 * gridDim.y;
|
||||||
dst[offset_dst] = y[offset_src];
|
dst[offset_dst] = y[offset_src];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) {
|
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) {
|
||||||
|
// blockIdx.z: idx of ne02*ne03
|
||||||
|
// blockIdx.y: idx of ne01*scale_factor, aka ne1
|
||||||
|
// blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE
|
||||||
|
// ne00xne01: ne00 * ne01
|
||||||
int ne0 = ne00 * scale_factor;
|
int ne0 = ne00 * scale_factor;
|
||||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
if (nidx >= ne0) {
|
if (nidx >= ne0) {
|
||||||
@ -1036,7 +1042,7 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
|
|||||||
int offset_src =
|
int offset_src =
|
||||||
i00 +
|
i00 +
|
||||||
i01 * ne00 +
|
i01 * ne00 +
|
||||||
blockIdx.z * nb02;
|
blockIdx.z * ne00xne01;
|
||||||
int offset_dst =
|
int offset_dst =
|
||||||
nidx +
|
nidx +
|
||||||
blockIdx.y * ne0 +
|
blockIdx.y * ne0 +
|
||||||
@ -1044,7 +1050,10 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
|
|||||||
dst[offset_dst] = x[offset_src];
|
dst[offset_dst] = x[offset_src];
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) {
|
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
|
||||||
|
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
|
||||||
|
// blockIdx.y: idx of ne1
|
||||||
|
// blockIDx.x: idx of ne0 / BLOCK_SIZE
|
||||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
if (nidx >= ne0) {
|
if (nidx >= ne0) {
|
||||||
return;
|
return;
|
||||||
@ -1055,19 +1064,53 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
|
|||||||
nidx +
|
nidx +
|
||||||
blockIdx.y * ne0 +
|
blockIdx.y * ne0 +
|
||||||
blockIdx.z * ne0 * gridDim.y;
|
blockIdx.z * ne0 * gridDim.y;
|
||||||
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
|
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
|
||||||
int offset_src =
|
int offset_src =
|
||||||
nidx +
|
nidx +
|
||||||
blockIdx.y * ne00 +
|
blockIdx.y * ne00 +
|
||||||
blockIdx.z * ne00 * ne01;
|
blockIdx.z * ne00 * ne01;
|
||||||
dst[offset_dst] = x[offset_src];
|
dst[offset_dst] = x[offset_src];
|
||||||
} else {
|
} else {
|
||||||
dst[offset_dst] = 0.0f;
|
dst[offset_dst] = 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
|
||||||
|
// blockIDx.x: idx of ne0 / BLOCK_SIZE
|
||||||
|
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
if (nidx >= ne0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dst[nidx] = start + step * nidx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
|
||||||
|
// blockIDx.y: idx of timesteps->ne[0]
|
||||||
|
// blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
|
||||||
|
int i = blockIdx.y;
|
||||||
|
int j = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
float * embed_data = (float *)((char *)dst + i*nb1);
|
||||||
|
|
||||||
|
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
|
||||||
|
embed_data[dim] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int half = dim / 2;
|
||||||
|
if (j >= half) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float timestep = timesteps[i];
|
||||||
|
float freq = (float)expf(-logf(max_period) * j / half);
|
||||||
|
float arg = timestep * freq;
|
||||||
|
embed_data[j] = cosf(arg);
|
||||||
|
embed_data[j + half] = sinf(arg);
|
||||||
|
}
|
||||||
|
|
||||||
template <int block_size>
|
template <int block_size>
|
||||||
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
|
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
|
||||||
|
// blockIdx.x: num_groups idx
|
||||||
|
// threadIdx.x: block_size idx
|
||||||
int start = blockIdx.x * group_size;
|
int start = blockIdx.x * group_size;
|
||||||
int end = start + group_size;
|
int end = start + group_size;
|
||||||
|
|
||||||
@ -6473,7 +6516,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
|||||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13) {
|
const int nb12, const int nb13) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i >= ne) {
|
if (i >= ne) {
|
||||||
return;
|
return;
|
||||||
@ -6481,17 +6524,17 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
|||||||
|
|
||||||
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
||||||
// then combine those indices with the corresponding byte offsets to get the total offsets
|
// then combine those indices with the corresponding byte offsets to get the total offsets
|
||||||
const int i03 = i/(ne00 * ne01 * ne02);
|
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||||
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
||||||
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
||||||
|
|
||||||
const int i13 = i/(ne10 * ne11 * ne12);
|
const int64_t i13 = i/(ne10 * ne11 * ne12);
|
||||||
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||||
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||||
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||||
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
|
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
|
||||||
|
|
||||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||||
}
|
}
|
||||||
@ -6929,6 +6972,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha
|
|||||||
// find the sum of exps in the block
|
// find the sum of exps in the block
|
||||||
tmp = warp_reduce_sum(tmp);
|
tmp = warp_reduce_sum(tmp);
|
||||||
if (block_size > WARP_SIZE) {
|
if (block_size > WARP_SIZE) {
|
||||||
|
__syncthreads();
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
buf_iw[lane_id] = 0.0f;
|
buf_iw[lane_id] = 0.0f;
|
||||||
}
|
}
|
||||||
@ -6980,23 +7024,23 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void im2col_kernel(
|
static __global__ void im2col_kernel(
|
||||||
const float * x, T * dst, int batch_offset,
|
const float * x, T * dst, int64_t batch_offset,
|
||||||
int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
|
int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
|
||||||
int s0, int s1, int p0, int p1, int d0, int d1) {
|
int s0, int s1, int p0, int p1, int d0, int d1) {
|
||||||
const int i = threadIdx.x + blockIdx.x * blockDim.x;
|
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
if (i >= pelements) {
|
if (i >= pelements) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int ksize = OW * (KH > 1 ? KW : 1);
|
const int64_t ksize = OW * (KH > 1 ? KW : 1);
|
||||||
const int kx = i / ksize;
|
const int64_t kx = i / ksize;
|
||||||
const int kd = kx * ksize;
|
const int64_t kd = kx * ksize;
|
||||||
const int ky = (i - kd) / OW;
|
const int64_t ky = (i - kd) / OW;
|
||||||
const int ix = i % OW;
|
const int64_t ix = i % OW;
|
||||||
|
|
||||||
const int oh = blockIdx.y;
|
const int64_t oh = blockIdx.y;
|
||||||
const int batch = blockIdx.z / IC;
|
const int64_t batch = blockIdx.z / IC;
|
||||||
const int ic = blockIdx.z % IC;
|
const int64_t ic = blockIdx.z % IC;
|
||||||
|
|
||||||
const int64_t iiw = ix * s0 + kx * d0 - p0;
|
const int64_t iiw = ix * s0 + kx * d0 - p0;
|
||||||
const int64_t iih = oh * s1 + ky * d1 - p1;
|
const int64_t iih = oh * s1 + ky * d1 - p1;
|
||||||
@ -7852,19 +7896,33 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, const
|
|||||||
concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
|
concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
|
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03,
|
||||||
|
const int scale_factor, cudaStream_t stream) {
|
||||||
int ne0 = (ne00 * scale_factor);
|
int ne0 = (ne00 * scale_factor);
|
||||||
int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||||
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
|
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02*ne03);
|
||||||
upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
|
upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void pad_f32_cuda(const float * x, float * dst,
|
static void pad_f32_cuda(const float * x, float * dst,
|
||||||
const int ne00, const int ne01, const int ne02,
|
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||||
const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
|
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
|
||||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||||
dim3 gridDim(num_blocks, ne1, ne2);
|
dim3 gridDim(num_blocks, ne1, ne2*ne3);
|
||||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
|
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
|
||||||
|
int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
|
||||||
|
arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
|
||||||
|
const int dim, const int max_period, cudaStream_t stream) {
|
||||||
|
int half_ceil = (dim + 1) / 2;
|
||||||
|
int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
|
||||||
|
dim3 gridDim(num_blocks, ne00, 1);
|
||||||
|
timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||||
@ -8997,8 +9055,8 @@ static void soft_max_f32_cuda(const float * x, const half * mask, const half * p
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void im2col_cuda(const float* x, T* dst,
|
static void im2col_cuda(const float* x, T* dst,
|
||||||
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
|
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
|
||||||
int batch, int batch_offset, int offset_delta,
|
int64_t batch, int64_t batch_offset, int64_t offset_delta,
|
||||||
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
||||||
const int parallel_elements = OW * KW * KH;
|
const int parallel_elements = OW * KW * KH;
|
||||||
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
|
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
|
||||||
@ -9684,7 +9742,7 @@ static void ggml_cuda_op_group_norm(
|
|||||||
|
|
||||||
int num_groups = dst->op_params[0];
|
int num_groups = dst->op_params[0];
|
||||||
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
||||||
group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
|
group_norm_f32_cuda(src0_dd, dst_dd, num_groups * src0->ne[3], group_size, ggml_nelements(src0), main_stream);
|
||||||
|
|
||||||
(void) src1;
|
(void) src1;
|
||||||
(void) dst;
|
(void) dst;
|
||||||
@ -9717,7 +9775,7 @@ static void ggml_cuda_op_upscale(
|
|||||||
|
|
||||||
const int scale_factor = dst->op_params[0];
|
const int scale_factor = dst->op_params[0];
|
||||||
|
|
||||||
upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
|
upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], scale_factor, main_stream);
|
||||||
|
|
||||||
(void) src1;
|
(void) src1;
|
||||||
(void) dst;
|
(void) dst;
|
||||||
@ -9733,8 +9791,49 @@ static void ggml_cuda_op_pad(
|
|||||||
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||||
|
|
||||||
pad_f32_cuda(src0_dd, dst_dd,
|
pad_f32_cuda(src0_dd, dst_dd,
|
||||||
src0->ne[0], src0->ne[1], src0->ne[2],
|
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||||
dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
|
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], main_stream);
|
||||||
|
|
||||||
|
(void) src1;
|
||||||
|
(void) dst;
|
||||||
|
(void) src1_dd;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_op_arange(
|
||||||
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
float start;
|
||||||
|
float stop;
|
||||||
|
float step;
|
||||||
|
memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
|
||||||
|
memcpy(&stop, (float *)dst->op_params + 1, sizeof(float));
|
||||||
|
memcpy(&step, (float *)dst->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
int64_t steps = (int64_t)ceil((stop - start) / step);
|
||||||
|
GGML_ASSERT(ggml_nelements(dst) == steps);
|
||||||
|
|
||||||
|
arange_f32_cuda(dst_dd, dst->ne[0], start, step, main_stream);
|
||||||
|
|
||||||
|
(void) src0;
|
||||||
|
(void) src1;
|
||||||
|
(void) src0_dd;
|
||||||
|
(void) src1_dd;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_op_timestep_embedding(
|
||||||
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
|
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const int dim = dst->op_params[0];
|
||||||
|
const int max_period = dst->op_params[1];
|
||||||
|
|
||||||
|
timestep_embedding_f32_cuda(src0_dd, dst_dd, src0->ne[0], dst->nb[1], dim, max_period, main_stream);
|
||||||
|
|
||||||
(void) src1;
|
(void) src1;
|
||||||
(void) dst;
|
(void) dst;
|
||||||
@ -11019,6 +11118,45 @@ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
|||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_arange(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||||
|
|
||||||
|
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
|
||||||
|
|
||||||
|
// dd = data device
|
||||||
|
float * src0_ddf = nullptr;
|
||||||
|
float * src1_ddf = nullptr;
|
||||||
|
float * dst_ddf = nullptr;
|
||||||
|
|
||||||
|
cuda_pool_alloc<float> dst_f;
|
||||||
|
|
||||||
|
ggml_cuda_set_device(g_main_device);
|
||||||
|
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||||
|
|
||||||
|
if (dst_on_device) {
|
||||||
|
dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||||
|
} else {
|
||||||
|
dst_ddf = dst_f.alloc(ggml_nelements(dst));
|
||||||
|
}
|
||||||
|
|
||||||
|
// do the computation
|
||||||
|
ggml_cuda_op_arange(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
// copy dst to host if necessary
|
||||||
|
if (!dst_on_device) {
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_cuda_timestep_embedding(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_timestep_embedding);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
|
||||||
}
|
}
|
||||||
@ -12124,6 +12262,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
func = ggml_cuda_pad;
|
func = ggml_cuda_pad;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
|
func = ggml_cuda_arange;
|
||||||
|
break;
|
||||||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
func = ggml_cuda_timestep_embedding;
|
||||||
|
break;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
func = ggml_cuda_leaky_relu;
|
func = ggml_cuda_leaky_relu;
|
||||||
break;
|
break;
|
||||||
@ -13029,6 +13173,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
return true;
|
return true;
|
||||||
|
360
ggml-metal.m
360
ggml-metal.m
@ -163,6 +163,8 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
||||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
||||||
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
||||||
@ -444,161 +446,163 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
|
|
||||||
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
||||||
|
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
[metal_library release];
|
[metal_library release];
|
||||||
@ -712,6 +716,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|||||||
return false;
|
return false;
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
@ -1107,7 +1113,8 @@ static bool ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
const float scale = *(const float *) dst->op_params;
|
float scale;
|
||||||
|
memcpy(&scale, dst->op_params, sizeof(scale));
|
||||||
|
|
||||||
int64_t n = ggml_nelements(dst);
|
int64_t n = ggml_nelements(dst);
|
||||||
|
|
||||||
@ -1268,11 +1275,15 @@ static bool ggml_metal_graph_compute(
|
|||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float scale = ((float *) dst->op_params)[0];
|
float scale;
|
||||||
const float max_bias = ((float *) dst->op_params)[1];
|
float max_bias;
|
||||||
|
|
||||||
|
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
||||||
|
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
||||||
|
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
const int64_t nrows_y = src0->ne[1];
|
const int64_t nrows_y = src0->ne[1];
|
||||||
|
|
||||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||||
|
|
||||||
@ -2104,6 +2115,7 @@ static bool ggml_metal_graph_compute(
|
|||||||
|
|
||||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||||
|
|
||||||
float max_bias;
|
float max_bias;
|
||||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
@ -2318,6 +2330,50 @@ static bool ggml_metal_graph_compute(
|
|||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
float start;
|
||||||
|
float step;
|
||||||
|
|
||||||
|
memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
||||||
|
memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
|
||||||
|
[encoder setBytes:&start length:sizeof(start) atIndex:2];
|
||||||
|
[encoder setBytes:&step length:sizeof(step) atIndex:3];
|
||||||
|
|
||||||
|
const int nth = MIN(1024, ne0);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const int dim = dst->op_params[0];
|
||||||
|
const int max_period = dst->op_params[1];
|
||||||
|
|
||||||
|
const int half = dim / 2;
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
|
||||||
|
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
|
||||||
|
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
|
||||||
|
|
||||||
|
const int nth = MIN(1024, half);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
@ -1959,6 +1959,49 @@ kernel void kernel_pad_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_arange_f32(
|
||||||
|
device char * dst,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant float & start,
|
||||||
|
constant float & step,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
device float * dst_ptr = (device float *) dst;
|
||||||
|
|
||||||
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||||
|
dst_ptr[i0] = start + step * i0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_timestep_embedding_f32(
|
||||||
|
device const char * src0,
|
||||||
|
device char * dst,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
constant int & dim,
|
||||||
|
constant int & max_period,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
|
int i = tgpig.x;
|
||||||
|
device float * embed_data = (device float *)(dst + i*nb1);
|
||||||
|
|
||||||
|
int half_ = dim / 2;
|
||||||
|
for (int j = tpitg.x; j < half_; j += ntg.x) {
|
||||||
|
float timestep = ((device float *)src0)[i];
|
||||||
|
float freq = (float)exp(-log((float)max_period) * j / half_);
|
||||||
|
float arg = timestep * freq;
|
||||||
|
embed_data[j ] = cos(arg);
|
||||||
|
embed_data[j + half_] = sin(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dim % 2 != 0 && tpitg.x == 0) {
|
||||||
|
embed_data[dim] = 0.f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// bitonic sort implementation following the CUDA kernels as reference
|
// bitonic sort implementation following the CUDA kernels as reference
|
||||||
typedef void (argsort_t)(
|
typedef void (argsort_t)(
|
||||||
device const float * x,
|
device const float * x,
|
||||||
|
207
ggml.c
207
ggml.c
@ -1882,6 +1882,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||||||
"POOL_2D",
|
"POOL_2D",
|
||||||
"UPSCALE",
|
"UPSCALE",
|
||||||
"PAD",
|
"PAD",
|
||||||
|
"ARANGE",
|
||||||
|
"TIMESTEP_EMBEDDING",
|
||||||
"ARGSORT",
|
"ARGSORT",
|
||||||
"LEAKY_RELU",
|
"LEAKY_RELU",
|
||||||
|
|
||||||
@ -1911,7 +1913,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
|
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
@ -1969,6 +1971,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||||||
"pool_2d(x)",
|
"pool_2d(x)",
|
||||||
"upscale(x)",
|
"upscale(x)",
|
||||||
"pad(x)",
|
"pad(x)",
|
||||||
|
"arange(start, stop, step)",
|
||||||
|
"timestep_embedding(timesteps, dim, max_period)",
|
||||||
"argsort(x)",
|
"argsort(x)",
|
||||||
"leaky_relu(x)",
|
"leaky_relu(x)",
|
||||||
|
|
||||||
@ -1998,7 +2002,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
|
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
@ -2957,11 +2961,21 @@ static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_
|
|||||||
return ((const int32_t *)(tensor->op_params))[i];
|
return ((const int32_t *)(tensor->op_params))[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) {
|
||||||
|
assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
|
||||||
|
return ((const float *)(tensor->op_params))[i];
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
|
static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
|
||||||
assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
|
assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
|
||||||
((int32_t *)(tensor->op_params))[i] = value;
|
((int32_t *)(tensor->op_params))[i] = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, float value) {
|
||||||
|
assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
|
||||||
|
((float *)(tensor->op_params))[i] = value;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
|
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
|
||||||
memset(tensor->data, 0, ggml_nbytes(tensor));
|
memset(tensor->data, 0, ggml_nbytes(tensor));
|
||||||
return tensor;
|
return tensor;
|
||||||
@ -5963,6 +5977,55 @@ struct ggml_tensor * ggml_upscale(
|
|||||||
return ggml_upscale_impl(ctx, a, scale_factor);
|
return ggml_upscale_impl(ctx, a, scale_factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_arange(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
float start,
|
||||||
|
float stop,
|
||||||
|
float step) {
|
||||||
|
|
||||||
|
GGML_ASSERT(stop > start);
|
||||||
|
|
||||||
|
const int64_t steps = (int64_t) ceilf((stop - start) / step);
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
|
||||||
|
|
||||||
|
result->op = GGML_OP_ARANGE;
|
||||||
|
ggml_set_op_params_f32(result, 0, start);
|
||||||
|
ggml_set_op_params_f32(result, 1, stop);
|
||||||
|
ggml_set_op_params_f32(result, 2, step);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_timestep_embedding(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * timesteps,
|
||||||
|
int dim,
|
||||||
|
int max_period) {
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
if (timesteps->grad) {
|
||||||
|
GGML_ASSERT(false); // TODO: implement backward
|
||||||
|
is_node = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int actual_dim = dim;
|
||||||
|
if (dim % 2 != 0) {
|
||||||
|
actual_dim = dim + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
|
||||||
|
|
||||||
|
result->op = GGML_OP_TIMESTEP_EMBEDDING;
|
||||||
|
ggml_set_op_params_i32(result, 0, dim);
|
||||||
|
ggml_set_op_params_i32(result, 1, max_period);
|
||||||
|
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src[0] = timesteps;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_argsort
|
// ggml_argsort
|
||||||
|
|
||||||
struct ggml_tensor * ggml_argsort(
|
struct ggml_tensor * ggml_argsort(
|
||||||
@ -10349,7 +10412,7 @@ static void ggml_compute_forward_group_norm_f32(
|
|||||||
int n_channels = src0->ne[2];
|
int n_channels = src0->ne[2];
|
||||||
int n_groups = dst->op_params[0];
|
int n_groups = dst->op_params[0];
|
||||||
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
|
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
|
||||||
for (int i = ith; i < n_groups; i+=nth) {
|
for (int i = ith; i < n_groups; i += nth) {
|
||||||
int start = i * n_channels_per_group;
|
int start = i * n_channels_per_group;
|
||||||
int end = start + n_channels_per_group;
|
int end = start + n_channels_per_group;
|
||||||
if (end > n_channels) {
|
if (end > n_channels) {
|
||||||
@ -10363,28 +10426,32 @@ static void ggml_compute_forward_group_norm_f32(
|
|||||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||||
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
|
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
|
||||||
|
|
||||||
|
ggml_float sumr = 0.0;
|
||||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||||
sum += (ggml_float)x[i00];
|
sumr += (ggml_float)x[i00];
|
||||||
}
|
}
|
||||||
|
sum += sumr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
float mean = sum / (ne00 * ne01 * step);
|
const float mean = sum / (ne00 * ne01 * step);
|
||||||
ggml_float sum2 = 0.0;
|
|
||||||
|
|
||||||
|
ggml_float sum2 = 0.0;
|
||||||
for (int64_t i02 = start; i02 < end; i02++) {
|
for (int64_t i02 = start; i02 < end; i02++) {
|
||||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||||
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
|
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
|
||||||
|
|
||||||
float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
|
float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
|
||||||
|
|
||||||
|
ggml_float sumr = 0.0;
|
||||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||||
float v = x[i00] - mean;
|
float v = x[i00] - mean;
|
||||||
y[i00] = v;
|
y[i00] = v;
|
||||||
sum2 += (ggml_float)(v * v);
|
sumr += (ggml_float)(v * v);
|
||||||
}
|
}
|
||||||
|
sum2 += sumr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
float variance = sum2 / (ne00 * ne01 * step);
|
const float variance = sum2 / (ne00 * ne01 * step);
|
||||||
const float scale = 1.0f / sqrtf(variance + eps);
|
const float scale = 1.0f / sqrtf(variance + eps);
|
||||||
|
|
||||||
for (int64_t i02 = start; i02 < end; i02++) {
|
for (int64_t i02 = start; i02 < end; i02++) {
|
||||||
@ -13667,6 +13734,106 @@ static void ggml_compute_forward_pad(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ggml_compute_forward_arange
|
||||||
|
|
||||||
|
static void ggml_compute_forward_arange_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const float start = ggml_get_op_params_f32(dst, 0);
|
||||||
|
const float stop = ggml_get_op_params_f32(dst, 1);
|
||||||
|
const float step = ggml_get_op_params_f32(dst, 2);
|
||||||
|
|
||||||
|
const int64_t steps = (int64_t) ceilf((stop - start) / step);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_nelements(dst) == steps);
|
||||||
|
|
||||||
|
for (int64_t i = ith; i < steps; i+= nth) {
|
||||||
|
float value = start + step * i;
|
||||||
|
((float *)dst->data)[i] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_arange(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
switch (dst->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_arange_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_timestep_embedding_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
GGML_TENSOR_UNARY_OP_LOCALS
|
||||||
|
|
||||||
|
const int dim = ggml_get_op_params_i32(dst, 0);
|
||||||
|
const int max_period = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
int half = dim / 2;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < ne00; i++) {
|
||||||
|
float * embed_data = (float *)((char *) dst->data + i*nb1);
|
||||||
|
for (int64_t j = ith; j < half; j += nth) {
|
||||||
|
float timestep = ((float *)src0->data)[i];
|
||||||
|
float freq = (float)expf(-logf(max_period) * j / half);
|
||||||
|
float arg = timestep * freq;
|
||||||
|
embed_data[j] = cosf(arg);
|
||||||
|
embed_data[j + half] = sinf(arg);
|
||||||
|
}
|
||||||
|
if (dim % 2 != 0 && ith == 0) {
|
||||||
|
embed_data[dim] = 0.f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_timestep_embedding(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_timestep_embedding_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_argsort
|
// ggml_compute_forward_argsort
|
||||||
|
|
||||||
static void ggml_compute_forward_argsort_f32(
|
static void ggml_compute_forward_argsort_f32(
|
||||||
@ -15926,6 +16093,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||||||
{
|
{
|
||||||
ggml_compute_forward_pad(params, tensor);
|
ggml_compute_forward_pad(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_arange(params, tensor);
|
||||||
|
} break;
|
||||||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_timestep_embedding(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_argsort(params, tensor);
|
ggml_compute_forward_argsort(params, tensor);
|
||||||
@ -16932,6 +17107,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||||||
{
|
{
|
||||||
GGML_ASSERT(false); // TODO: not implemented
|
GGML_ASSERT(false); // TODO: not implemented
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false); // TODO: not implemented
|
||||||
|
} break;
|
||||||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false); // TODO: not implemented
|
||||||
|
} break;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false); // TODO: not implemented
|
GGML_ASSERT(false); // TODO: not implemented
|
||||||
@ -17684,6 +17867,14 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ARANGE:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
} break;
|
||||||
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
} break;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
|
17
ggml.h
17
ggml.h
@ -454,6 +454,8 @@ extern "C" {
|
|||||||
GGML_OP_POOL_2D,
|
GGML_OP_POOL_2D,
|
||||||
GGML_OP_UPSCALE, // nearest interpolate
|
GGML_OP_UPSCALE, // nearest interpolate
|
||||||
GGML_OP_PAD,
|
GGML_OP_PAD,
|
||||||
|
GGML_OP_ARANGE,
|
||||||
|
GGML_OP_TIMESTEP_EMBEDDING,
|
||||||
GGML_OP_ARGSORT,
|
GGML_OP_ARGSORT,
|
||||||
GGML_OP_LEAKY_RELU,
|
GGML_OP_LEAKY_RELU,
|
||||||
|
|
||||||
@ -1662,6 +1664,15 @@ extern "C" {
|
|||||||
int p2,
|
int p2,
|
||||||
int p3);
|
int p3);
|
||||||
|
|
||||||
|
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
||||||
|
// timesteps: [N,]
|
||||||
|
// return: [N, dim]
|
||||||
|
GGML_API struct ggml_tensor * ggml_timestep_embedding(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * timesteps,
|
||||||
|
int dim,
|
||||||
|
int max_period);
|
||||||
|
|
||||||
// sort rows
|
// sort rows
|
||||||
enum ggml_sort_order {
|
enum ggml_sort_order {
|
||||||
GGML_SORT_ORDER_ASC,
|
GGML_SORT_ORDER_ASC,
|
||||||
@ -1673,6 +1684,12 @@ extern "C" {
|
|||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_sort_order order);
|
enum ggml_sort_order order);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_arange(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
float start,
|
||||||
|
float stop,
|
||||||
|
float step);
|
||||||
|
|
||||||
// top k elements per row
|
// top k elements per row
|
||||||
GGML_API struct ggml_tensor * ggml_top_k(
|
GGML_API struct ggml_tensor * ggml_top_k(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
12
llama.cpp
12
llama.cpp
@ -13330,7 +13330,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
std::string & dest, bool add_ass) {
|
std::string & dest, bool add_ass) {
|
||||||
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
if (tmpl.find("<|im_start|>") != std::string::npos) {
|
if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
|
||||||
// chatml template
|
// chatml template
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
|
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
|
||||||
@ -13338,7 +13338,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|im_start|>assistant\n";
|
ss << "<|im_start|>assistant\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl.find("[INST]") != std::string::npos) {
|
} else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) {
|
||||||
// llama2 template and its variants
|
// llama2 template and its variants
|
||||||
// [variant] support system message
|
// [variant] support system message
|
||||||
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
|
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
|
||||||
@ -13373,7 +13373,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// llama2 templates seem to not care about "add_generation_prompt"
|
// llama2 templates seem to not care about "add_generation_prompt"
|
||||||
} else if (tmpl.find("<|user|>") != std::string::npos) {
|
} else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
|
||||||
// zephyr template
|
// zephyr template
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
|
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
|
||||||
@ -13381,7 +13381,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<|assistant|>\n";
|
ss << "<|assistant|>\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl.find("bos_token + message['role']") != std::string::npos) {
|
} else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) {
|
||||||
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
|
// mlabonne/AlphaMonarch-7B template (the <s> is included inside history)
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
|
std::string bos = (message == chat.front()) ? "" : "<s>"; // skip BOS for first message
|
||||||
@ -13390,7 +13390,7 @@ static int32_t llama_chat_apply_template_internal(
|
|||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<s>assistant\n";
|
ss << "<s>assistant\n";
|
||||||
}
|
}
|
||||||
} else if (tmpl.find("<start_of_turn>") != std::string::npos) {
|
} else if (tmpl == "gemma" || tmpl.find("<start_of_turn>") != std::string::npos) {
|
||||||
// google/gemma-7b-it
|
// google/gemma-7b-it
|
||||||
std::string system_prompt = "";
|
std::string system_prompt = "";
|
||||||
for (auto message : chat) {
|
for (auto message : chat) {
|
||||||
@ -13437,7 +13437,7 @@ LLAMA_API int32_t llama_chat_apply_template(
|
|||||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
// worst case: there is no information about template, we will use chatml by default
|
// worst case: there is no information about template, we will use chatml by default
|
||||||
curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal
|
curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
|
||||||
} else {
|
} else {
|
||||||
curr_tmpl = std::string(model_template.data(), model_template.size());
|
curr_tmpl = std::string(model_template.data(), model_template.size());
|
||||||
}
|
}
|
||||||
|
@ -1 +1 @@
|
|||||||
b458250b736a7473f7ff3560d47c93f1644f3290
|
274680868e12427373bab4bec87554431b954704
|
||||||
|
@ -1422,6 +1422,50 @@ struct test_pad : public test_case {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GGML_OP_ARANGE
|
||||||
|
struct test_arange : public test_case {
|
||||||
|
const ggml_type type;
|
||||||
|
const float start;
|
||||||
|
const float stop;
|
||||||
|
const float step;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR4(type, start, stop, step);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_arange(ggml_type type = GGML_TYPE_F32,
|
||||||
|
float start = 0.f, float stop = 10.f, float step = 1.f)
|
||||||
|
: type(type), start(start), stop(stop), step(step) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * out = ggml_arange(ctx, start, stop, step);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// GGML_OP_TIMESTEP_EMBEDDING
|
||||||
|
struct test_timestep_embedding : public test_case {
|
||||||
|
const ggml_type type;
|
||||||
|
const std::array<int64_t, 4> ne_a;
|
||||||
|
const int dim;
|
||||||
|
const int max_period;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR4(type, ne_a, dim, max_period);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_timestep_embedding(ggml_type type = GGML_TYPE_F32,
|
||||||
|
std::array<int64_t, 4> ne_a = {2, 1, 1, 1},
|
||||||
|
int dim = 320, int max_period=10000)
|
||||||
|
: type(type), ne_a(ne_a), dim(dim), max_period(max_period) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
|
ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// GGML_OP_LEAKY_RELU
|
// GGML_OP_LEAKY_RELU
|
||||||
struct test_leaky_relu : public test_case {
|
struct test_leaky_relu : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
@ -2206,6 +2250,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
test_cases.emplace_back(new test_group_norm());
|
test_cases.emplace_back(new test_group_norm());
|
||||||
test_cases.emplace_back(new test_acc());
|
test_cases.emplace_back(new test_acc());
|
||||||
test_cases.emplace_back(new test_pad());
|
test_cases.emplace_back(new test_pad());
|
||||||
|
test_cases.emplace_back(new test_arange());
|
||||||
|
test_cases.emplace_back(new test_timestep_embedding());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
|
Loading…
Reference in New Issue
Block a user