[wip] chat now has parameter and cfg

This commit is contained in:
Henri Vasserman 2023-07-20 15:37:31 +03:00
parent 082dd81286
commit e4db70720d
No known key found for this signature in database
GPG Key ID: 2995FC0F58B1A986
2 changed files with 79 additions and 33 deletions

View File

@ -15,7 +15,8 @@
min-width: 300px;
line-height: 1.2;
margin: 0 auto;
padding: 0 0.5em; }
padding: 0 0.5em;
}
#container {
margin: 0em auto;
@ -50,6 +51,7 @@
background-color: #161616;
color: #d6d6d6;
margin-left: 20%;
border-bottom-right-radius: 0;
}
.asst {
@ -57,6 +59,7 @@
color: #161616;
text-align: left;
margin-right: 20%;
border-top-left-radius: 0;
}
.typing {
@ -113,7 +116,7 @@
fieldset label {
margin: 0.5em 0;
display: block;
/*display: block;*/
}
header, footer {
@ -135,22 +138,30 @@
const session = signal({
system: "A chat between a curious user and a pirate.",
system_cfg: "A chat between a curious user and an artificial intelligence assistant.",
message: "{{system}}\n\n### Instruction:\n{{user}}\n\n### Response:\n{{assistant}}",
stop: ["###"],
transcript: [],
type: "chat",
char: "llama",
user: "User",
fullprompt: "",
fullprompt: "", // debug
})
const params = signal({
n_predict: 400,
temperature: 0.7,
repeat_last_n: 256,
repeat_penalty: 1.18,
top_k: 40,
top_p: 0.5,
top_p: 0.95,
tfs_z: 1.0,
typical_p: 1.0,
temperature: 0.7,
repeat_penalty: 1.18,
frequency_penalty: 0.0,
presence_penalty: 0.0,
repeat_last_n: 256,
mirostat: 0,
mirostat_tau: 5.0,
mirostat_eta: 0.1,
cfg_scale: 1.0,
penalize_nl: true,
})
const llamaStats = signal(null)
@ -187,12 +198,19 @@
const system = history.length == 0 ? session.value.system : ""
transcriptUpdate([...history, { system, user: msg, assistant: "" }])
const prompt = session.value.transcript.map(t => template(session.value.message, t)).join("").trimEnd()
session.value = { ...session.value, fullprompt: prompt } // debug
const prompt = session.value.transcript.map(t =>
template(session.value.message, t)).join("").trimEnd()
const cfg_negative_prompt = params.value.cfg_scale > 1 ? session.value.transcript.map(t =>
template(session.value.message, { ...t, system: session.value.system_cfg })
).join("").trimEnd() : ""
session.value = { ...session.value, fullprompt: cfg_negative_prompt } // debug
let currentMessage = ''
const llamaParams = {
...params.value,
cfg_negative_prompt,
stop: session.stop,
}
@ -284,6 +302,18 @@
<pre>${JSON.stringify(session.value.transcript, null, 2)}</pre>` // debug
}
const ParamSlider = ({param, min, max, step, children}) => {
const updateParamsFloat = (el) => params.value = { ...params.value, [param]: parseFloat(el.target.value) }
return html`
<div>
<label for="${param}"><code>${param}</code></label>
<input type="range" id="${param}" min="${min}" max="${max}" step="${step}" name="${param}" value="${params.value[param]}" oninput=${updateParamsFloat} />
<span>${params.value[param]}</span>
<span>${children}</span>
</div>
`
}
const ConfigForm = (props) => {
const updateSession = (el) => session.value = { ...session.value, [el.target.name]: el.target.value }
const updateParams = (el) => params.value = { ...params.value, [el.target.name]: el.target.value }
@ -318,30 +348,46 @@
<pre>${JSON.stringify(session.value.stop)/* debug */}</pre>
</div>
<div>
<label for="temperature">Temperature</label>
<input type="range" id="temperature" min="0.0" max="2.0" step="0.01" name="temperature" value="${params.value.temperature}" oninput=${updateParamsFloat} />
<span>${params.value.temperature}</span>
</div>
<${ParamSlider} min=1 max=10 step=0.1 param=cfg_scale>CFG scale<//>
${params.value.cfg_scale > 1 && html`
<div>
<label for="system_cfg">CFG System prompt</label>
<textarea type="text" name="system_cfg" value="${session.value.system_cfg}" rows=4 oninput=${updateSession}/>
</div>
`}
<div>
<label for="nPredict">Predictions</label>
<input type="range" id="nPredict" min="1" max="2048" step="1" name="n_predict" value="${params.value.n_predict}" oninput=${updateParamsFloat} />
<span>${params.value.n_predict}</span>
</div>
<${ParamSlider} min=1 max=1000 step=1 param=n_predict>Predict N tokens<//>
<${ParamSlider} min=0 max=1000 step=1 param=repeat_last_n>Penalize last N tokens<//>
${params.value.repeat_last_n > 0 && html`
<${ParamSlider} min=0 max=4 step=0.01 param=repeat_penalty>Penalize repeat sequence<//>
<${ParamSlider} min=0 max=4 step=0.01 param=frequency_penalty>Penalize frequent tokens<//>
<${ParamSlider} min=0 max=4 step=0.01 param=presence_penalty>Penalize tokens not present in prompt<//>
`}
<${ParamSlider} min=0 max=2 step=0.01 param=temperature>Temperature<//>
${params.value.temperature > 0 && html`
<div>
<input id=mirostat_0 type=radio name=mirostat checked=${params.value.mirostat == 0} value=0 oninput=${updateParamsFloat} />
<label for=mirostat_0>Temperature</label>
<div>
<label for="repeat_penalty">Penalize repeat sequence</label>
<input type="range" id="repeat_penalty" min="0.0" max="4.0" step="0.01" name="repeat_penalty" value="${params.value.repeat_penalty}" oninput=${updateParamsFloat} />
<span>${params.value.repeat_penalty}</span>
</div>
<input id=mirostat_1 type=radio name=mirostat checked=${params.value.mirostat == 1} value=1 oninput=${updateParamsFloat} />
<label for=mirostat_1>Mirostat v1</label>
<div>
<label for="repeat_last_n">Consider N tokens for penalize</label>
<input type="range" id="repeat_last_n" min="0.0" max="2048" name="repeat_last_n" value="${params.value.repeat_last_n}" oninput=${updateParamsFloat} />
<span>${params.value.repeat_last_n}</span>
</div>
<input id=mirostat_2 type=radio name=mirostat checked=${params.value.mirostat == 2} value=2 oninput=${updateParamsFloat} />
<label for=mirostat_2>Mirostat v2</label>
</div>
${params.value.mirostat == 0 && html`
<${ParamSlider} min=1 max=1000 step=1 param=top_k>Top K<//>
<${ParamSlider} min=0 max=1 step=0.01 param=tfs_z>Tail free sampling<//>
<${ParamSlider} min=0 max=1 step=0.01 param=typical_p>Typical P<//>
<${ParamSlider} min=0 max=1 step=0.01 param=top_p>Top P<//>
`}
${params.value.mirostat > 0 && html`
<${ParamSlider} min=0 max=1 step=0.01 param=mirostat_eta>Mirostat eta, learning rate<//>
<${ParamSlider} min=0 max=1000 step=1 param=mirostat_tau>Mirostat tau, target entropy<//>
`}
`}
<pre>${JSON.stringify(params.value, null, 2)/*debug*/}</pre>
</fieldset>
</form>
`

View File

@ -398,7 +398,7 @@ struct llama_server_context {
evaluator.evaluate(params.n_threads, params.n_batch);
if (cfg_enabled) {
evaluator_guidance.evaluate(params.n_threads, params.n_batch);
evaluator_guidance.evaluate(params.n_threads, params.n_batch);
}
if (params.n_predict == 0) {
@ -1067,7 +1067,7 @@ int main(int argc, char ** argv) {
llama.loadPrompt();
llama.beginCompletion();
if (llama.params.cfg_negative_prompt.size() > 0) {
if (llama.params.cfg_scale > 1.0f && llama.params.cfg_negative_prompt.size() > 0) {
llama.cfg_enabled = true;
llama.loadGuidancePrompt();
}