diff --git a/tensorlayer/layers/embedding.py b/tensorlayer/layers/embedding.py index 861723c38..9d0d882d1 100644 --- a/tensorlayer/layers/embedding.py +++ b/tensorlayer/layers/embedding.py @@ -284,10 +284,8 @@ def forward(self, inputs, use_nce_loss=None): The nce_cost is returned only if the nce_loss is used. """ - if isinstance(inputs, list): - outputs = tf.nn.embedding_lookup(params=self.embeddings, ids=inputs[0]) - else: - outputs = tf.nn.embedding_lookup(params=self.embeddings, ids=inputs) + ids = inputs[0] if isinstance(inputs, list) else inputs + outputs = tf.nn.embedding_lookup(params=self.embeddings, ids=ids) if use_nce_loss is True and not self.activate_nce_loss: raise AttributeError(