mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
8cc91dc63c
This change upstreams llamafile's cpu matrix multiplication kernels which improve image and prompt evaluation speed. For starters, Q4_0 and Q8_0 weights should go ~40% faster on CPU. The biggest benefits are with data types like f16 / f32, which process prompts 2x faster thus making them faster than quantized data types for prompt evals. This change also introduces bona fide AVX512 support since tinyBLAS is able to exploit the larger register file. For example, on my CPU llama.cpp llava-cli processes an image prompt at 305 tokens/second, using the Q4_K and Q4_0 types, which has always been faster than if we used f16 LLaVA weights, which at HEAD go 188 tokens/second. With this change, f16 LLaVA performance leap frogs to 464 tokens/second. On Intel Core i9-14900K this change improves F16 prompt perf by 5x. For example, using llama.cpp at HEAD with Mistral 7b f16 to process a 215 token prompt will go 13 tok/sec. This change has fixes making it go 52 tok/sec. It's mostly thanks to my vectorized outer product kernels but also because I added support for correctly counting the number of cores on Alderlake, so the default thread count discounts Intel's new efficiency cores. Only Linux right now can count cores. This work was sponsored by Mozilla who's given permission to change the license of this code from Apache 2.0 to MIT. To read more about what's improved, and how it works, see: https://justine.lol/matmul/
144 lines
6.9 KiB
Zig
144 lines
6.9 KiB
Zig
// Compatible with Zig Version 0.11.0
|
|
const std = @import("std");
|
|
const ArrayList = std.ArrayList;
|
|
const Compile = std.Build.Step.Compile;
|
|
const ConfigHeader = std.Build.Step.ConfigHeader;
|
|
const Mode = std.builtin.Mode;
|
|
const CrossTarget = std.zig.CrossTarget;
|
|
|
|
const Maker = struct {
|
|
builder: *std.build.Builder,
|
|
target: CrossTarget,
|
|
optimize: Mode,
|
|
enable_lto: bool,
|
|
|
|
include_dirs: ArrayList([]const u8),
|
|
cflags: ArrayList([]const u8),
|
|
cxxflags: ArrayList([]const u8),
|
|
objs: ArrayList(*Compile),
|
|
|
|
fn addInclude(m: *Maker, dir: []const u8) !void {
|
|
try m.include_dirs.append(dir);
|
|
}
|
|
fn addProjectInclude(m: *Maker, path: []const []const u8) !void {
|
|
try m.addInclude(try m.builder.build_root.join(m.builder.allocator, path));
|
|
}
|
|
fn addCFlag(m: *Maker, flag: []const u8) !void {
|
|
try m.cflags.append(flag);
|
|
}
|
|
fn addCxxFlag(m: *Maker, flag: []const u8) !void {
|
|
try m.cxxflags.append(flag);
|
|
}
|
|
fn addFlag(m: *Maker, flag: []const u8) !void {
|
|
try m.addCFlag(flag);
|
|
try m.addCxxFlag(flag);
|
|
}
|
|
|
|
fn init(builder: *std.build.Builder) !Maker {
|
|
const target = builder.standardTargetOptions(.{});
|
|
const zig_version = @import("builtin").zig_version_string;
|
|
const commit_hash = try std.ChildProcess.exec(
|
|
.{ .allocator = builder.allocator, .argv = &.{ "git", "rev-parse", "HEAD" } },
|
|
);
|
|
try std.fs.cwd().writeFile("common/build-info.cpp", builder.fmt(
|
|
\\int LLAMA_BUILD_NUMBER = {};
|
|
\\char const *LLAMA_COMMIT = "{s}";
|
|
\\char const *LLAMA_COMPILER = "Zig {s}";
|
|
\\char const *LLAMA_BUILD_TARGET = "{s}";
|
|
\\
|
|
, .{ 0, commit_hash.stdout[0 .. commit_hash.stdout.len - 1], zig_version, try target.allocDescription(builder.allocator) }));
|
|
var m = Maker{
|
|
.builder = builder,
|
|
.target = target,
|
|
.optimize = builder.standardOptimizeOption(.{}),
|
|
.enable_lto = false,
|
|
.include_dirs = ArrayList([]const u8).init(builder.allocator),
|
|
.cflags = ArrayList([]const u8).init(builder.allocator),
|
|
.cxxflags = ArrayList([]const u8).init(builder.allocator),
|
|
.objs = ArrayList(*Compile).init(builder.allocator),
|
|
};
|
|
|
|
try m.addCFlag("-std=c11");
|
|
try m.addCxxFlag("-std=c++11");
|
|
try m.addProjectInclude(&.{});
|
|
try m.addProjectInclude(&.{"common"});
|
|
return m;
|
|
}
|
|
|
|
fn obj(m: *const Maker, name: []const u8, src: []const u8) *Compile {
|
|
const o = m.builder.addObject(.{ .name = name, .target = m.target, .optimize = m.optimize });
|
|
if (o.target.getAbi() != .msvc)
|
|
o.defineCMacro("_GNU_SOURCE", null);
|
|
|
|
if (std.mem.endsWith(u8, src, ".c")) {
|
|
o.addCSourceFiles(&.{src}, m.cflags.items);
|
|
o.linkLibC();
|
|
} else {
|
|
o.addCSourceFiles(&.{src}, m.cxxflags.items);
|
|
if (o.target.getAbi() == .msvc) {
|
|
o.linkLibC(); // need winsdk + crt
|
|
} else {
|
|
// linkLibCpp already add (libc++ + libunwind + libc)
|
|
o.linkLibCpp();
|
|
}
|
|
}
|
|
for (m.include_dirs.items) |i| o.addIncludePath(.{ .path = i });
|
|
o.want_lto = m.enable_lto;
|
|
return o;
|
|
}
|
|
|
|
fn exe(m: *const Maker, name: []const u8, src: []const u8, deps: []const *Compile) *Compile {
|
|
const e = m.builder.addExecutable(.{ .name = name, .target = m.target, .optimize = m.optimize });
|
|
e.addCSourceFiles(&.{src}, m.cxxflags.items);
|
|
for (deps) |d| e.addObject(d);
|
|
for (m.objs.items) |o| e.addObject(o);
|
|
for (m.include_dirs.items) |i| e.addIncludePath(.{ .path = i });
|
|
|
|
// https://github.com/ziglang/zig/issues/15448
|
|
if (e.target.getAbi() == .msvc) {
|
|
e.linkLibC(); // need winsdk + crt
|
|
} else {
|
|
// linkLibCpp already add (libc++ + libunwind + libc)
|
|
e.linkLibCpp();
|
|
}
|
|
m.builder.installArtifact(e);
|
|
e.want_lto = m.enable_lto;
|
|
return e;
|
|
}
|
|
};
|
|
|
|
pub fn build(b: *std.build.Builder) !void {
|
|
var make = try Maker.init(b);
|
|
make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;
|
|
|
|
const ggml = make.obj("ggml", "ggml.c");
|
|
const sgemm = make.obj("sgemm", "sgemm.cpp");
|
|
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
|
|
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
|
|
const ggml_quants = make.obj("ggml-quants", "ggml-quants.c");
|
|
const unicode = make.obj("unicode", "unicode.cpp");
|
|
const unicode_data = make.obj("unicode-data", "unicode-data.cpp");
|
|
const llama = make.obj("llama", "llama.cpp");
|
|
const buildinfo = make.obj("common", "common/build-info.cpp");
|
|
const common = make.obj("common", "common/common.cpp");
|
|
const console = make.obj("console", "common/console.cpp");
|
|
const sampling = make.obj("sampling", "common/sampling.cpp");
|
|
const grammar_parser = make.obj("grammar-parser", "common/grammar-parser.cpp");
|
|
const json_schema_to_grammar = make.obj("json-schema-to-grammar", "common/json-schema-to-grammar.cpp");
|
|
const train = make.obj("train", "common/train.cpp");
|
|
const clip = make.obj("clip", "examples/llava/clip.cpp");
|
|
const llava = make.obj("llava", "examples/llava/llava.cpp");
|
|
|
|
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser });
|
|
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
|
|
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
|
|
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
|
|
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
|
|
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
|
|
|
|
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava });
|
|
if (server.target.isWindows()) {
|
|
server.linkSystemLibrary("ws2_32");
|
|
}
|
|
}
|