mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-08 17:51:45 +00:00
397 lines
10 KiB
C++
397 lines
10 KiB
C++
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
/**
|
|
* Copyright (c) 2023 Nomic, Inc. All rights reserved.
|
|
*
|
|
* This software is licensed under the terms of the Software for Open Models License (SOM),
|
|
* version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany
|
|
* this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc.
|
|
*/
|
|
|
|
#include "kompute/Sequence.hpp"
|
|
|
|
namespace kp {
|
|
|
|
Sequence::Sequence(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
|
std::shared_ptr<vk::Device> device,
|
|
std::shared_ptr<vk::Queue> computeQueue,
|
|
uint32_t queueIndex,
|
|
uint32_t totalTimestamps)
|
|
{
|
|
KP_LOG_DEBUG("Kompute Sequence Constructor with existing device & queue");
|
|
|
|
this->mPhysicalDevice = physicalDevice;
|
|
this->mDevice = device;
|
|
this->mComputeQueue = computeQueue;
|
|
this->mQueueIndex = queueIndex;
|
|
|
|
this->createCommandPool();
|
|
this->createCommandBuffer();
|
|
if (totalTimestamps > 0)
|
|
this->createTimestampQueryPool(totalTimestamps +
|
|
1); //+1 for the first one
|
|
}
|
|
|
|
Sequence::~Sequence()
|
|
{
|
|
KP_LOG_DEBUG("Kompute Sequence Destructor started");
|
|
|
|
if (this->mDevice) {
|
|
this->destroy();
|
|
}
|
|
}
|
|
|
|
void
|
|
Sequence::begin()
|
|
{
|
|
KP_LOG_DEBUG("Kompute sequence called BEGIN");
|
|
|
|
if (this->isRecording()) {
|
|
KP_LOG_DEBUG("Kompute Sequence begin called when already recording");
|
|
return;
|
|
}
|
|
|
|
if (this->isRunning()) {
|
|
throw std::runtime_error(
|
|
"Kompute Sequence begin called when sequence still running");
|
|
}
|
|
|
|
KP_LOG_INFO("Kompute Sequence command now started recording");
|
|
this->mCommandBuffer->begin(vk::CommandBufferBeginInfo());
|
|
this->mRecording = true;
|
|
|
|
// latch the first timestamp before any commands are submitted
|
|
if (this->timestampQueryPool)
|
|
this->mCommandBuffer->writeTimestamp(
|
|
vk::PipelineStageFlagBits::eAllCommands,
|
|
*this->timestampQueryPool,
|
|
0);
|
|
}
|
|
|
|
void
|
|
Sequence::end()
|
|
{
|
|
KP_LOG_DEBUG("Kompute Sequence calling END");
|
|
|
|
if (this->isRunning()) {
|
|
throw std::runtime_error(
|
|
"Kompute Sequence begin called when sequence still running");
|
|
}
|
|
|
|
if (!this->isRecording()) {
|
|
KP_LOG_WARN("Kompute Sequence end called when not recording");
|
|
return;
|
|
} else {
|
|
KP_LOG_INFO("Kompute Sequence command recording END");
|
|
this->mCommandBuffer->end();
|
|
this->mRecording = false;
|
|
}
|
|
}
|
|
|
|
void
|
|
Sequence::clear()
|
|
{
|
|
KP_LOG_DEBUG("Kompute Sequence calling clear");
|
|
if (this->isRecording()) {
|
|
this->end();
|
|
}
|
|
}
|
|
|
|
std::shared_ptr<Sequence>
|
|
Sequence::eval()
|
|
{
|
|
KP_LOG_DEBUG("Kompute sequence EVAL BEGIN");
|
|
|
|
return this->evalAsync()->evalAwait();
|
|
}
|
|
|
|
std::shared_ptr<Sequence>
|
|
Sequence::eval(std::shared_ptr<OpBase> op)
|
|
{
|
|
this->clear();
|
|
return this->record(op)->eval();
|
|
}
|
|
|
|
std::shared_ptr<Sequence>
|
|
Sequence::evalAsync()
|
|
{
|
|
if (this->isRecording()) {
|
|
this->end();
|
|
}
|
|
|
|
if (this->mIsRunning) {
|
|
throw std::runtime_error(
|
|
"Kompute Sequence evalAsync called when an eval async was "
|
|
"called without successful wait");
|
|
}
|
|
|
|
this->mIsRunning = true;
|
|
|
|
for (size_t i = 0; i < this->mOperations.size(); i++) {
|
|
this->mOperations[i]->preEval(*this->mCommandBuffer);
|
|
}
|
|
|
|
vk::SubmitInfo submitInfo(
|
|
0, nullptr, nullptr, 1, this->mCommandBuffer.get());
|
|
|
|
this->mFence = this->mDevice->createFence(vk::FenceCreateInfo());
|
|
|
|
KP_LOG_DEBUG(
|
|
"Kompute sequence submitting command buffer into compute queue");
|
|
|
|
this->mComputeQueue->submit(1, &submitInfo, this->mFence);
|
|
|
|
return shared_from_this();
|
|
}
|
|
|
|
std::shared_ptr<Sequence>
|
|
Sequence::evalAsync(std::shared_ptr<OpBase> op)
|
|
{
|
|
this->clear();
|
|
this->record(op);
|
|
this->evalAsync();
|
|
return shared_from_this();
|
|
}
|
|
|
|
std::shared_ptr<Sequence>
|
|
Sequence::evalAwait(uint64_t waitFor)
|
|
{
|
|
if (!this->mIsRunning) {
|
|
KP_LOG_WARN("Kompute Sequence evalAwait called without existing eval");
|
|
return shared_from_this();
|
|
}
|
|
|
|
vk::Result result =
|
|
this->mDevice->waitForFences(1, &this->mFence, VK_TRUE, waitFor);
|
|
this->mDevice->destroy(
|
|
this->mFence, (vk::Optional<const vk::AllocationCallbacks>)nullptr);
|
|
|
|
this->mIsRunning = false;
|
|
|
|
if (result == vk::Result::eTimeout) {
|
|
KP_LOG_WARN("Kompute Sequence evalAwait reached timeout of {}",
|
|
waitFor);
|
|
return shared_from_this();
|
|
}
|
|
|
|
for (size_t i = 0; i < this->mOperations.size(); i++) {
|
|
this->mOperations[i]->postEval(*this->mCommandBuffer);
|
|
}
|
|
|
|
return shared_from_this();
|
|
}
|
|
|
|
bool
|
|
Sequence::isRunning() const
|
|
{
|
|
return this->mIsRunning;
|
|
}
|
|
|
|
bool
|
|
Sequence::isRecording() const
|
|
{
|
|
return this->mRecording;
|
|
}
|
|
|
|
bool
|
|
Sequence::isInit() const
|
|
{
|
|
return this->mDevice && this->mCommandPool && this->mCommandBuffer &&
|
|
this->mComputeQueue;
|
|
}
|
|
|
|
void
|
|
Sequence::rerecord()
|
|
{
|
|
this->end();
|
|
std::vector<std::shared_ptr<OpBase>> ops = this->mOperations;
|
|
this->mOperations.clear();
|
|
for (const std::shared_ptr<kp::OpBase>& op : ops) {
|
|
this->record(op);
|
|
}
|
|
}
|
|
|
|
void
|
|
Sequence::destroy()
|
|
{
|
|
KP_LOG_DEBUG("Kompute Sequence destroy called");
|
|
|
|
if (!this->mDevice) {
|
|
KP_LOG_WARN("Kompute Sequence destroy called "
|
|
"with null Device pointer");
|
|
return;
|
|
}
|
|
|
|
if (this->mFreeCommandBuffer) {
|
|
KP_LOG_INFO("Freeing CommandBuffer");
|
|
if (!this->mCommandBuffer) {
|
|
KP_LOG_WARN("Kompute Sequence destroy called with null "
|
|
"CommandPool pointer");
|
|
return;
|
|
}
|
|
this->mDevice->freeCommandBuffers(
|
|
*this->mCommandPool, 1, this->mCommandBuffer.get());
|
|
|
|
this->mCommandBuffer = nullptr;
|
|
this->mFreeCommandBuffer = false;
|
|
|
|
KP_LOG_DEBUG("Kompute Sequence Freed CommandBuffer");
|
|
}
|
|
|
|
if (this->mFreeCommandPool) {
|
|
KP_LOG_INFO("Destroying CommandPool");
|
|
if (this->mCommandPool == nullptr) {
|
|
KP_LOG_WARN("Kompute Sequence destroy called with null "
|
|
"CommandPool pointer");
|
|
return;
|
|
}
|
|
this->mDevice->destroy(
|
|
*this->mCommandPool,
|
|
(vk::Optional<const vk::AllocationCallbacks>)nullptr);
|
|
|
|
this->mCommandPool = nullptr;
|
|
this->mFreeCommandPool = false;
|
|
|
|
KP_LOG_DEBUG("Kompute Sequence Destroyed CommandPool");
|
|
}
|
|
|
|
if (this->mOperations.size()) {
|
|
KP_LOG_INFO("Kompute Sequence clearing operations buffer");
|
|
this->mOperations.clear();
|
|
}
|
|
|
|
if (this->timestampQueryPool) {
|
|
KP_LOG_INFO("Destroying QueryPool");
|
|
this->mDevice->destroy(
|
|
*this->timestampQueryPool,
|
|
(vk::Optional<const vk::AllocationCallbacks>)nullptr);
|
|
|
|
this->timestampQueryPool = nullptr;
|
|
KP_LOG_DEBUG("Kompute Sequence Destroyed QueryPool");
|
|
}
|
|
|
|
if (this->mDevice) {
|
|
this->mDevice = nullptr;
|
|
}
|
|
if (this->mPhysicalDevice) {
|
|
this->mPhysicalDevice = nullptr;
|
|
}
|
|
if (this->mComputeQueue) {
|
|
this->mComputeQueue = nullptr;
|
|
}
|
|
}
|
|
|
|
std::shared_ptr<Sequence>
|
|
Sequence::record(std::shared_ptr<OpBase> op)
|
|
{
|
|
KP_LOG_DEBUG("Kompute Sequence record function started");
|
|
|
|
this->begin();
|
|
|
|
KP_LOG_DEBUG(
|
|
"Kompute Sequence running record on OpBase derived class instance");
|
|
|
|
op->record(*this->mCommandBuffer);
|
|
|
|
this->mOperations.push_back(op);
|
|
|
|
if (this->timestampQueryPool)
|
|
this->mCommandBuffer->writeTimestamp(
|
|
vk::PipelineStageFlagBits::eAllCommands,
|
|
*this->timestampQueryPool,
|
|
this->mOperations.size());
|
|
|
|
return shared_from_this();
|
|
}
|
|
|
|
void
|
|
Sequence::createCommandPool()
|
|
{
|
|
KP_LOG_DEBUG("Kompute Sequence creating command pool");
|
|
|
|
if (!this->mDevice) {
|
|
throw std::runtime_error("Kompute Sequence device is null");
|
|
}
|
|
|
|
this->mFreeCommandPool = true;
|
|
|
|
vk::CommandPoolCreateInfo commandPoolInfo(vk::CommandPoolCreateFlags(),
|
|
this->mQueueIndex);
|
|
this->mCommandPool = std::make_shared<vk::CommandPool>();
|
|
this->mDevice->createCommandPool(
|
|
&commandPoolInfo, nullptr, this->mCommandPool.get());
|
|
KP_LOG_DEBUG("Kompute Sequence Command Pool Created");
|
|
}
|
|
|
|
void
|
|
Sequence::createCommandBuffer()
|
|
{
|
|
KP_LOG_DEBUG("Kompute Sequence creating command buffer");
|
|
if (!this->mDevice) {
|
|
throw std::runtime_error("Kompute Sequence device is null");
|
|
}
|
|
if (!this->mCommandPool) {
|
|
throw std::runtime_error("Kompute Sequence command pool is null");
|
|
}
|
|
|
|
this->mFreeCommandBuffer = true;
|
|
|
|
vk::CommandBufferAllocateInfo commandBufferAllocateInfo(
|
|
*this->mCommandPool, vk::CommandBufferLevel::ePrimary, 1);
|
|
|
|
this->mCommandBuffer = std::make_shared<vk::CommandBuffer>();
|
|
this->mDevice->allocateCommandBuffers(&commandBufferAllocateInfo,
|
|
this->mCommandBuffer.get());
|
|
KP_LOG_DEBUG("Kompute Sequence Command Buffer Created");
|
|
}
|
|
|
|
void
|
|
Sequence::createTimestampQueryPool(uint32_t totalTimestamps)
|
|
{
|
|
KP_LOG_DEBUG("Kompute Sequence creating query pool");
|
|
if (!this->isInit()) {
|
|
throw std::runtime_error(
|
|
"createTimestampQueryPool() called on uninitialized Sequence");
|
|
}
|
|
if (!this->mPhysicalDevice) {
|
|
throw std::runtime_error("Kompute Sequence physical device is null");
|
|
}
|
|
|
|
vk::PhysicalDeviceProperties physicalDeviceProperties =
|
|
this->mPhysicalDevice->getProperties();
|
|
|
|
if (physicalDeviceProperties.limits.timestampComputeAndGraphics) {
|
|
vk::QueryPoolCreateInfo queryPoolInfo;
|
|
queryPoolInfo.setQueryCount(totalTimestamps);
|
|
queryPoolInfo.setQueryType(vk::QueryType::eTimestamp);
|
|
this->timestampQueryPool = std::make_shared<vk::QueryPool>(
|
|
this->mDevice->createQueryPool(queryPoolInfo));
|
|
|
|
KP_LOG_DEBUG("Query pool for timestamps created");
|
|
} else {
|
|
throw std::runtime_error("Device does not support timestamps");
|
|
}
|
|
}
|
|
|
|
std::vector<std::uint64_t>
|
|
Sequence::getTimestamps()
|
|
{
|
|
if (!this->timestampQueryPool)
|
|
throw std::runtime_error("Timestamp latching not enabled");
|
|
|
|
const auto n = this->mOperations.size() + 1;
|
|
std::vector<std::uint64_t> timestamps(n, 0);
|
|
this->mDevice->getQueryPoolResults(
|
|
*this->timestampQueryPool,
|
|
0,
|
|
n,
|
|
timestamps.size() * sizeof(std::uint64_t),
|
|
timestamps.data(),
|
|
sizeof(uint64_t),
|
|
vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait);
|
|
|
|
return timestamps;
|
|
}
|
|
|
|
}
|