mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 11:40:17 +00:00
gemma : fix bfloat16 -> float16 conversion issue (#5810)
This commit is contained in:
parent
f49a535686
commit
e743386728
@ -1811,16 +1811,15 @@ class GemmaModel(Model):
|
|||||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
|
|
||||||
for name, data_torch in self.get_tensors():
|
for name, data_torch in self.get_tensors():
|
||||||
# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
|
|
||||||
if name.endswith("norm.weight"):
|
|
||||||
data_torch = data_torch + 1
|
|
||||||
|
|
||||||
old_dtype = data_torch.dtype
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# convert any unsupported data types to float32
|
||||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
data_torch = data_torch.to(torch.float32)
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
|
# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
|
||||||
|
if name.endswith("norm.weight"):
|
||||||
|
data_torch = data_torch + 1
|
||||||
data = data_torch.squeeze().numpy()
|
data = data_torch.squeeze().numpy()
|
||||||
|
|
||||||
# map tensor names
|
# map tensor names
|
||||||
|
Loading…
Reference in New Issue
Block a user