mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-13 14:29:52 +00:00
Compare commits
3 Commits
87fdddabf3
...
5e85e7b095
Author | SHA1 | Date | |
---|---|---|---|
|
5e85e7b095 | ||
|
1329c0a75e | ||
|
c76851eeb0 |
@ -800,6 +800,7 @@ if (GGML_KOMPUTE)
|
||||
kompute-shaders/op_mul_mat_q8_0.comp
|
||||
kompute-shaders/op_mul_mat_q4_0.comp
|
||||
kompute-shaders/op_mul_mat_q4_1.comp
|
||||
kompute-shaders/op_mul_mat_q4_k.comp
|
||||
kompute-shaders/op_mul_mat_q6_k.comp
|
||||
kompute-shaders/op_getrows_f32.comp
|
||||
kompute-shaders/op_getrows_f16.comp
|
||||
@ -833,6 +834,7 @@ if (GGML_KOMPUTE)
|
||||
shaderop_mul_mat_q8_0.h
|
||||
shaderop_mul_mat_q4_0.h
|
||||
shaderop_mul_mat_q4_1.h
|
||||
shaderop_mul_mat_q4_k.h
|
||||
shaderop_mul_mat_q6_k.h
|
||||
shaderop_getrows_f32.h
|
||||
shaderop_getrows_f16.h
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "shaderop_mul_mat_q8_0.h"
|
||||
#include "shaderop_mul_mat_q4_0.h"
|
||||
#include "shaderop_mul_mat_q4_1.h"
|
||||
#include "shaderop_mul_mat_q4_k.h"
|
||||
#include "shaderop_mul_mat_q6_k.h"
|
||||
#include "shaderop_mul_mat_mat_f32.h"
|
||||
#include "shaderop_getrows_f32.h"
|
||||
@ -1067,6 +1068,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
|
||||
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_q4_k(
|
||||
kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA,
|
||||
const std::shared_ptr<kp::Tensor>& inB,
|
||||
const std::shared_ptr<kp::Tensor>& out,
|
||||
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
||||
int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
|
||||
int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
|
||||
int32_t ne1, int32_t r2, int32_t r3
|
||||
) {
|
||||
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
|
||||
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
|
||||
|
||||
struct PushConstants {
|
||||
uint32_t inAOff, inBOff, outOff;
|
||||
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
|
||||
} pushConsts {
|
||||
0, 0, 0,
|
||||
ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
|
||||
};
|
||||
|
||||
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||
if (!komputeManager()->hasAlgorithm(__func__)) {
|
||||
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
|
||||
} else {
|
||||
s_algo = komputeManager()->getAlgorithm(__func__);
|
||||
s_algo->setTensors({inA, inB, out});
|
||||
s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
|
||||
s_algo->setPushConstants<PushConstants>({pushConsts});
|
||||
s_algo->updateDescriptors(s_kompute_context->pool.get());
|
||||
}
|
||||
seq.record<kp::OpAlgoDispatch>(s_algo);
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_q6_k(
|
||||
kp::Sequence& seq,
|
||||
const std::shared_ptr<kp::Tensor>& inA,
|
||||
@ -1384,6 +1419,7 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_K:
|
||||
return true;
|
||||
default:
|
||||
;
|
||||
@ -1635,6 +1671,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
|
||||
);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
ggml_vk_mul_mat_q4_k(
|
||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
||||
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
|
||||
);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
ggml_vk_mul_mat_q6_k(
|
||||
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
||||
|
@ -15,6 +15,7 @@
|
||||
#define TWOPI_F 6.283185307179586f
|
||||
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
|
||||
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
|
||||
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
|
||||
@ -64,6 +65,14 @@ mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
|
||||
return reg;
|
||||
}
|
||||
|
||||
#define sizeof_block_q4_k 144
|
||||
struct block_q4_k {
|
||||
float16_t d;
|
||||
float16_t dmin;
|
||||
uint8_t scales[K_SCALE_SIZE];
|
||||
uint8_t qs[QK_K/2];
|
||||
};
|
||||
|
||||
#define sizeof_block_q6_k 210
|
||||
struct block_q6_k {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
|
133
ggml/src/kompute-shaders/op_mul_mat_q4_k.comp
Normal file
133
ggml/src/kompute-shaders/op_mul_mat_q4_k.comp
Normal file
@ -0,0 +1,133 @@
|
||||
#version 450
|
||||
|
||||
#include "common.comp"
|
||||
|
||||
#define N_DST 4
|
||||
#define SIZE_OF_BLOCK sizeof_block_q4_k
|
||||
|
||||
layout(local_size_x = 4) in;
|
||||
layout(local_size_y = 8) in;
|
||||
layout(local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
|
||||
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
|
||||
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint inAOff;
|
||||
uint inBOff;
|
||||
uint outOff;
|
||||
int ne00;
|
||||
int ne10;
|
||||
int ne0;
|
||||
int ne1;
|
||||
int ne01;
|
||||
int ne02;
|
||||
int ne12;
|
||||
int r2;
|
||||
int r3;
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
const uint16_t kmask1 = uint16_t(0x3f3f);
|
||||
const uint16_t kmask2 = uint16_t(0x0f0f);
|
||||
const uint16_t kmask3 = uint16_t(0xc0c0);
|
||||
|
||||
const uint ix = gl_SubgroupInvocationID/8; // 0...3
|
||||
const uint it = gl_SubgroupInvocationID%8; // 0...7
|
||||
const uint iq = it/4; // 0 or 1
|
||||
const uint ir = it%4; // 0...3
|
||||
|
||||
const uint nb = pcs.ne00/QK_K;
|
||||
|
||||
const uint r0 = gl_WorkGroupID.x;
|
||||
const uint r1 = gl_WorkGroupID.y;
|
||||
const uint im = gl_WorkGroupID.z;
|
||||
|
||||
const uint first_row = r0 * N_DST;
|
||||
const uint ib_row = first_row * nb;
|
||||
|
||||
const uint i12 = im%pcs.ne12;
|
||||
const uint i13 = im/pcs.ne12;
|
||||
|
||||
const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
|
||||
|
||||
const uint xblk = ib_row + offset0 + pcs.inAOff;
|
||||
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
|
||||
|
||||
float yl[16];
|
||||
float yh[16];
|
||||
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
|
||||
float all_sum = 0.f;
|
||||
|
||||
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
||||
|
||||
for (uint ib = ix; ib < nb; ib += 4) {
|
||||
const uint blk_idx = ib + xblk;
|
||||
|
||||
float sumy[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
|
||||
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
|
||||
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
|
||||
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
uint row_idx = row * nb;
|
||||
|
||||
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
|
||||
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
|
||||
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
|
||||
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
|
||||
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
|
||||
|
||||
uint16_t sc16[4];
|
||||
sc16[0] = sc_0 & kmask1;
|
||||
sc16[1] = sc_2 & kmask1;
|
||||
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
|
||||
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
|
||||
|
||||
float acc1[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
|
||||
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
|
||||
acc1[0] += yl[i+0] * (q1 & 0x000F);
|
||||
acc1[1] += yl[i+1] * (q1 & 0x0F00);
|
||||
acc1[2] += yl[i+8] * (q1 & 0x00F0);
|
||||
acc1[3] += yl[i+9] * (q1 & 0xF000);
|
||||
acc2[0] += yh[i+0] * (q2 & 0x000F);
|
||||
acc2[1] += yh[i+1] * (q2 & 0x0F00);
|
||||
acc2[2] += yh[i+8] * (q2 & 0x00F0);
|
||||
acc2[3] += yh[i+9] * (q2 & 0xF000);
|
||||
}
|
||||
|
||||
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
|
||||
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
|
||||
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
|
||||
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
|
||||
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
|
||||
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
|
||||
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
|
||||
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
|
||||
|
||||
float dall = float(inA[blk_idx + row_idx].d);
|
||||
float dmin = float(inA[blk_idx + row_idx].dmin);
|
||||
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
|
||||
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
|
||||
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
|
||||
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
|
||||
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
|
||||
}
|
||||
|
||||
y4 += 4 * QK_K;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = subgroupAdd(sumf[row]);
|
||||
if (subgroupElect()) {
|
||||
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
@ -3547,10 +3547,10 @@ static bool llama_kv_cache_init(
|
||||
// to the first cell of the slot.
|
||||
static bool llama_kv_cache_find_slot(
|
||||
struct llama_kv_cache & cache,
|
||||
const struct llama_ubatch & batch) {
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
const uint32_t n_seqs = batch.n_seqs;
|
||||
const uint32_t n_seq_tokens = batch.n_seq_tokens;
|
||||
const struct llama_ubatch & ubatch) {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
if (cache.recurrent) {
|
||||
// For recurrent state architectures (like Mamba or RWKV),
|
||||
@ -3558,16 +3558,16 @@ static bool llama_kv_cache_find_slot(
|
||||
// A slot should be always be contiguous.
|
||||
|
||||
// can only process batches with an equal number of new tokens in each sequence
|
||||
GGML_ASSERT(batch.equal_seqs);
|
||||
GGML_ASSERT(ubatch.equal_seqs);
|
||||
|
||||
int32_t min = cache.size - 1;
|
||||
int32_t max = 0;
|
||||
|
||||
// everything should fit if all seq_ids are smaller than the max
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const uint32_t n_seq_id = batch.n_seq_id[s];
|
||||
const uint32_t n_seq_id = ubatch.n_seq_id[s];
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
const llama_seq_id seq_id = batch.seq_id[s][j];
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
|
||||
// too big seq_id
|
||||
@ -3626,7 +3626,7 @@ static bool llama_kv_cache_find_slot(
|
||||
|
||||
// find usable cell range
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = batch.seq_id[s][0];
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
llama_kv_cell & seq_meta = cache.cells[seq_id];
|
||||
bool has_cell = false;
|
||||
if (seq_meta.tail >= 0) {
|
||||
@ -3665,7 +3665,7 @@ static bool llama_kv_cache_find_slot(
|
||||
// gather and re-order
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
int32_t dst_id = s + min;
|
||||
int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
|
||||
int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
|
||||
if (dst_id != src_id) {
|
||||
llama_kv_cell & dst_cell = cache.cells[dst_id];
|
||||
llama_kv_cell & src_cell = cache.cells[src_id];
|
||||
@ -3686,7 +3686,7 @@ static bool llama_kv_cache_find_slot(
|
||||
|
||||
// update the pos of the used seqs
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
||||
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
||||
int32_t cell_id = s + min;
|
||||
llama_kv_cell & cell = cache.cells[cell_id];
|
||||
|
||||
@ -3694,12 +3694,12 @@ static bool llama_kv_cache_find_slot(
|
||||
// What should happen when the pos backtracks or skips a value?
|
||||
// Clearing the state mid-batch would require special-casing which isn't done.
|
||||
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
||||
__func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
|
||||
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
|
||||
}
|
||||
cell.pos = last_pos;
|
||||
cell.seq_id.clear();
|
||||
for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
|
||||
const llama_seq_id seq_id = batch.seq_id[s][j];
|
||||
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
||||
cell.seq_id.insert(seq_id);
|
||||
cache.cells[seq_id].tail = cell_id;
|
||||
}
|
||||
@ -3751,10 +3751,10 @@ static bool llama_kv_cache_find_slot(
|
||||
for (uint32_t s = 0; s < n_seqs; s++) {
|
||||
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
||||
uint32_t k = s*n_seq_tokens + i;
|
||||
cache.cells[cache.head + k].pos = batch.pos[k];
|
||||
cache.cells[cache.head + k].pos = ubatch.pos[k];
|
||||
|
||||
for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
|
||||
cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
|
||||
for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
|
||||
cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -9275,21 +9275,21 @@ static struct ggml_tensor * llm_build_inp_embd(
|
||||
struct ggml_context * ctx,
|
||||
struct llama_context & lctx,
|
||||
const llama_hparams & hparams,
|
||||
const llama_ubatch & batch,
|
||||
const llama_ubatch & ubatch,
|
||||
struct ggml_tensor * tok_embd,
|
||||
const llm_build_cb & cb) {
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
if (batch.token) {
|
||||
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
|
||||
if (ubatch.token) {
|
||||
lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
cb(lctx.inp_tokens, "inp_tokens", -1);
|
||||
ggml_set_input(lctx.inp_tokens);
|
||||
|
||||
inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
|
||||
} else {
|
||||
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
|
||||
lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
|
||||
inpL = lctx.inp_embd;
|
||||
ggml_set_input(lctx.inp_embd);
|
||||
}
|
||||
@ -9856,7 +9856,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
|
||||
static struct ggml_tensor * llm_build_mamba(
|
||||
struct ggml_context * ctx,
|
||||
struct llama_context & lctx,
|
||||
const llama_ubatch & batch,
|
||||
const llama_ubatch & ubatch,
|
||||
struct ggml_cgraph * graph,
|
||||
struct ggml_tensor * cur,
|
||||
struct ggml_tensor * state_copy,
|
||||
@ -9872,17 +9872,17 @@ static struct ggml_tensor * llm_build_mamba(
|
||||
const int64_t d_inner = hparams.ssm_d_inner;
|
||||
const int64_t d_state = hparams.ssm_d_state;
|
||||
const int64_t dt_rank = hparams.ssm_dt_rank;
|
||||
const int64_t n_seqs = batch.n_seqs;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
|
||||
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
|
||||
// Use the same RMS norm as the final layer norm
|
||||
const float norm_rms_eps = hparams.f_norm_rms_eps;
|
||||
|
||||
const int64_t n_seq_tokens = batch.n_seq_tokens;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
||||
GGML_ASSERT(n_seqs != 0);
|
||||
GGML_ASSERT(batch.equal_seqs);
|
||||
GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
|
||||
GGML_ASSERT(ubatch.equal_seqs);
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
|
||||
struct ggml_tensor * conv_states_all = kv.k_l[il];
|
||||
struct ggml_tensor * ssm_states_all = kv.v_l[il];
|
||||
@ -20545,10 +20545,10 @@ struct llama_data_read {
|
||||
|
||||
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||
|
||||
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||
batch.n_tokens = cell_count;
|
||||
batch.n_seq_tokens = cell_count;
|
||||
batch.n_seqs = 1;
|
||||
llama_ubatch ubatch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||
ubatch.n_tokens = cell_count;
|
||||
ubatch.n_seq_tokens = cell_count;
|
||||
ubatch.n_seqs = 1;
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_pos pos;
|
||||
@ -20562,11 +20562,11 @@ struct llama_data_read {
|
||||
return false;
|
||||
}
|
||||
|
||||
batch.pos[i] = pos;
|
||||
ubatch.pos[i] = pos;
|
||||
}
|
||||
batch.n_seq_id[0] = 1;
|
||||
batch.seq_id[0] = &dest_seq_id;
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
ubatch.n_seq_id[0] = 1;
|
||||
ubatch.seq_id[0] = &dest_seq_id;
|
||||
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
@ -20574,8 +20574,8 @@ struct llama_data_read {
|
||||
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||
// Assume that this is one contiguous block of cells
|
||||
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head].pos == ubatch.pos[0]);
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
|
||||
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user