mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 11:40:17 +00:00
57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
|
import matplotlib.pyplot as plt
|
||
|
import sys, os
|
||
|
import csv
|
||
|
|
||
|
labels = []
|
||
|
numbers = []
|
||
|
numEntries = 1
|
||
|
|
||
|
rows = []
|
||
|
|
||
|
def bar_chart(numbers, labels, pos):
|
||
|
plt.bar(pos, numbers, color='blue')
|
||
|
plt.xticks(ticks=pos, labels=labels)
|
||
|
plt.title("Jeopardy Results by Model")
|
||
|
plt.xlabel("Model")
|
||
|
plt.ylabel("Questions Correct")
|
||
|
plt.show()
|
||
|
|
||
|
def calculatecorrect():
|
||
|
directory = os.fsencode("./examples/jeopardy/results/")
|
||
|
csv_reader = csv.reader(open("./examples/jeopardy/qasheet.csv", 'rt'), delimiter=',')
|
||
|
for row in csv_reader:
|
||
|
global rows
|
||
|
rows.append(row)
|
||
|
for listing in os.listdir(directory):
|
||
|
filename = os.fsdecode(listing)
|
||
|
if filename.endswith(".txt"):
|
||
|
file = open("./examples/jeopardy/results/" + filename, "rt")
|
||
|
global labels
|
||
|
global numEntries
|
||
|
global numbers
|
||
|
labels.append(filename[:-4])
|
||
|
numEntries += 1
|
||
|
i = 1
|
||
|
totalcorrect = 0
|
||
|
for line in file.readlines():
|
||
|
if line.strip() != "------":
|
||
|
print(line)
|
||
|
else:
|
||
|
print("Correct answer: " + rows[i][2] + "\n")
|
||
|
i+=1
|
||
|
print("Did the AI get the question right? (y/n)")
|
||
|
if input() == "y":
|
||
|
totalcorrect += 1
|
||
|
numbers.append(totalcorrect)
|
||
|
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
calculatecorrect()
|
||
|
pos = list(range(numEntries))
|
||
|
labels.append("Human")
|
||
|
numbers.append(48.11)
|
||
|
bar_chart(numbers, labels, pos)
|
||
|
print(labels)
|
||
|
print(numbers)
|