mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
lora : raise error if lm_head is ignored (#9103)
* lora : raise error if lm_head is ignored * fix style * clarify comment
This commit is contained in:
parent
2a825116b6
commit
d4c3c10fad
@ -363,7 +363,13 @@ if __name__ == '__main__':
|
|||||||
yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
|
yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
dest = super().modify_tensors(data_torch, name, bid)
|
dest = list(super().modify_tensors(data_torch, name, bid))
|
||||||
|
# some archs may have the same tensor for lm_head and output (tie word embeddings)
|
||||||
|
# in this case, adapters targeting lm_head will fail when using llama-export-lora
|
||||||
|
# therefore, we ignore them for now
|
||||||
|
# see: https://github.com/ggerganov/llama.cpp/issues/9065
|
||||||
|
if name == "lm_head.weight" and len(dest) == 0:
|
||||||
|
raise ValueError("lm_head is present in adapter, but is ignored in base model")
|
||||||
for dest_name, dest_data in dest:
|
for dest_name, dest_data in dest:
|
||||||
assert isinstance(dest_data, LoraTorchTensor)
|
assert isinstance(dest_data, LoraTorchTensor)
|
||||||
lora_a, lora_b = dest_data.get_lora_A_B()
|
lora_a, lora_b = dest_data.get_lora_A_B()
|
||||||
|
Loading…
Reference in New Issue
Block a user