mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
Merge branch 'master' into speculative-tree
ggml-ci
This commit is contained in:
commit
ad2727d091
27
README.md
27
README.md
@ -10,7 +10,7 @@
|
|||||||
Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
|
Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
|
||||||
|
|
||||||
### Hot topics
|
### Hot topics
|
||||||
|
- ‼️ BPE tokenizer update: existing Falcon and Starcoder `.gguf` models will need to be reconverted: [#3252](https://github.com/ggerganov/llama.cpp/pull/3252)
|
||||||
- ‼️ Breaking change: `rope_freq_base` and `rope_freq_scale` must be set to zero to use the model default values: [#3401](https://github.com/ggerganov/llama.cpp/pull/3401)
|
- ‼️ Breaking change: `rope_freq_base` and `rope_freq_scale` must be set to zero to use the model default values: [#3401](https://github.com/ggerganov/llama.cpp/pull/3401)
|
||||||
- Parallel decoding + continuous batching support added: [#3228](https://github.com/ggerganov/llama.cpp/pull/3228) \
|
- Parallel decoding + continuous batching support added: [#3228](https://github.com/ggerganov/llama.cpp/pull/3228) \
|
||||||
**Devs should become familiar with the new API**
|
**Devs should become familiar with the new API**
|
||||||
@ -89,16 +89,17 @@ as the main playground for developing new features for the [ggml](https://github
|
|||||||
- [X] [Vicuna](https://github.com/ggerganov/llama.cpp/discussions/643#discussioncomment-5533894)
|
- [X] [Vicuna](https://github.com/ggerganov/llama.cpp/discussions/643#discussioncomment-5533894)
|
||||||
- [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/)
|
- [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/)
|
||||||
- [X] [OpenBuddy 🐶 (Multilingual)](https://github.com/OpenBuddy/OpenBuddy)
|
- [X] [OpenBuddy 🐶 (Multilingual)](https://github.com/OpenBuddy/OpenBuddy)
|
||||||
- [X] [Pygmalion 7B / Metharme 7B](#using-pygmalion-7b--metharme-7b)
|
- [X] [Pygmalion/Metharme](#using-pygmalion-7b--metharme-7b)
|
||||||
- [X] [WizardLM](https://github.com/nlpxucan/WizardLM)
|
- [X] [WizardLM](https://github.com/nlpxucan/WizardLM)
|
||||||
- [X] [Baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) and its derivations (such as [baichuan-7b-sft](https://huggingface.co/hiyouga/baichuan-7b-sft))
|
- [X] [Baichuan 1 & 2](https://huggingface.co/models?search=baichuan-inc/Baichuan) + [derivations](https://huggingface.co/hiyouga/baichuan-7b-sft)
|
||||||
- [X] [Aquila-7B](https://huggingface.co/BAAI/Aquila-7B) / [AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B)
|
- [X] [Aquila 1 & 2](https://huggingface.co/models?search=BAAI/Aquila)
|
||||||
- [X] [Aquila2-7B](https://huggingface.co/BAAI/Aquila2-7B) / [AquilaChat2-7B](https://huggingface.co/BAAI/AquilaChat2-7B) / [AquilaChat2-34B](https://huggingface.co/BAAI/AquilaChat2-34B) / [Aquila2-34B](https://huggingface.co/BAAI/Aquila2-34B)
|
|
||||||
- [X] [Starcoder models](https://github.com/ggerganov/llama.cpp/pull/3187)
|
- [X] [Starcoder models](https://github.com/ggerganov/llama.cpp/pull/3187)
|
||||||
- [X] [Mistral AI v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
- [X] [Mistral AI v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||||
- [X] [Refact](https://huggingface.co/smallcloudai/Refact-1_6B-fim)
|
- [X] [Refact](https://huggingface.co/smallcloudai/Refact-1_6B-fim)
|
||||||
- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553)
|
- [X] [Persimmon 8B](https://github.com/ggerganov/llama.cpp/pull/3410)
|
||||||
- [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417)
|
- [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417)
|
||||||
|
- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553)
|
||||||
|
|
||||||
|
|
||||||
**Bindings:**
|
**Bindings:**
|
||||||
|
|
||||||
@ -207,7 +208,7 @@ https://user-images.githubusercontent.com/1991296/224442907-7693d4be-acaa-4e01-8
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Here are the steps for the LLaMA-7B model.
|
Here are the end-to-end binary build and model conversion steps for the LLaMA-7B model.
|
||||||
|
|
||||||
### Get the Code
|
### Get the Code
|
||||||
|
|
||||||
@ -574,6 +575,18 @@ python3 convert.py models/7B/
|
|||||||
|
|
||||||
When running the larger models, make sure you have enough disk space to store all the intermediate files.
|
When running the larger models, make sure you have enough disk space to store all the intermediate files.
|
||||||
|
|
||||||
|
### Running on Windows with prebuilt binaries
|
||||||
|
|
||||||
|
You will find prebuilt Windows binaries on the release page.
|
||||||
|
|
||||||
|
Simply download and extract the latest zip package of choice: (e.g. `llama-b1380-bin-win-avx2-x64.zip`)
|
||||||
|
|
||||||
|
From the unzipped folder, open a terminal/cmd window here and place a pre-converted `.gguf` model file. Test out the main example like so:
|
||||||
|
|
||||||
|
```
|
||||||
|
.\main -m llama-2-7b.Q4_0.gguf -n 128
|
||||||
|
```
|
||||||
|
|
||||||
### Memory/Disk Requirements
|
### Memory/Disk Requirements
|
||||||
|
|
||||||
As the models are currently fully loaded into memory, you will need adequate disk space to save them and sufficient RAM to load them. At the moment, memory and disk requirements are the same.
|
As the models are currently fully loaded into memory, you will need adequate disk space to save them and sufficient RAM to load them. At the moment, memory and disk requirements are the same.
|
||||||
|
@ -769,7 +769,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
|
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
|
||||||
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
|
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
|
||||||
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
|
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
|
||||||
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
|
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
|
||||||
|
|
||||||
|
@ -106,25 +106,25 @@ node index.js
|
|||||||
|
|
||||||
## API Endpoints
|
## API Endpoints
|
||||||
|
|
||||||
- **POST** `/completion`: Given a prompt, it returns the predicted completion.
|
- **POST** `/completion`: Given a `prompt`, it returns the predicted completion.
|
||||||
|
|
||||||
*Options:*
|
*Options:*
|
||||||
|
|
||||||
|
`prompt`: Provide the prompt for this completion as a string or as an array of strings or numbers representing tokens. Internally, the prompt is compared to the previous completion and only the "unseen" suffix is evaluated. If the prompt is a string or an array with the first element given as a string, a `bos` token is inserted in the front like `main` does.
|
||||||
|
|
||||||
`temperature`: Adjust the randomness of the generated text (default: 0.8).
|
`temperature`: Adjust the randomness of the generated text (default: 0.8).
|
||||||
|
|
||||||
`top_k`: Limit the next token selection to the K most probable tokens (default: 40).
|
`top_k`: Limit the next token selection to the K most probable tokens (default: 40).
|
||||||
|
|
||||||
`top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.95).
|
`top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.95).
|
||||||
|
|
||||||
`n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. (default: -1, -1 = infinity).
|
`n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. (default: -1, -1 = infinity).
|
||||||
|
|
||||||
`n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context.
|
`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded.
|
||||||
By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt.
|
By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the prompt.
|
||||||
|
|
||||||
`stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.
|
`stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.
|
||||||
|
|
||||||
`prompt`: Provide a prompt as a string, or as an array of strings and numbers representing tokens. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. If the prompt is a string, or an array with the first element given as a string, a space is inserted in the front like main.cpp does.
|
|
||||||
|
|
||||||
`stop`: Specify a JSON array of stopping strings.
|
`stop`: Specify a JSON array of stopping strings.
|
||||||
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).
|
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).
|
||||||
|
|
||||||
@ -158,6 +158,36 @@ node index.js
|
|||||||
|
|
||||||
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0)
|
`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0)
|
||||||
|
|
||||||
|
*Result JSON:*
|
||||||
|
|
||||||
|
Note: When using streaming mode (`stream`) only `content` and `stop` will be returned until end of completion.
|
||||||
|
|
||||||
|
`content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
|
||||||
|
|
||||||
|
`stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options)
|
||||||
|
|
||||||
|
`generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`
|
||||||
|
|
||||||
|
`model`: The path to the model loaded with `-m`
|
||||||
|
|
||||||
|
`prompt`: The provided `prompt`
|
||||||
|
|
||||||
|
`stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token
|
||||||
|
|
||||||
|
`stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered
|
||||||
|
|
||||||
|
`stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided
|
||||||
|
|
||||||
|
`stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word)
|
||||||
|
|
||||||
|
`timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second`
|
||||||
|
|
||||||
|
`tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`)
|
||||||
|
|
||||||
|
`tokens_evaluated`: Number of tokens evaluated in total from the prompt
|
||||||
|
|
||||||
|
`truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`)
|
||||||
|
|
||||||
- **POST** `/tokenize`: Tokenize a given text.
|
- **POST** `/tokenize`: Tokenize a given text.
|
||||||
|
|
||||||
*Options:*
|
*Options:*
|
||||||
|
@ -253,13 +253,14 @@ static void init_model(struct my_llama_model * model) {
|
|||||||
set_param_model(model);
|
set_param_model(model);
|
||||||
|
|
||||||
// measure data size
|
// measure data size
|
||||||
struct ggml_allocr * alloc = NULL;
|
size_t size = 0;
|
||||||
alloc = ggml_allocr_new_measure(tensor_alignment);
|
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
alloc_model(alloc, model);
|
size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
|
||||||
|
}
|
||||||
|
|
||||||
// allocate data
|
// allocate data
|
||||||
model->data.resize(ggml_allocr_max_size(alloc) + tensor_alignment);
|
struct ggml_allocr * alloc = NULL;
|
||||||
ggml_allocr_free(alloc);
|
model->data.resize(size + tensor_alignment);
|
||||||
alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
|
alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
|
||||||
alloc_model(alloc, model);
|
alloc_model(alloc, model);
|
||||||
ggml_allocr_free(alloc);
|
ggml_allocr_free(alloc);
|
||||||
@ -1094,11 +1095,9 @@ int main(int argc, char ** argv) {
|
|||||||
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||||
|
|
||||||
// measure required memory for input tensors
|
// measure required memory for input tensors
|
||||||
alloc = ggml_allocr_new_measure(tensor_alignment);
|
size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
|
||||||
ggml_allocr_alloc(alloc, tokens_input);
|
GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
|
||||||
ggml_allocr_alloc(alloc, target_probs);
|
tensor_alignment;
|
||||||
size_t max_input_size = ggml_allocr_max_size(alloc) + tensor_alignment;
|
|
||||||
ggml_allocr_free(alloc);
|
|
||||||
printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
|
printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
|
||||||
|
|
||||||
// allocate input tensors
|
// allocate input tensors
|
||||||
|
@ -1568,7 +1568,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
|||||||
ggml_cl_pool_free(d_D, d_size);
|
ggml_cl_pool_free(d_D, d_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
|
static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
|
||||||
GGML_ASSERT(fp16_support);
|
GGML_ASSERT(fp16_support);
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
@ -1598,6 +1598,10 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
|||||||
const int y_ne = ne11 * ne10;
|
const int y_ne = ne11 * ne10;
|
||||||
const int d_ne = ne11 * ne01;
|
const int d_ne = ne11 * ne01;
|
||||||
|
|
||||||
|
GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * y_ne);
|
||||||
|
GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * d_ne);
|
||||||
|
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata;
|
||||||
|
|
||||||
size_t x_size;
|
size_t x_size;
|
||||||
size_t y_size;
|
size_t y_size;
|
||||||
size_t d_size;
|
size_t d_size;
|
||||||
@ -1634,7 +1638,6 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
|||||||
|
|
||||||
// convert src1 to fp16
|
// convert src1 to fp16
|
||||||
// TODO: use multiple threads
|
// TODO: use multiple threads
|
||||||
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12);
|
|
||||||
char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
|
char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
|
||||||
if (src1_cont_rows) {
|
if (src1_cont_rows) {
|
||||||
if (src1_cont_cols) {
|
if (src1_cont_cols) {
|
||||||
@ -1897,8 +1900,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor *
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||||
if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
|
if (src0->type == GGML_TYPE_F16 && ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
|
||||||
return ggml_nelements(src1) * sizeof(ggml_fp16_t);
|
return sizeof(ggml_fp16_t) * std::max(src1->ne[0] * src1->ne[1], dst->ne[0] * dst->ne[1]);
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
23
llama.cpp
23
llama.cpp
@ -2333,13 +2333,13 @@ static void llm_load_vocab(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
|
if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
|
||||||
fprintf(stderr, "%s: warning: Mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
|
LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
|
||||||
__func__,
|
__func__,
|
||||||
special_tokens_count_from_verification, vocab.id_to_token.size(),
|
special_tokens_count_from_verification, vocab.id_to_token.size(),
|
||||||
special_tokens_count_by_type, vocab.id_to_token.size()
|
special_tokens_count_by_type, vocab.id_to_token.size()
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "%s: Special tokens definition check successful ( %u/%zu ).\n",
|
LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n",
|
||||||
__func__,
|
__func__,
|
||||||
special_tokens_count_from_verification, vocab.id_to_token.size()
|
special_tokens_count_from_verification, vocab.id_to_token.size()
|
||||||
);
|
);
|
||||||
@ -5918,6 +5918,13 @@ static int llama_decode_internal(
|
|||||||
|
|
||||||
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
||||||
|
|
||||||
|
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||||
|
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
||||||
|
|
||||||
|
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
|
||||||
|
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||||
|
|
||||||
|
|
||||||
#ifdef GGML_USE_CUBLAS
|
#ifdef GGML_USE_CUBLAS
|
||||||
for (int i = 0; i < gf->n_leafs; i++) {
|
for (int i = 0; i < gf->n_leafs; i++) {
|
||||||
ggml_tensor * node = gf->leafs[i];
|
ggml_tensor * node = gf->leafs[i];
|
||||||
@ -5935,6 +5942,12 @@ static int llama_decode_internal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_cuda_set_mul_mat_q(cparams.mul_mat_q);
|
ggml_cuda_set_mul_mat_q(cparams.mul_mat_q);
|
||||||
|
|
||||||
|
// HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed
|
||||||
|
if (!lctx.embedding.empty()) {
|
||||||
|
embeddings->backend = GGML_BACKEND_CPU;
|
||||||
|
}
|
||||||
|
res->backend = GGML_BACKEND_CPU;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||||
@ -5959,12 +5972,6 @@ static int llama_decode_internal(
|
|||||||
n_threads = 1;
|
n_threads = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
|
||||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
|
||||||
|
|
||||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0);
|
|
||||||
GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
|
||||||
|
|
||||||
#if GGML_USE_MPI
|
#if GGML_USE_MPI
|
||||||
const int64_t n_layer = hparams.n_layer;
|
const int64_t n_layer = hparams.n_layer;
|
||||||
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
|
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
|
||||||
|
Loading…
Reference in New Issue
Block a user