compute hann window

This commit is contained in:
Georgi Gerganov 2024-12-11 12:35:47 +02:00
parent 86d0ad5ef4
commit 5aaf4a8aa6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
6 changed files with 27 additions and 30 deletions

View File

@ -70,7 +70,7 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'):
# keep only what we need for inference # keep only what we need for inference
if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \ if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \
not key.startswith('state_dict.backbone.') and \ not key.startswith('state_dict.backbone.') and \
not key.startswith('state_dict.head.'): not key.startswith('state_dict.head.out'):
print('Skipping key: ', key) print('Skipping key: ', key)
continue continue
@ -101,9 +101,6 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'):
if new_key.endswith("gamma"): if new_key.endswith("gamma"):
new_key = new_key.replace("gamma", "gamma.weight") new_key = new_key.replace("gamma", "gamma.weight")
if new_key == "head.istft.window":
new_key = "head.istft.window.weight"
size_mb = value.element_size() * value.nelement() / (1024 * 1024) size_mb = value.element_size() * value.nelement() / (1024 * 1024)
print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}") print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")

View File

@ -57,6 +57,16 @@ static void print_usage(int, char ** argv) {
LOG("\n"); LOG("\n");
} }
void fill_hann_window(int length, bool periodic, float * output) {
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
}
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
common_params params; common_params params;
@ -171,6 +181,11 @@ int main(int argc, char ** argv) {
const int n_embd = llama_n_embd(model_cts); const int n_embd = llama_n_embd(model_cts);
const float * embd = llama_get_embeddings(ctx_cts); const float * embd = llama_get_embeddings(ctx_cts);
const int w = 1280;
std::vector<float> hann(w);
fill_hann_window(hann.size(), true, hann.data());
int n = n_embd*261; int n = n_embd*261;
LOG("result:\n"); LOG("result:\n");

View File

@ -387,7 +387,6 @@ class MODEL_TENSOR(IntEnum):
POS_NET_ATTN_K = auto() POS_NET_ATTN_K = auto()
POS_NET_ATTN_V = auto() POS_NET_ATTN_V = auto()
POS_NET_ATTN_OUT = auto() POS_NET_ATTN_OUT = auto()
HANN_WINDOW = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -569,7 +568,6 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.POS_NET_ATTN_K: "pos_net.{bid}.attn_k", MODEL_TENSOR.POS_NET_ATTN_K: "pos_net.{bid}.attn_k",
MODEL_TENSOR.POS_NET_ATTN_V: "pos_net.{bid}.attn_v", MODEL_TENSOR.POS_NET_ATTN_V: "pos_net.{bid}.attn_v",
MODEL_TENSOR.POS_NET_ATTN_OUT: "pos_net.{bid}.attn_output", MODEL_TENSOR.POS_NET_ATTN_OUT: "pos_net.{bid}.attn_output",
MODEL_TENSOR.HANN_WINDOW: "hann_window",
} }
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -1429,7 +1427,6 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.POS_NET_ATTN_K, MODEL_TENSOR.POS_NET_ATTN_K,
MODEL_TENSOR.POS_NET_ATTN_V, MODEL_TENSOR.POS_NET_ATTN_V,
MODEL_TENSOR.POS_NET_ATTN_OUT, MODEL_TENSOR.POS_NET_ATTN_OUT,
MODEL_TENSOR.HANN_WINDOW,
], ],
# TODO # TODO
} }

View File

@ -94,10 +94,6 @@ class TensorNameMap:
MODEL_TENSOR.ROPE_FACTORS_LONG: (), MODEL_TENSOR.ROPE_FACTORS_LONG: (),
MODEL_TENSOR.ROPE_FACTORS_SHORT: (), MODEL_TENSOR.ROPE_FACTORS_SHORT: (),
MODEL_TENSOR.HANN_WINDOW: (
"head.istft.window", # outetts
),
MODEL_TENSOR.CONV1D: ( MODEL_TENSOR.CONV1D: (
"backbone.embed", # roberta "backbone.embed", # roberta
), ),

View File

@ -482,9 +482,6 @@ extern "C" {
// Returns the total number of parameters in the model // Returns the total number of parameters in the model
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
// Get a llama model tensor
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
// Returns true if the model contains an encoder that requires llama_encode() call // Returns true if the model contains an encoder that requires llama_encode() call
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model); LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);

View File

@ -627,7 +627,6 @@ enum llm_tensor {
LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_K,
LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_V,
LLM_TENSOR_POS_NET_ATTN_OUT, LLM_TENSOR_POS_NET_ATTN_OUT,
LLM_TENSOR_HANN_WINDOW,
}; };
static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = { static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
@ -1635,7 +1634,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_POS_NET_ATTN_K, "pos_net.%d.attn_k" }, { LLM_TENSOR_POS_NET_ATTN_K, "pos_net.%d.attn_k" },
{ LLM_TENSOR_POS_NET_ATTN_V, "pos_net.%d.attn_v" }, { LLM_TENSOR_POS_NET_ATTN_V, "pos_net.%d.attn_v" },
{ LLM_TENSOR_POS_NET_ATTN_OUT, "pos_net.%d.attn_output" }, { LLM_TENSOR_POS_NET_ATTN_OUT, "pos_net.%d.attn_output" },
{ LLM_TENSOR_HANN_WINDOW, "hann_window" },
}, },
}, },
{ {
@ -3648,6 +3646,17 @@ static int llama_get_device_count(const llama_model & model) {
return (int) model.devices.size(); return (int) model.devices.size();
} }
static struct ggml_tensor * llama_get_model_tensor(const struct llama_model * model, const char * name) {
auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(),
[name](const std::pair<std::string, struct ggml_tensor *> & it) {
return it.first == name;
});
if (it == model->tensors_by_name.end()) {
return nullptr;
}
return it->second;
}
template<typename F> template<typename F>
static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
ggml_init_params params = { ggml_init_params params = {
@ -7462,7 +7471,6 @@ static const std::map<llm_tensor, llm_tensor_info> llm_tensor_info_mapping = {
{LLM_TENSOR_CONV_NEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONV_NEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONV_NEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONV_NEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONV_NEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CONV_NEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_HANN_WINDOW, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
}; };
// checks if the weight tensor can be used with the specified buffer type and device // checks if the weight tensor can be used with the specified buffer type and device
@ -9638,8 +9646,6 @@ static bool llm_load_tensors(
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {768, n_embd}, 0); model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {768, n_embd}, 0);
model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
model.hann_window = create_tensor(tn(LLM_TENSOR_HANN_WINDOW, "weight"), {1280}, 0);
} break; } break;
default: default:
throw std::runtime_error("unknown architecture"); throw std::runtime_error("unknown architecture");
@ -21021,17 +21027,6 @@ uint64_t llama_model_n_params(const struct llama_model * model) {
return model->n_elements; return model->n_elements;
} }
struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) {
auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(),
[name](const std::pair<std::string, struct ggml_tensor *> & it) {
return it.first == name;
});
if (it == model->tensors_by_name.end()) {
return nullptr;
}
return it->second;
}
bool llama_model_has_encoder(const struct llama_model * model) { bool llama_model_has_encoder(const struct llama_model * model) {
switch (model->arch) { switch (model->arch) {
case LLM_ARCH_T5: return true; case LLM_ARCH_T5: return true;