CUDA: fix scratch malloced on non-main device (#3220)

This commit is contained in:
Johannes Gäßler 2023-09-17 14:16:22 +02:00 committed by GitHub
parent b541b4f0b1
commit 578d8c8f5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6970,6 +6970,7 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
return; return;
} }
if (g_scratch_buffer == nullptr) { if (g_scratch_buffer == nullptr) {
ggml_cuda_set_device(g_main_device);
CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size)); CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size));
} }