Remove identical wte/etw logic for jais (#10203)

This commit is contained in:
Faisal Zaghloul 2024-11-07 11:46:12 -05:00 committed by GitHub
parent 5107e8cea3
commit 60e17ce23c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3748,10 +3748,7 @@ class JaisModel(Model):
# Embeddings scale # Embeddings scale
self.embeddings_scale = 1.0 self.embeddings_scale = 1.0
# note: For some JAIS flavors, output is tied to (same as) wte in original model
self.output_is_wte = False
if 'mup_embeddings_scale' in self.hparams: if 'mup_embeddings_scale' in self.hparams:
self.output_is_wte = True # Hack (?)
self.embeddings_scale = self.hparams['mup_embeddings_scale'] self.embeddings_scale = self.hparams['mup_embeddings_scale']
elif 'embeddings_scale' in self.hparams: elif 'embeddings_scale' in self.hparams:
self.embeddings_scale = self.hparams['embeddings_scale'] self.embeddings_scale = self.hparams['embeddings_scale']
@ -3808,10 +3805,7 @@ class JaisModel(Model):
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD): if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
tensors.append((new_name, data_torch * self.embeddings_scale)) tensors.append((new_name, data_torch * self.embeddings_scale))
if self.output_is_wte:
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch * self.width_scale))
elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT): elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
assert not self.output_is_wte
tensors.append((new_name, data_torch * self.width_scale)) tensors.append((new_name, data_torch * self.width_scale))
else: else:
tensors.append((new_name, data_torch)) tensors.append((new_name, data_torch))