diff --git a/site/en/r1/tutorials/keras/basic_classification.ipynb b/site/en/r1/tutorials/keras/basic_classification.ipynb index 14950538ce4..53f9f536ab0 100644 --- a/site/en/r1/tutorials/keras/basic_classification.ipynb +++ b/site/en/r1/tutorials/keras/basic_classification.ipynb @@ -728,6 +728,7 @@ "plot_image(i, predictions[i], test_labels, test_images)\n", "plt.subplot(1,2,2)\n", "plot_value_array(i, predictions[i], test_labels)\n", + "plt.xticks(range(10), class_names, rotation=45, fontsize=6)\n", "plt.show()" ] }, @@ -745,6 +746,7 @@ "plot_image(i, predictions[i], test_labels, test_images)\n", "plt.subplot(1,2,2)\n", "plot_value_array(i, predictions[i], test_labels)\n", + "plt.xticks(range(10), class_names, rotation=45, fontsize=6)\n", "plt.show()" ] }, @@ -779,6 +781,45 @@ "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Display 6 incorrect predictions and 6 correct predictions\n", + "num_incorrect = 6\n", + "num_correct = 6 \n", + "\n", + "incorrect_indices = []\n", + "correct_indices = []\n", + "\n", + "# Find correct and incorrect predictions\n", + "for i in range(len(test_labels)):\n", + " predicted_label = np.argmax(predictions[i])\n", + " true_label = test_labels[i]\n", + " if predicted_label != true_label and len(incorrect_indices) < num_incorrect:\n", + " incorrect_indices.append(i)\n", + " elif predicted_label == true_label and len(correct_indices) < num_correct:\n", + " correct_indices.append(i)\n", + " # Stop when both lists have enough examples\n", + " if len(incorrect_indices) == num_incorrect and len(correct_indices) == num_correct:\n", + " break\n", + "\n", + "selected_indices = incorrect_indices + correct_indices\n", + "num_images = len(selected_indices)\n", + "\n", + "num_rows = 5 \n", + "num_cols = 3 \n", + "plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))\n", + "for i, idx in enumerate(selected_indices):\n", + " plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)\n", + " plot_image(idx, predictions[idx], test_labels, test_images)\n", + " plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)\n", + " plot_value_array(idx, predictions[idx], test_labels)\n", + "plt.show()" + ] + }, { "cell_type": "markdown", "metadata": {