This commit is contained in:
Georgi Gerganov 2024-12-10 16:31:02 +02:00
parent 5e67008f38
commit 81472a3716
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
10 changed files with 295 additions and 23 deletions

View File

@ -2178,5 +2178,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
add_opt(common_arg(
{"-mv", "--model-vocoder"}, "FNAME",
"vocoder model for audio generation (default: unused)",
[](common_params & params, const std::string & value) {
params.vocoder.model = value;
}
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
return ctx_arg;
}

View File

@ -80,6 +80,7 @@ enum llama_example {
LLAMA_EXAMPLE_LLAVA,
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,
LLAMA_EXAMPLE_COUNT,
};
@ -159,6 +160,7 @@ struct common_params_sampling {
struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
@ -172,6 +174,10 @@ struct common_params_speculative {
std::string model = ""; // draft model for speculative decoding // NOLINT
};
struct common_params_vocoder {
std::string model = ""; // vocoder model for producing audio // NOLINT
};
struct common_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size
@ -214,8 +220,9 @@ struct common_params {
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
struct common_params_sampling sampling;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
struct common_params_vocoder vocoder;
std::string model = ""; // model path // NOLINT
std::string model_alias = ""; // model alias // NOLINT

View File

@ -221,17 +221,17 @@ class Model:
self.gguf_writer.add_context_length(n_ctx)
logger.info(f"gguf: context length = {n_ctx}")
n_embd = self.find_hparam(["hidden_size", "n_embd"])
self.gguf_writer.add_embedding_length(n_embd)
logger.info(f"gguf: embedding length = {n_embd}")
if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None:
self.gguf_writer.add_embedding_length(n_embd)
logger.info(f"gguf: embedding length = {n_embd}")
if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
self.gguf_writer.add_feed_forward_length(n_ff)
logger.info(f"gguf: feed forward length = {n_ff}")
n_head = self.find_hparam(["num_attention_heads", "n_head"])
self.gguf_writer.add_head_count(n_head)
logger.info(f"gguf: head count = {n_head}")
if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None:
self.gguf_writer.add_head_count(n_head)
logger.info(f"gguf: head count = {n_head}")
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
self.gguf_writer.add_head_count_kv(n_head_kv)
@ -2050,7 +2050,8 @@ class OuteTTSVocoderModel(Model):
self._set_vocab_none()
def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.block_count)
super().set_gguf_parameters()
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
@Model.register("Qwen2MoeForCausalLM")

View File

@ -51,6 +51,7 @@ else()
add_subdirectory(speculative)
add_subdirectory(speculative-simple)
add_subdirectory(tokenize)
add_subdirectory(tts)
add_subdirectory(gen-docs)
if (NOT GGML_BACKEND_DL)
# these examples use the backends directly and cannot be built with dynamic loading

View File

@ -0,0 +1,5 @@
set(TARGET llama-tts)
add_executable(${TARGET} tts.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@ -84,6 +84,10 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'):
if match:
new_key = f"backbone.pos_net.{match.group(1)}.norm.{match.group(2)}"
# "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed" -> "backbone.embedding.weight"
if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed":
new_key = "backbone.embedding.weight"
size_mb = value.element_size() * value.nelement() / (1024 * 1024)
print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")
@ -132,6 +136,9 @@ config = {
"architectures": [
"OuteTTSVocoder"
],
"hidden_size": 512,
"vocab_size": 4096,
"max_position_embeddings": 8192, # ?
"num_hidden_layers": 12
}

186
examples/tts/tts.cpp Normal file
View File

@ -0,0 +1,186 @@
#include "arg.h"
#include "common.h"
#include "sampling.h"
#include "log.h"
#include "llama.h"
#include <algorithm>
#include <cstdio>
#include <string>
#include <vector>
#include <fstream>
//
// Terminal utils
//
#define SQR(X) ((X) * (X))
#define UNCUBE(x) x < 48 ? 0 : x < 115 ? 1 : (x - 35) / 40
/**
* Quantizes 24-bit RGB to xterm256 code range [16,256).
*/
static int rgb2xterm256(int r, int g, int b) {
unsigned char cube[] = {0, 0137, 0207, 0257, 0327, 0377};
int av, ir, ig, ib, il, qr, qg, qb, ql;
av = r * .299 + g * .587 + b * .114 + .5;
ql = (il = av > 238 ? 23 : (av - 3) / 10) * 10 + 8;
qr = cube[(ir = UNCUBE(r))];
qg = cube[(ig = UNCUBE(g))];
qb = cube[(ib = UNCUBE(b))];
if (SQR(qr - r) + SQR(qg - g) + SQR(qb - b) <=
SQR(ql - r) + SQR(ql - g) + SQR(ql - b))
return ir * 36 + ig * 6 + ib + 020;
return il + 0350;
}
static std::string set_xterm256_foreground(int r, int g, int b) {
int x = rgb2xterm256(r, g, b);
std::ostringstream oss;
oss << "\033[38;5;" << x << "m";
return oss.str();
}
const std::vector<std::string> k_colors = {
set_xterm256_foreground(220, 5, 12),
set_xterm256_foreground(232, 96, 28),
set_xterm256_foreground(241, 147, 45),
set_xterm256_foreground(246, 193, 65),
set_xterm256_foreground(247, 240, 86),
set_xterm256_foreground(144, 201, 135),
set_xterm256_foreground( 78, 178, 101),
};
static void print_usage(int, char ** argv) {
LOG("\nexample usage:\n");
LOG("\n %s -m model.gguf -p \"Hello!\"\n", argv[0]);
LOG("\n");
}
int main(int argc, char ** argv) {
common_params params;
params.prompt = "";
params.n_predict = 1024;
params.n_batch = 8192;
params.n_ctx = 8192;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
return 1;
}
common_init();
// init LLM
llama_backend_init();
llama_numa_init(params.numa);
llama_model * model_ttc = NULL; // text-to-codes
llama_model * model_cts = NULL; // codes-to-speech
llama_context * ctx_ttc = NULL;
llama_context * ctx_cts = NULL;
common_init_result llama_init_ttc = common_init_from_params(params);
model_ttc = llama_init_ttc.model;
ctx_ttc = llama_init_ttc.context;
params.model = params.vocoder.model;
common_init_result llama_init_cts = common_init_from_params(params);
model_cts = llama_init_cts.model;
ctx_cts = llama_init_cts.context;
const auto t_main_start = ggml_time_us();
std::vector<llama_token> prompt_inp = {198, 88225, 155856, 151669, 152205,
153064, 152537, 153421, 153209, 152524, 151689, 152993, 152438, 152695,
153091, 152945, 152829, 152534, 152934, 153020, 151997, 152263, 153010,
153146, 152399, 153208, 152496, 151793, 152848, 152263, 152571, 153286,
152227, 153300, 152934, 152263, 153208, 152263, 152965, 152430, 152296,
153146, 152920, 152376, 152556, 153363, 151775, 152044, 152972, 152690,
153379, 152368, 152233, 153422, 152490, 151996, 152022, 151694, 152061,
153238, 152539, 153356, 152640, 153021, 153123, 151962, 153094, 151670,
198, 20339, 13189, 155824, 151669, 152070, 152007, 152910, 151683,
152000, 152373, 152760, 152046, 151735, 152334, 152394, 153073, 152908,
151856, 151953, 153247, 153293, 151903, 153480, 153168, 152478, 153359,
153429, 151905, 151678, 152567, 152411, 152165, 152556, 153075, 153424,
151993, 152999, 153078, 152151, 152088, 153389, 152484, 151874, 151670,
198, 285, 155784, 151669, 152226, 152126, 152638, 153215, 151729,
152959, 153479, 153059, 151838, 151670, 198, 1782, 155783, 151669,
153288, 153055, 153314, 152497, 152962, 152741, 152076, 153253, 151670,
198, 471, 16488, 155825, 151669, 152060, 152916, 151893, 153469, 152501,
152080, 152743, 151932, 153161, 152096, 152761, 152698, 153401, 153242,
153336, 152441, 152838, 153467, 152706, 153496, 153310, 152422, 153360,
153115, 152763, 151998, 152373, 153450, 152554, 151968, 153323, 152055,
152468, 153111, 153358, 152813, 152010, 151770, 152823, 152960, 151670,
198, 22627, 155823, 151669, 152814, 152366, 153484, 152931, 153441,
152164, 152877, 152915, 153463, 151692, 152911, 152747, 152776, 151831,
153449, 151882, 152975, 152031, 152513, 153150, 152448, 152667, 153133,
153189, 152619, 153466, 152054, 152106, 153119, 152277, 152439, 153109,
152997, 152141, 153154, 153256, 153311, 151922, 151670, 198, 1055,
155781, 151669, 152633, 151850, 153060, 153270, 152560, 153348, 152729,
151670, 198, 25312, 155803, 151669, 152521, 153403, 152561, 153337,
153383, 152199, 153493, 153326, 151830, 152254, 152248, 152349, 152153,
153007, 151823, 153037, 152575, 152457, 152406, 152592, 153116, 153365,
153456, 151670, 198, 88225, 155817, 151669, 153271, 151925, 152218,
152418, 152253, 153140, 151903, 153151, 152626, 152338, 152647, 153464,
152785, 152768, 151711, 152037, 152033, 151804, 152216, 151701, 151855,
152348, 152995, 152955, 152905, 152342, 152340, 153391, 153453, 152418,
153415, 151990, 153083, 152884, 151670, 198, 151668, 198, 151645};
{
const std::string inp_txt = common_detokenize(ctx_ttc, prompt_inp, true);
LOG_INF("prompt: '%s'\n", inp_txt.c_str());
LOG_INF("%s: prompt size: %d\n", __func__, (int) prompt_inp.size());
}
// remove all non-audio tokens (i.e. < 151672 || > 155772)
prompt_inp.erase(std::remove_if(prompt_inp.begin(), prompt_inp.end(), [](llama_token t) { return t < 151672 || t > 155772; }), prompt_inp.end());
{
const std::string inp_txt = common_detokenize(ctx_ttc, prompt_inp, true);
LOG_INF("prompt audio: '%s'\n", inp_txt.c_str());
LOG_INF("%s: prompt audio size: %d\n", __func__, (int) prompt_inp.size());
}
llama_batch batch = llama_batch_init(prompt_inp.size(), 0, 1);
// evaluate the initial prompt
for (size_t i = 0; i < prompt_inp.size(); ++i) {
common_batch_add(batch, prompt_inp[i], i, { 0 }, true); // TODO: all logits?
}
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
if (llama_decode(ctx_ttc, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
llama_synchronize(ctx_ttc);
LOG_INF("%s: time for prompt: %.3f ms\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f);
const float * embd = llama_get_embeddings(ctx_ttc);
LOG("result:\n");
for (int i = 0; i < 10; ++i) {
LOG("%8.3f ", embd[i]);
}
LOG("\n");
fprintf(stderr, "\n");
llama_free(ctx_ttc);
llama_free_model(model_ttc);
llama_free(ctx_cts);
llama_free_model(model_cts);
llama_backend_free();
return 0;
}

View File

@ -372,6 +372,7 @@ class MODEL_TENSOR(IntEnum):
ENC_OUTPUT_NORM = auto()
CLS = auto() # classifier
CLS_OUT = auto() # classifier output projection
CONV1D = auto()
CONV_NEXT_DW = auto()
CONV_NEXT_NORM = auto()
CONV_NEXT_SHIFT = auto()
@ -388,7 +389,6 @@ class MODEL_TENSOR(IntEnum):
POS_NET_ATTN_K = auto()
POS_NET_ATTN_V = auto()
POS_NET_ATTN_OUT = auto()
QNTZ_CBOOK_EMBD = auto()
HANN_WINDOW = auto()
@ -556,6 +556,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
MODEL_TENSOR.CLS: "cls",
MODEL_TENSOR.CLS_OUT: "cls.output",
MODEL_TENSOR.CONV1D: "conv1d",
MODEL_TENSOR.CONV_NEXT_DW: "conv_next.{bid}.dw",
MODEL_TENSOR.CONV_NEXT_NORM: "conv_next.{bid}.norm",
MODEL_TENSOR.CONV_NEXT_SHIFT: "conv_next.{bid}.shift",
@ -572,7 +573,6 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.POS_NET_ATTN_K: "pos_net.{bid}.attn_k",
MODEL_TENSOR.POS_NET_ATTN_V: "pos_net.{bid}.attn_v",
MODEL_TENSOR.POS_NET_ATTN_OUT: "pos_net.{bid}.attn_output",
MODEL_TENSOR.QNTZ_CBOOK_EMBD: "qntz.cbook.{bid}.embd",
MODEL_TENSOR.HANN_WINDOW: "hann_window",
}
@ -1416,6 +1416,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.TOKEN_EMBD_SHIFT,
MODEL_TENSOR.CONV1D,
MODEL_TENSOR.CONV_NEXT_DW,
MODEL_TENSOR.CONV_NEXT_NORM,
MODEL_TENSOR.CONV_NEXT_SHIFT,
@ -1434,7 +1435,6 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.POS_NET_ATTN_K,
MODEL_TENSOR.POS_NET_ATTN_V,
MODEL_TENSOR.POS_NET_ATTN_OUT,
MODEL_TENSOR.QNTZ_CBOOK_EMBD,
MODEL_TENSOR.HANN_WINDOW,
],
# TODO

View File

@ -28,7 +28,7 @@ class TensorNameMap:
"transformer.token_embeddings", # openelm
"shared", # t5
"rwkv.embeddings", # rwkv
"backbone.embed", # outetts
"feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed" # outetts
),
# Token type embeddings
@ -102,6 +102,10 @@ class TensorNameMap:
MODEL_TENSOR.HANN_WINDOW: (
"head.istft.window", # outetts
),
MODEL_TENSOR.CONV1D: (
"backbone.embed", # roberta
),
}
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
@ -772,10 +776,6 @@ class TensorNameMap:
MODEL_TENSOR.POS_NET_ATTN_OUT: (
"backbone.pos_net.{bid}.proj_out", # outetts
),
MODEL_TENSOR.QNTZ_CBOOK_EMBD: (
"feature_extractor.encodec.quantizer.vq.layers.{bid}._codebook.embed", # outetts
),
}
# architecture-specific block mappings

View File

@ -197,6 +197,7 @@ enum llm_arch {
LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_CHAMELEON,
LLM_ARCH_OUTETTS_VOC,
LLM_ARCH_UNKNOWN,
};
@ -253,6 +254,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_CHAMELEON, "chameleon" },
{ LLM_ARCH_OUTETTS_VOC, "outetts-voc" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -503,6 +505,7 @@ struct LLM_KV {
enum llm_tensor {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_TOKEN_EMBD_NORM,
LLM_TENSOR_TOKEN_EMBD_SHIFT,
LLM_TENSOR_TOKEN_TYPES,
LLM_TENSOR_POS_EMBD,
LLM_TENSOR_OUTPUT,
@ -609,6 +612,24 @@ enum llm_tensor {
LLM_TENSOR_ENC_OUTPUT_NORM,
LLM_TENSOR_CLS,
LLM_TENSOR_CLS_OUT,
LLM_TENSOR_CONV1D,
LLM_TENSOR_CONV_NEXT_DW,
LLM_TENSOR_CONV_NEXT_NORM,
LLM_TENSOR_CONV_NEXT_SHIFT,
LLM_TENSOR_CONV_NEXT_PW1,
LLM_TENSOR_CONV_NEXT_PW2,
LLM_TENSOR_CONV_NEXT_GAMMA,
LLM_TENSOR_POS_NET_CONV1,
LLM_TENSOR_POS_NET_CONV2,
LLM_TENSOR_POS_NET_NORM,
LLM_TENSOR_POS_NET_NORM1,
LLM_TENSOR_POS_NET_NORM2,
LLM_TENSOR_POS_NET_ATTN_NORM,
LLM_TENSOR_POS_NET_ATTN_Q,
LLM_TENSOR_POS_NET_ATTN_K,
LLM_TENSOR_POS_NET_ATTN_V,
LLM_TENSOR_POS_NET_ATTN_OUT,
LLM_TENSOR_HANN_WINDOW,
};
static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_NAMES = {
@ -1593,6 +1614,34 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
},
},
{
LLM_ARCH_OUTETTS_VOC,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_EMBD_SHIFT, "token_embd_shift" },
{ LLM_TENSOR_CONV1D, "conv1d" },
{ LLM_TENSOR_CONV_NEXT_DW, "conv_next.dw" },
{ LLM_TENSOR_CONV_NEXT_NORM, "conv_next.norm" },
{ LLM_TENSOR_CONV_NEXT_SHIFT, "conv_next.shift" },
{ LLM_TENSOR_CONV_NEXT_PW1, "conv_next.pw1" },
{ LLM_TENSOR_CONV_NEXT_PW2, "conv_next.pw2" },
{ LLM_TENSOR_CONV_NEXT_GAMMA, "conv_next.gamma" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_POS_NET_CONV1, "pos_net.conv1" },
{ LLM_TENSOR_POS_NET_CONV2, "pos_net.conv2" },
{ LLM_TENSOR_POS_NET_NORM, "pos_net.norm" },
{ LLM_TENSOR_POS_NET_NORM1, "pos_net.norm1" },
{ LLM_TENSOR_POS_NET_NORM2, "pos_net.norm2" },
{ LLM_TENSOR_POS_NET_ATTN_NORM, "pos_net.attn_norm" },
{ LLM_TENSOR_POS_NET_ATTN_Q, "pos_net.attn_q" },
{ LLM_TENSOR_POS_NET_ATTN_K, "pos_net.attn_k" },
{ LLM_TENSOR_POS_NET_ATTN_V, "pos_net.attn_v" },
{ LLM_TENSOR_POS_NET_ATTN_OUT, "pos_net.attn_output" },
{ LLM_TENSOR_HANN_WINDOW, "hann_window" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@ -2489,7 +2538,7 @@ struct llama_hparams {
bool use_par_res;
bool swin_norm;
uint32_t n_vocab;
uint32_t n_vocab = 0;
uint32_t n_ctx_train; // context size the model was trained on
uint32_t n_embd;
uint32_t n_layer;
@ -3005,6 +3054,9 @@ struct llama_model {
struct ggml_tensor * cls_out = nullptr;
struct ggml_tensor * cls_out_b = nullptr;
// quantizer
struct ggml_tensor * qntz_cbook_embd = nullptr;
std::vector<llama_layer> layers;
// gguf metadata
@ -5519,7 +5571,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
// get hparams kv
ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false);
// everything past this point is not vocab-related
if (hparams.vocab_only) {
@ -5545,8 +5597,8 @@ static void llm_load_hparams(
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer);
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false);
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
// n_head_kv is optional, default to n_head
hparams.n_head_kv_arr = hparams.n_head_arr;
@ -6320,7 +6372,7 @@ static void llm_load_vocab(
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
if (tokenizer_model == "no_vocab") {
if (tokenizer_model == "no_vocab" || tokenizer_model == "none") {
vocab.type = LLAMA_VOCAB_TYPE_NONE;
// default special tokens
@ -9336,9 +9388,9 @@ static bool llm_load_tensors(
} break;
case LLM_ARCH_CHAMELEON:
{
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
// output
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
@ -9367,6 +9419,10 @@ static bool llm_load_tensors(
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_OUTETTS_VOC:
{
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
} break;
default:
throw std::runtime_error("unknown architecture");
}
@ -20383,6 +20439,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_T5ENCODER:
case LLM_ARCH_JAIS:
case LLM_ARCH_RWKV6:
case LLM_ARCH_OUTETTS_VOC:
return LLAMA_ROPE_TYPE_NONE;
// use what we call a normal RoPE, operating on pairs of consecutive head values