From 1b2226a7ae90922895ecbbeba785c697db16cc6a Mon Sep 17 00:00:00 2001 From: ojotoxy Date: Sun, 10 Oct 2021 12:10:20 +0200 Subject: [PATCH] Fix some example code in readme for einsum operation --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 88dabba7..2c1bc906 100644 --- a/README.md +++ b/README.md @@ -108,8 +108,8 @@ labels = mtf.import_tf_tensor(mesh, tf_labels, [batch_dim]) w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, hidden_dim]) w2 = mtf.get_variable(mesh, "w2", [hidden_dim, classes_dim]) # einsum is a generalization of matrix multiplication (see numpy.einsum) -hidden = mtf.relu(mtf.einsum(images, w1, output_shape=[batch_dim, hidden_dim])) -logits = mtf.einsum(hidden, w2, output_shape=[batch_dim, classes_dim]) +hidden = mtf.relu(mtf.einsum([images, w1], output_shape=[batch_dim, hidden_dim])) +logits = mtf.einsum([hidden, w2], output_shape=[batch_dim, classes_dim]) loss = mtf.reduce_mean(mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim)) w1_grad, w2_grad = mtf.gradients([loss], [w1, w2])