ggml/ex: calculate accuracy in graph, adapt MNIST (ggml/980)

This commit is contained in:
Johannes Gäßler 2024-10-03 17:29:59 +02:00 committed by Georgi Gerganov
parent eee39bdc96
commit fabdc3bda3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
11 changed files with 389 additions and 8 deletions

View File

@ -456,6 +456,7 @@ extern "C" {
GGML_OP_SUM_ROWS, GGML_OP_SUM_ROWS,
GGML_OP_MEAN, GGML_OP_MEAN,
GGML_OP_ARGMAX, GGML_OP_ARGMAX,
GGML_OP_COUNT_EQUAL,
GGML_OP_REPEAT, GGML_OP_REPEAT,
GGML_OP_REPEAT_BACK, GGML_OP_REPEAT_BACK,
GGML_OP_CONCAT, GGML_OP_CONCAT,
@ -994,6 +995,12 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// count number of equal elements in a and b
GGML_API struct ggml_tensor * ggml_count_equal(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// if a is the same shape as b, and a is not parameter, return a // if a is the same shape as b, and a is not parameter, return a
// otherwise, return a new tensor: repeat(a) to fit in b // otherwise, return a new tensor: repeat(a) to fit in b
GGML_API struct ggml_tensor * ggml_repeat( GGML_API struct ggml_tensor * ggml_repeat(

View File

@ -5,12 +5,14 @@
#include "ggml-cuda/common.cuh" #include "ggml-cuda/common.cuh"
#include "ggml-cuda/acc.cuh" #include "ggml-cuda/acc.cuh"
#include "ggml-cuda/arange.cuh" #include "ggml-cuda/arange.cuh"
#include "ggml-cuda/argmax.cuh"
#include "ggml-cuda/argsort.cuh" #include "ggml-cuda/argsort.cuh"
#include "ggml-cuda/binbcast.cuh" #include "ggml-cuda/binbcast.cuh"
#include "ggml-cuda/clamp.cuh" #include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/concat.cuh" #include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/convert.cuh" #include "ggml-cuda/convert.cuh"
#include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/cross-entropy-loss.cuh"
#include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/diagmask.cuh"
@ -2143,6 +2145,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
} }
switch (dst->op) { switch (dst->op) {
case GGML_OP_ARGMAX:
ggml_cuda_argmax(ctx, dst);
break;
case GGML_OP_COUNT_EQUAL:
ggml_cuda_count_equal(ctx, dst);
break;
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
ggml_cuda_op_repeat(ctx, dst); ggml_cuda_op_repeat(ctx, dst);
break; break;
@ -3073,6 +3081,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false; return false;
} break; } break;
case GGML_OP_DUP: case GGML_OP_DUP:
{
ggml_type src0_type = op->src[0]->type;
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
} break;
case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
{
return true;
} break;
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
{ {
ggml_type src0_type = op->src[0]->type; ggml_type src0_type = op->src[0]->type;

View File

@ -0,0 +1,79 @@
#include "common.cuh"
#include "argmax.cuh"
#include "sum.cuh"
#include <cstdint>
static __global__ void argmax_f32(
const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
int argmax_thread = 0;
const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;
#pragma unroll
for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {
const int64_t row = row0 + row1;
if (row >= nrows) {
break;
}
float maxval = -FLT_MAX;
int argmax = -1;
for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {
const float val = x[row*ncols + col];
const int bigger = val > maxval;
const int not_bigger = bigger ^ 0x00000001;
maxval = maxval*not_bigger + val*bigger;
argmax = argmax*not_bigger + col*bigger;
}
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
const int bigger = val > maxval;
const int not_bigger = bigger ^ 0x00000001;
maxval = maxval*not_bigger + val*bigger;
argmax = argmax*not_bigger + col*bigger;
}
const int store = row1 == threadIdx.x;
argmax_thread += store*argmax;
}
const int row = row0 + threadIdx.x;
if (row >= nrows) {
return;
}
dst[row] = argmax_thread;
}
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const float * src0_d = (const float *) src0->data;
int32_t * dst_d = (int32_t *) dst->data;
cudaStream_t stream = ctx.stream();
const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;
const dim3 blocks_dim(WARP_SIZE, 1, 1);
const dim3 blocks_num(num_blocks, 1, 1);
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -175,6 +175,18 @@ static __device__ void no_device_code(
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.") #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
#endif // __CUDA_ARCH__ #endif // __CUDA_ARCH__
static __device__ __forceinline__ int warp_reduce_sum(int x) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
return __reduce_add_sync(0xffffffff, x);
#else
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
}
static __device__ __forceinline__ float warp_reduce_sum(float x) { static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) { for (int mask = 16; mask > 0; mask >>= 1) {

View File

@ -0,0 +1,64 @@
#include "common.cuh"
#include "count-equal.cuh"
#include <cstdint>
template <typename T>
static __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) {
const int64_t i0 = (int64_t) blockIdx.x*dk;
const int64_t i1 = min(i0 + dk, k);
int nequal = 0;
for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) {
const T xi = x[i];
const T yi = y[i];
nequal += xi == yi;
}
nequal = warp_reduce_sum(nequal);
if (threadIdx.x != 0) {
return;
}
atomicAdd((int *) dst, nequal);
}
void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == src1->type);
GGML_ASSERT( dst->type == GGML_TYPE_I64);
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
GGML_ASSERT(ggml_is_contiguous(dst));
int64_t * dst_d = (int64_t *) dst->data;
cudaStream_t stream = ctx.stream();
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
const int64_t dne = GGML_PAD(ne / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE);
CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream));
const dim3 blocks_dim(WARP_SIZE, 1, 1);
const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1);
switch (src0->type) {
case GGML_TYPE_I32: {
const int * src0_d = (const int *) src0->data;
const int * src1_d = (const int *) src1->data;
count_equal<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_d, dne, ne);
} break;
default:
GGML_ASSERT(false);
break;
}
}

View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_COUNT_EQUAL_CHUNK_SIZE 128
void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -259,7 +259,7 @@ static __global__ void flash_attn_tile_ext_f16(
} }
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
kqsum_j = warp_reduce_sum(kqsum_j); kqsum_j = warp_reduce_sum((float)kqsum_j);
#pragma unroll #pragma unroll
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {

View File

@ -196,7 +196,7 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum); sum = warp_reduce_sum((float)sum);
if (use_logit_softcap) { if (use_logit_softcap) {
sum = logit_softcap*tanhf(sum); sum = logit_softcap*tanhf(sum);
@ -265,7 +265,7 @@ static __global__ void flash_attn_vec_ext_f16(
#pragma unroll #pragma unroll
for (int j = 0; j < ncols; ++j) { for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]); kqsum[j] = warp_reduce_sum((float)kqsum[j]);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
kqsum_shared[j][threadIdx.y] = kqsum[j]; kqsum_shared[j][threadIdx.y] = kqsum[j];
} }
@ -280,7 +280,7 @@ static __global__ void flash_attn_vec_ext_f16(
} }
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
if (parallel_blocks == 1) { if (parallel_blocks == 1) {

View File

@ -2994,6 +2994,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"SUM_ROWS", "SUM_ROWS",
"MEAN", "MEAN",
"ARGMAX", "ARGMAX",
"COUNT_EQUAL",
"REPEAT", "REPEAT",
"REPEAT_BACK", "REPEAT_BACK",
"CONCAT", "CONCAT",
@ -3067,7 +3068,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW", "OPT_STEP_ADAMW",
}; };
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -3088,6 +3089,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"Σx_k", "Σx_k",
"Σx/n", "Σx/n",
"argmax(x)", "argmax(x)",
"count_equal(x)",
"repeat(x)", "repeat(x)",
"repeat_back(x)", "repeat_back(x)",
"concat(x, y)", "concat(x, y)",
@ -3161,7 +3163,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)", "adamw(x)",
}; };
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -5222,6 +5224,23 @@ struct ggml_tensor * ggml_argmax(
return result; return result;
} }
// ggml_count_equal
struct ggml_tensor * ggml_count_equal(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(ggml_are_same_shape(a, b));
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1);
result->op = GGML_OP_COUNT_EQUAL;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_repeat // ggml_repeat
struct ggml_tensor * ggml_repeat( struct ggml_tensor * ggml_repeat(
@ -10809,6 +10828,86 @@ static void ggml_compute_forward_argmax(
} }
} }
// ggml_compute_forward_count_equal
static void ggml_compute_forward_count_equal_i32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS;
GGML_ASSERT(src0->type == GGML_TYPE_I32);
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_are_same_shape(src0, src1));
GGML_ASSERT(ggml_is_scalar(dst));
GGML_ASSERT(dst->type == GGML_TYPE_I64);
const int64_t nr = ggml_nrows(src0);
const int ith = params->ith;
const int nth = params->nth;
int64_t * sums = (int64_t *) params->wdata;
int64_t sum_thread = 0;
// rows per thread
const int64_t dr = (nr + nth - 1)/nth;
// row range for this thread
const int64_t ir0 = dr*ith;
const int64_t ir1 = MIN(ir0 + dr, nr);
for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir / (ne02*ne01);
const int64_t i02 = (ir - i03*ne03) / ne01;
const int64_t i01 = ir - i03*ne03 - i02*ne02;
const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
for (int64_t i00 = 0; i00 < ne00; ++i00) {
const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
sum_thread += val0 == val1;
}
}
if (ith != 0) {
sums[ith] = sum_thread;
}
ggml_barrier(params->threadpool);
if (ith != 0) {
return;
}
for (int ith_other = 1; ith_other < nth; ++ith_other) {
sum_thread += sums[ith_other];
}
*((int64_t *) dst->data) = sum_thread;
}
static void ggml_compute_forward_count_equal(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_I32:
{
ggml_compute_forward_count_equal_i32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_repeat // ggml_compute_forward_repeat
static void ggml_compute_forward_repeat_f32( static void ggml_compute_forward_repeat_f32(
@ -17187,6 +17286,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_argmax(params, tensor); ggml_compute_forward_argmax(params, tensor);
} break; } break;
case GGML_OP_COUNT_EQUAL:
{
ggml_compute_forward_count_equal(params, tensor);
} break;
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
{ {
ggml_compute_forward_repeat(params, tensor); ggml_compute_forward_repeat(params, tensor);
@ -17937,6 +18040,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break; } break;
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL:
{ {
GGML_ABORT("fatal error"); // TODO: implement GGML_ABORT("fatal error"); // TODO: implement
} }
@ -18710,6 +18814,10 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
for (int i = 0; i < gf->n_nodes; ++i) { for (int i = 0; i < gf->n_nodes; ++i) {
struct ggml_tensor * node = gf->nodes[i]; struct ggml_tensor * node = gf->nodes[i];
if (node->type == GGML_TYPE_I32) {
continue;
}
bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM; bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
bool ignore_src[GGML_MAX_SRC] = {false}; bool ignore_src[GGML_MAX_SRC] = {false};
switch (node->op) { switch (node->op) {
@ -19113,6 +19221,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
{
n_tasks = 1;
} break;
case GGML_OP_COUNT_EQUAL:
{
n_tasks = n_threads;
} break;
case GGML_OP_REPEAT: case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK: case GGML_OP_REPEAT_BACK:
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
@ -19611,6 +19726,10 @@ struct ggml_cplan ggml_graph_plan(
cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
} }
} break; } break;
case GGML_OP_COUNT_EQUAL:
{
cur = ggml_type_size(node->type)*n_tasks;
} break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;

View File

@ -116,6 +116,11 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
} else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) { } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
// This is going to create some weird integers though. // This is going to create some weird integers though.
ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor)); ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor));
} else if (tensor->type == GGML_TYPE_I64) {
// Integers with a size of 8 bytes can be set by mirroring the float data, the specific values are again not really meaningful.
const size_t nbytes_half = ggml_nbytes(tensor)/2;
ggml_backend_tensor_set(tensor, data.data(), 0*nbytes_half, nbytes_half);
ggml_backend_tensor_set(tensor, data.data(), 1*nbytes_half, nbytes_half);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
@ -145,6 +150,8 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i])); tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i]));
} else if (t->type == GGML_TYPE_F32) { } else if (t->type == GGML_TYPE_F32) {
tv.push_back(*(float *) &buf[i]); tv.push_back(*(float *) &buf[i]);
} else if (t->type == GGML_TYPE_I64) {
tv.push_back((float)*(int64_t *) &buf[i]);
} else if (t->type == GGML_TYPE_I32) { } else if (t->type == GGML_TYPE_I32) {
tv.push_back((float)*(int32_t *) &buf[i]); tv.push_back((float)*(int32_t *) &buf[i]);
} else if (t->type == GGML_TYPE_I16) { } else if (t->type == GGML_TYPE_I16) {
@ -1116,6 +1123,71 @@ struct test_get_rows : public test_case {
} }
}; };
// GGML_OP_ARGMAX
struct test_argmax : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_argmax(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 100, 1, 1})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * out = ggml_argmax(ctx, a);
ggml_set_name(out, "out");
return out;
}
double max_nmse_err() override {
return 0.0;
}
};
// GGML_OP_COUNT_EQUAL
struct test_count_equal : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
std::string vars() override {
return VARS_TO_STR2(type, ne);
}
test_count_equal(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {4, 500, 1, 1})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * a_argmax = ggml_argmax(ctx, a);
ggml_set_name(a_argmax, "a_argmax");
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(b, "b");
ggml_tensor * b_argmax = ggml_argmax(ctx, a);
ggml_set_name(b_argmax, "b_argmax");
ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
ggml_set_name(out, "out");
return out;
}
double max_nmse_err() override {
return 0.0;
}
};
// GGML_OP_REPEAT // GGML_OP_REPEAT
struct test_repeat : public test_case { struct test_repeat : public test_case {
const ggml_type type; const ggml_type type;
@ -3260,6 +3332,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
test_cases.emplace_back(new test_argmax());
test_cases.emplace_back(new test_count_equal());
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1 for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));