Skip to content

Commit e38336d

Browse files
authored
Correct typo enbedding -> embedding (#157)
fix: correct typo enbedding -> embedding
1 parent 0572532 commit e38336d

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

jetstream_pt/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def create_quantized_from_nn_linear(
313313
return obj
314314

315315

316-
def get_quantized_enbedding_layer(config: "QuantizationConfig"):
316+
def get_quantized_embedding_layer(config: "QuantizationConfig"):
317317
if not config.enable_weight_quantization:
318318
return nn.Embedding
319319
else:

jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
RMSNorm,
1515
WeightOnlyBlockwiseQuantizedLinear,
1616
WeightOnlyPerChannelQuantizedLinear,
17-
get_quantized_enbedding_layer,
17+
get_quantized_embedding_layer,
1818
get_quantized_linear_layer,
1919
)
2020
from torch import nn
@@ -188,7 +188,7 @@ def __init__(
188188
self.vocab_size = params.vocab_size
189189
self.n_layers = params.n_layers
190190

191-
Embedding = get_quantized_enbedding_layer(env.quant_config)
191+
Embedding = get_quantized_embedding_layer(env.quant_config)
192192
self.tok_embeddings = Embedding(
193193
params.vocab_size,
194194
params.dim,

jetstream_pt/third_party/mixtral/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch import Tensor
2222
from torch.nn import functional as F
2323
from .config import ModelArgs, find_multiple
24-
from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_enbedding_layer
24+
from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_embedding_layer
2525

2626
import jax
2727

@@ -33,7 +33,7 @@ def __init__(self, config: ModelArgs, env) -> None:
3333
self.config = config
3434
self.env = env
3535

36-
Embedding = get_quantized_enbedding_layer(env.quant_config)
36+
Embedding = get_quantized_embedding_layer(env.quant_config)
3737
self.tok_embeddings = Embedding(
3838
config.vocab_size, config.dim, device=config.device
3939
)

0 commit comments

Comments
 (0)