[wip] chat now has parameter and cfg

This commit is contained in:
Henri Vasserman 2023-07-20 15:37:31 +03:00
parent 082dd81286
commit e4db70720d
No known key found for this signature in database
GPG Key ID: 2995FC0F58B1A986
2 changed files with 79 additions and 33 deletions

View File

@ -15,7 +15,8 @@
min-width: 300px; min-width: 300px;
line-height: 1.2; line-height: 1.2;
margin: 0 auto; margin: 0 auto;
padding: 0 0.5em; } padding: 0 0.5em;
}
#container { #container {
margin: 0em auto; margin: 0em auto;
@ -50,6 +51,7 @@
background-color: #161616; background-color: #161616;
color: #d6d6d6; color: #d6d6d6;
margin-left: 20%; margin-left: 20%;
border-bottom-right-radius: 0;
} }
.asst { .asst {
@ -57,6 +59,7 @@
color: #161616; color: #161616;
text-align: left; text-align: left;
margin-right: 20%; margin-right: 20%;
border-top-left-radius: 0;
} }
.typing { .typing {
@ -113,7 +116,7 @@
fieldset label { fieldset label {
margin: 0.5em 0; margin: 0.5em 0;
display: block; /*display: block;*/
} }
header, footer { header, footer {
@ -135,22 +138,30 @@
const session = signal({ const session = signal({
system: "A chat between a curious user and a pirate.", system: "A chat between a curious user and a pirate.",
system_cfg: "A chat between a curious user and an artificial intelligence assistant.",
message: "{{system}}\n\n### Instruction:\n{{user}}\n\n### Response:\n{{assistant}}", message: "{{system}}\n\n### Instruction:\n{{user}}\n\n### Response:\n{{assistant}}",
stop: ["###"], stop: ["###"],
transcript: [], transcript: [],
type: "chat", type: "chat",
char: "llama", fullprompt: "", // debug
user: "User",
fullprompt: "",
}) })
const params = signal({ const params = signal({
n_predict: 400, n_predict: 400,
temperature: 0.7,
repeat_last_n: 256,
repeat_penalty: 1.18,
top_k: 40, top_k: 40,
top_p: 0.5, top_p: 0.95,
tfs_z: 1.0,
typical_p: 1.0,
temperature: 0.7,
repeat_penalty: 1.18,
frequency_penalty: 0.0,
presence_penalty: 0.0,
repeat_last_n: 256,
mirostat: 0,
mirostat_tau: 5.0,
mirostat_eta: 0.1,
cfg_scale: 1.0,
penalize_nl: true,
}) })
const llamaStats = signal(null) const llamaStats = signal(null)
@ -187,12 +198,19 @@
const system = history.length == 0 ? session.value.system : "" const system = history.length == 0 ? session.value.system : ""
transcriptUpdate([...history, { system, user: msg, assistant: "" }]) transcriptUpdate([...history, { system, user: msg, assistant: "" }])
const prompt = session.value.transcript.map(t => template(session.value.message, t)).join("").trimEnd() const prompt = session.value.transcript.map(t =>
session.value = { ...session.value, fullprompt: prompt } // debug template(session.value.message, t)).join("").trimEnd()
const cfg_negative_prompt = params.value.cfg_scale > 1 ? session.value.transcript.map(t =>
template(session.value.message, { ...t, system: session.value.system_cfg })
).join("").trimEnd() : ""
session.value = { ...session.value, fullprompt: cfg_negative_prompt } // debug
let currentMessage = '' let currentMessage = ''
const llamaParams = { const llamaParams = {
...params.value, ...params.value,
cfg_negative_prompt,
stop: session.stop, stop: session.stop,
} }
@ -284,6 +302,18 @@
<pre>${JSON.stringify(session.value.transcript, null, 2)}</pre>` // debug <pre>${JSON.stringify(session.value.transcript, null, 2)}</pre>` // debug
} }
const ParamSlider = ({param, min, max, step, children}) => {
const updateParamsFloat = (el) => params.value = { ...params.value, [param]: parseFloat(el.target.value) }
return html`
<div>
<label for="${param}"><code>${param}</code></label>
<input type="range" id="${param}" min="${min}" max="${max}" step="${step}" name="${param}" value="${params.value[param]}" oninput=${updateParamsFloat} />
<span>${params.value[param]}</span>
<span>${children}</span>
</div>
`
}
const ConfigForm = (props) => { const ConfigForm = (props) => {
const updateSession = (el) => session.value = { ...session.value, [el.target.name]: el.target.value } const updateSession = (el) => session.value = { ...session.value, [el.target.name]: el.target.value }
const updateParams = (el) => params.value = { ...params.value, [el.target.name]: el.target.value } const updateParams = (el) => params.value = { ...params.value, [el.target.name]: el.target.value }
@ -318,30 +348,46 @@
<pre>${JSON.stringify(session.value.stop)/* debug */}</pre> <pre>${JSON.stringify(session.value.stop)/* debug */}</pre>
</div> </div>
<${ParamSlider} min=1 max=10 step=0.1 param=cfg_scale>CFG scale<//>
${params.value.cfg_scale > 1 && html`
<div> <div>
<label for="temperature">Temperature</label> <label for="system_cfg">CFG System prompt</label>
<input type="range" id="temperature" min="0.0" max="2.0" step="0.01" name="temperature" value="${params.value.temperature}" oninput=${updateParamsFloat} /> <textarea type="text" name="system_cfg" value="${session.value.system_cfg}" rows=4 oninput=${updateSession}/>
<span>${params.value.temperature}</span> </div>
</div> `}
<div> <${ParamSlider} min=1 max=1000 step=1 param=n_predict>Predict N tokens<//>
<label for="nPredict">Predictions</label> <${ParamSlider} min=0 max=1000 step=1 param=repeat_last_n>Penalize last N tokens<//>
<input type="range" id="nPredict" min="1" max="2048" step="1" name="n_predict" value="${params.value.n_predict}" oninput=${updateParamsFloat} /> ${params.value.repeat_last_n > 0 && html`
<span>${params.value.n_predict}</span> <${ParamSlider} min=0 max=4 step=0.01 param=repeat_penalty>Penalize repeat sequence<//>
</div> <${ParamSlider} min=0 max=4 step=0.01 param=frequency_penalty>Penalize frequent tokens<//>
<${ParamSlider} min=0 max=4 step=0.01 param=presence_penalty>Penalize tokens not present in prompt<//>
<div> `}
<label for="repeat_penalty">Penalize repeat sequence</label> <${ParamSlider} min=0 max=2 step=0.01 param=temperature>Temperature<//>
<input type="range" id="repeat_penalty" min="0.0" max="4.0" step="0.01" name="repeat_penalty" value="${params.value.repeat_penalty}" oninput=${updateParamsFloat} /> ${params.value.temperature > 0 && html`
<span>${params.value.repeat_penalty}</span> <div>
</div> <input id=mirostat_0 type=radio name=mirostat checked=${params.value.mirostat == 0} value=0 oninput=${updateParamsFloat} />
<label for=mirostat_0>Temperature</label>
<div>
<label for="repeat_last_n">Consider N tokens for penalize</label> <input id=mirostat_1 type=radio name=mirostat checked=${params.value.mirostat == 1} value=1 oninput=${updateParamsFloat} />
<input type="range" id="repeat_last_n" min="0.0" max="2048" name="repeat_last_n" value="${params.value.repeat_last_n}" oninput=${updateParamsFloat} /> <label for=mirostat_1>Mirostat v1</label>
<span>${params.value.repeat_last_n}</span>
<input id=mirostat_2 type=radio name=mirostat checked=${params.value.mirostat == 2} value=2 oninput=${updateParamsFloat} />
<label for=mirostat_2>Mirostat v2</label>
</div> </div>
${params.value.mirostat == 0 && html`
<${ParamSlider} min=1 max=1000 step=1 param=top_k>Top K<//>
<${ParamSlider} min=0 max=1 step=0.01 param=tfs_z>Tail free sampling<//>
<${ParamSlider} min=0 max=1 step=0.01 param=typical_p>Typical P<//>
<${ParamSlider} min=0 max=1 step=0.01 param=top_p>Top P<//>
`}
${params.value.mirostat > 0 && html`
<${ParamSlider} min=0 max=1 step=0.01 param=mirostat_eta>Mirostat eta, learning rate<//>
<${ParamSlider} min=0 max=1000 step=1 param=mirostat_tau>Mirostat tau, target entropy<//>
`}
`}
<pre>${JSON.stringify(params.value, null, 2)/*debug*/}</pre>
</fieldset> </fieldset>
</form> </form>
` `

View File

@ -1067,7 +1067,7 @@ int main(int argc, char ** argv) {
llama.loadPrompt(); llama.loadPrompt();
llama.beginCompletion(); llama.beginCompletion();
if (llama.params.cfg_negative_prompt.size() > 0) { if (llama.params.cfg_scale > 1.0f && llama.params.cfg_negative_prompt.size() > 0) {
llama.cfg_enabled = true; llama.cfg_enabled = true;
llama.loadGuidancePrompt(); llama.loadGuidancePrompt();
} }