android : module (#7502)

* move ndk code to a new library

* add gradle file
This commit is contained in:
Elton Kola 2024-05-25 04:11:33 -04:00 committed by GitHub
parent 902184dd3a
commit 9791f40258
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 213 additions and 49 deletions

View File

@ -7,8 +7,6 @@ android {
namespace = "com.example.llama" namespace = "com.example.llama"
compileSdk = 34 compileSdk = 34
ndkVersion = "26.1.10909125"
defaultConfig { defaultConfig {
applicationId = "com.example.llama" applicationId = "com.example.llama"
minSdk = 33 minSdk = 33
@ -20,17 +18,6 @@ android {
vectorDrawables { vectorDrawables {
useSupportLibrary = true useSupportLibrary = true
} }
ndk {
// Add NDK properties if wanted, e.g.
// abiFilters += listOf("arm64-v8a")
}
externalNativeBuild {
cmake {
arguments += "-DCMAKE_BUILD_TYPE=Release"
cppFlags += listOf()
arguments += listOf()
}
}
} }
buildTypes { buildTypes {
@ -55,17 +42,6 @@ android {
composeOptions { composeOptions {
kotlinCompilerExtensionVersion = "1.5.1" kotlinCompilerExtensionVersion = "1.5.1"
} }
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
}
}
externalNativeBuild {
cmake {
path = file("src/main/cpp/CMakeLists.txt")
version = "3.22.1"
}
}
} }
dependencies { dependencies {
@ -78,6 +54,7 @@ dependencies {
implementation("androidx.compose.ui:ui-graphics") implementation("androidx.compose.ui:ui-graphics")
implementation("androidx.compose.ui:ui-tooling-preview") implementation("androidx.compose.ui:ui-tooling-preview")
implementation("androidx.compose.material3:material3") implementation("androidx.compose.material3:material3")
implementation(project(":llama"))
testImplementation("junit:junit:4.13.2") testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5") androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")

View File

@ -1,5 +1,6 @@
package com.example.llama package com.example.llama
import android.llama.cpp.LLamaAndroid
import android.util.Log import android.util.Log
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
@ -9,7 +10,7 @@ import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() {
companion object { companion object {
@JvmStatic @JvmStatic
private val NanosPerSecond = 1_000_000_000.0 private val NanosPerSecond = 1_000_000_000.0
@ -28,7 +29,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
viewModelScope.launch { viewModelScope.launch {
try { try {
llm.unload() llamaAndroid.unload()
} catch (exc: IllegalStateException) { } catch (exc: IllegalStateException) {
messages += exc.message!! messages += exc.message!!
} }
@ -44,7 +45,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
messages += "" messages += ""
viewModelScope.launch { viewModelScope.launch {
llm.send(text) llamaAndroid.send(text)
.catch { .catch {
Log.e(tag, "send() failed", it) Log.e(tag, "send() failed", it)
messages += it.message!! messages += it.message!!
@ -57,7 +58,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
viewModelScope.launch { viewModelScope.launch {
try { try {
val start = System.nanoTime() val start = System.nanoTime()
val warmupResult = llm.bench(pp, tg, pl, nr) val warmupResult = llamaAndroid.bench(pp, tg, pl, nr)
val end = System.nanoTime() val end = System.nanoTime()
messages += warmupResult messages += warmupResult
@ -70,7 +71,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
return@launch return@launch
} }
messages += llm.bench(512, 128, 1, 3) messages += llamaAndroid.bench(512, 128, 1, 3)
} catch (exc: IllegalStateException) { } catch (exc: IllegalStateException) {
Log.e(tag, "bench() failed", exc) Log.e(tag, "bench() failed", exc)
messages += exc.message!! messages += exc.message!!
@ -81,7 +82,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
fun load(pathToModel: String) { fun load(pathToModel: String) {
viewModelScope.launch { viewModelScope.launch {
try { try {
llm.load(pathToModel) llamaAndroid.load(pathToModel)
messages += "Loaded $pathToModel" messages += "Loaded $pathToModel"
} catch (exc: IllegalStateException) { } catch (exc: IllegalStateException) {
Log.e(tag, "load() failed", exc) Log.e(tag, "load() failed", exc)

View File

@ -2,4 +2,5 @@
plugins { plugins {
id("com.android.application") version "8.2.0" apply false id("com.android.application") version "8.2.0" apply false
id("org.jetbrains.kotlin.android") version "1.9.0" apply false id("org.jetbrains.kotlin.android") version "1.9.0" apply false
id("com.android.library") version "8.2.0" apply false
} }

View File

@ -0,0 +1 @@
/build

View File

@ -0,0 +1,68 @@
plugins {
id("com.android.library")
id("org.jetbrains.kotlin.android")
}
android {
namespace = "android.llama.cpp"
compileSdk = 34
defaultConfig {
minSdk = 33
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
consumerProguardFiles("consumer-rules.pro")
ndk {
// Add NDK properties if wanted, e.g.
// abiFilters += listOf("arm64-v8a")
}
externalNativeBuild {
cmake {
arguments += "-DCMAKE_BUILD_TYPE=Release"
cppFlags += listOf()
arguments += listOf()
cppFlags("")
}
}
}
buildTypes {
release {
isMinifyEnabled = false
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
}
}
externalNativeBuild {
cmake {
path("src/main/cpp/CMakeLists.txt")
version = "3.22.1"
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
}
kotlinOptions {
jvmTarget = "1.8"
}
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
}
}
}
dependencies {
implementation("androidx.core:core-ktx:1.12.0")
implementation("androidx.appcompat:appcompat:1.6.1")
implementation("com.google.android.material:material:1.11.0")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
}

View File

@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

View File

@ -0,0 +1,24 @@
package android.llama.cpp
import androidx.test.platform.app.InstrumentationRegistry
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.Assert.*
/**
* Instrumented test, which will execute on an Android device.
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
@RunWith(AndroidJUnit4::class)
class ExampleInstrumentedTest {
@Test
fun useAppContext() {
// Context of the app under test.
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
assertEquals("android.llama.cpp.test", appContext.packageName)
}
}

View File

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
</manifest>

View File

@ -0,0 +1,49 @@
# For more information about using CMake with Android Studio, read the
# documentation: https://d.android.com/studio/projects/add-native-code.html.
# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
# Sets the minimum CMake version required for this project.
cmake_minimum_required(VERSION 3.22.1)
# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
# Since this is the top level CMakeLists.txt, the project name is also accessible
# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
# build script scope).
project("llama-android")
include(FetchContent)
FetchContent_Declare(
llama
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp
GIT_TAG master
)
# Also provides "common"
FetchContent_MakeAvailable(llama)
# Creates and names a library, sets it as either STATIC
# or SHARED, and provides the relative paths to its source code.
# You can define multiple libraries, and CMake builds them for you.
# Gradle automatically packages shared libraries with your APK.
#
# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
# is preferred for the same purpose.
#
# In order to load a library into your app from Java/Kotlin, you must call
# System.loadLibrary() and pass the name of the library defined here;
# for GameActivity/NativeActivity derived applications, the same library name must be
# used in the AndroidManifest.xml file.
add_library(${CMAKE_PROJECT_NAME} SHARED
# List C/C++ source files with relative paths to this CMakeLists.txt.
llama-android.cpp)
# Specifies libraries CMake should link to your target library. You
# can link libraries from various origins, such as libraries defined in this
# build script, prebuilt third-party libraries, or Android system libraries.
target_link_libraries(${CMAKE_PROJECT_NAME}
# List libraries link to the target library
llama
common
android
log)

View File

@ -81,7 +81,7 @@ static void log_callback(ggml_log_level level, const char * fmt, void * data) {
extern "C" extern "C"
JNIEXPORT jlong JNICALL JNIEXPORT jlong JNICALL
Java_com_example_llama_Llm_load_1model(JNIEnv *env, jobject, jstring filename) { Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
auto path_to_model = env->GetStringUTFChars(filename, 0); auto path_to_model = env->GetStringUTFChars(filename, 0);
@ -101,13 +101,13 @@ Java_com_example_llama_Llm_load_1model(JNIEnv *env, jobject, jstring filename) {
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_com_example_llama_Llm_free_1model(JNIEnv *, jobject, jlong model) { Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
llama_free_model(reinterpret_cast<llama_model *>(model)); llama_free_model(reinterpret_cast<llama_model *>(model));
} }
extern "C" extern "C"
JNIEXPORT jlong JNICALL JNIEXPORT jlong JNICALL
Java_com_example_llama_Llm_new_1context(JNIEnv *env, jobject, jlong jmodel) { Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) {
auto model = reinterpret_cast<llama_model *>(jmodel); auto model = reinterpret_cast<llama_model *>(jmodel);
if (!model) { if (!model) {
@ -139,25 +139,25 @@ Java_com_example_llama_Llm_new_1context(JNIEnv *env, jobject, jlong jmodel) {
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_com_example_llama_Llm_free_1context(JNIEnv *, jobject, jlong context) { Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) {
llama_free(reinterpret_cast<llama_context *>(context)); llama_free(reinterpret_cast<llama_context *>(context));
} }
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_com_example_llama_Llm_backend_1free(JNIEnv *, jobject) { Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) {
llama_backend_free(); llama_backend_free();
} }
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_com_example_llama_Llm_log_1to_1android(JNIEnv *, jobject) { Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
llama_log_set(log_callback, NULL); llama_log_set(log_callback, NULL);
} }
extern "C" extern "C"
JNIEXPORT jstring JNICALL JNIEXPORT jstring JNICALL
Java_com_example_llama_Llm_bench_1model( Java_android_llama_cpp_LLamaAndroid_bench_1model(
JNIEnv *env, JNIEnv *env,
jobject, jobject,
jlong context_pointer, jlong context_pointer,
@ -271,13 +271,13 @@ Java_com_example_llama_Llm_bench_1model(
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_com_example_llama_Llm_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer)); llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
} }
extern "C" extern "C"
JNIEXPORT jlong JNICALL JNIEXPORT jlong JNICALL
Java_com_example_llama_Llm_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated. // Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
@ -313,19 +313,19 @@ Java_com_example_llama_Llm_new_1batch(JNIEnv *, jobject, jint n_tokens, jint emb
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_com_example_llama_Llm_backend_1init(JNIEnv *, jobject) { Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
llama_backend_init(); llama_backend_init();
} }
extern "C" extern "C"
JNIEXPORT jstring JNICALL JNIEXPORT jstring JNICALL
Java_com_example_llama_Llm_system_1info(JNIEnv *env, jobject) { Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) {
return env->NewStringUTF(llama_print_system_info()); return env->NewStringUTF(llama_print_system_info());
} }
extern "C" extern "C"
JNIEXPORT jint JNICALL JNIEXPORT jint JNICALL
Java_com_example_llama_Llm_completion_1init( Java_android_llama_cpp_LLamaAndroid_completion_1init(
JNIEnv *env, JNIEnv *env,
jobject, jobject,
jlong context_pointer, jlong context_pointer,
@ -376,7 +376,7 @@ Java_com_example_llama_Llm_completion_1init(
extern "C" extern "C"
JNIEXPORT jstring JNICALL JNIEXPORT jstring JNICALL
Java_com_example_llama_Llm_completion_1loop( Java_android_llama_cpp_LLamaAndroid_completion_1loop(
JNIEnv * env, JNIEnv * env,
jobject, jobject,
jlong context_pointer, jlong context_pointer,
@ -438,6 +438,6 @@ Java_com_example_llama_Llm_completion_1loop(
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_com_example_llama_Llm_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context)); llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
} }

View File

@ -1,4 +1,4 @@
package com.example.llama package android.llama.cpp
import android.util.Log import android.util.Log
import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineDispatcher
@ -10,7 +10,7 @@ import kotlinx.coroutines.withContext
import java.util.concurrent.Executors import java.util.concurrent.Executors
import kotlin.concurrent.thread import kotlin.concurrent.thread
class Llm { class LLamaAndroid {
private val tag: String? = this::class.simpleName private val tag: String? = this::class.simpleName
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle } private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
@ -165,8 +165,8 @@ class Llm {
} }
// Enforce only one instance of Llm. // Enforce only one instance of Llm.
private val _instance: Llm = Llm() private val _instance: LLamaAndroid = LLamaAndroid()
fun instance(): Llm = _instance fun instance(): LLamaAndroid = _instance
} }
} }

View File

@ -0,0 +1,17 @@
package android.llama.cpp
import org.junit.Test
import org.junit.Assert.*
/**
* Example local unit test, which will execute on the development machine (host).
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
class ExampleUnitTest {
@Test
fun addition_isCorrect() {
assertEquals(4, 2 + 2)
}
}

View File

@ -15,3 +15,4 @@ dependencyResolutionManagement {
rootProject.name = "LlamaAndroid" rootProject.name = "LlamaAndroid"
include(":app") include(":app")
include(":llama")