How to Plot a Confusion Matrix in Matplotlib
Confusion matrices are a fundamental tool in machine learning for visualizing the performance of a classification algorithm. They show the actual versus predicted classifications in a grid, helping to identify where a model is making errors. In this article, we will explore how to plot a confusion matrix using the Matplotlib library in Python. We will provide detailed examples with complete code snippets that can be run independently.
Understanding Confusion Matrices
Before we dive into the code, let’s understand what a confusion matrix is. A confusion matrix is a table with two dimensions, “Actual” and “Predicted,” each divided into the number of classes in the classification problem. For a binary classification problem, the matrix has 2×2 dimensions, representing True Positive (TP), False Positive (FP), True Negative (TN), and False Negative (FN).
Basic Confusion Matrix Plot
Let’s start with a basic example of plotting a confusion matrix using Matplotlib.
Example 1: Basic Confusion Matrix
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools
# Sample data
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'])
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Output:
Adding Text Annotations
To make the confusion matrix more informative, we can add text annotations inside each square to show the counts.
Example 2: Confusion Matrix with Annotations
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools
# Sample data
y_true = [1, 0, 1, 1, 0, 1]
y_pred = [0, 0, 1, 1, 0, 0]
# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Wistia)
plt.title('Confusion Matrix with Annotations - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1'])
# Adding text annotations
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], 'd'),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Output:
Normalizing the Confusion Matrix
Sometimes, it’s useful to normalize the confusion matrix to show proportions rather than counts.
Example 3: Normalized Confusion Matrix
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools
# Sample data
y_true = [1, 2, 0, 1, 2, 0]
y_pred = [0, 2, 0, 2, 2, 0]
# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# Plotting
plt.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.Spectral)
plt.title('Normalized Confusion Matrix - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'])
# Adding text annotations
for i, j in itertools.product(range(cm_normalized.shape[0]), range(cm_normalized.shape[1])):
plt.text(j, i, "{:0.2f}".format(cm_normalized[i, j]),
horizontalalignment="center",
color="white" if cm_normalized[i, j] > 0.5 else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Output:
Confusion Matrix with Different Color Maps
Changing the color map can help in highlighting different aspects of the confusion matrix.
Example 4: Confusion Matrix with a Different Color Map
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools
# Sample data
y_true = [1, 2, 1, 0, 2, 0]
y_pred = [1, 0, 1, 0, 1, 0]
# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Cool)
plt.title('Confusion Matrix with Cool Color Map - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'])
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Plotting a Confusion Matrix for Multiclass Classification
In multiclass classification, the confusion matrix can become quite large. Here’s how to handle it.
Example 5: Multiclass Confusion Matrix
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools
# Sample data
y_true = np.random.randint(0, 5, 100)
y_pred = np.random.randint(0, 5, 100)
# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Accent)
plt.title('Multiclass Confusion Matrix - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(5)
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2', 'Class 3', 'Class 4'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2', 'Class 3', 'Class 4'])
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Output:
Adding a Grid to the Confusion Matrix
A grid can help in visually separating the classes in the confusion matrix.
Example 6: Confusion Matrix with Grid
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools
# Sample data
y_true = [1, 0, 1, 1, 0, 1]
y_pred = [0, 0, 1, 1, 0, 0]
# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Oranges)
plt.title('Confusion Matrix with Grid - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1'])
# Adding grid
plt.grid(which='major', color='gray', linestyle='-', linewidth=0.5)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Output:
Customizing Text Properties
You can customize the text properties for better readability.
Example 7: Customizing Text in Confusion Matrix
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools
# Sample data
y_true = [1, 2, 1, 0, 2, 0]
y_pred = [1, 0, 1, 0, 1, 0]
# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Purples)
plt.title('Custom Text Confusion Matrix - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'])
# Customizing text properties
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], 'd'),
horizontalalignment="center",
color="white" if cm[i, j] > 0.5 * cm.max() else "black",
fontweight='bold', fontsize=14)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Output:
Using Different Shapes for the Confusion Matrix
Instead of the traditional square cells, you can use different shapes such as circles or ellipses to represent the values.
Example 8: Confusion Matrix with Ellipses
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools
# Sample data
y_true = [1, 2, 1, 0, 2, 0]
y_pred = [1, 0, 1, 0, 1, 0]
# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plotting
fig, ax = plt.subplots()
cmap = plt.cm.Blues
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ellipse = patches.Ellipse((j, i), width=0.5, height=0.8, fill=True, color=cmap(cm[i, j] / cm.max()))
ax.add_patch(ellipse)
plt.text(j, i, format(cm[i, j], 'd'),
horizontalalignment="center",
verticalalignment="center",
color="white" if cm[i, j] > 0.5 * cm.max() else "black")
ax.set_xlim(-0.5, np.max(cm.shape)-0.5)
ax.set_ylim(-0.5, np.max(cm.shape)-0.5)
ax.set_xticks(np.arange(cm.shape[1]))
ax.set_yticks(np.arange(cm.shape[0]))
ax.set_xticklabels(['Class 0', 'Class 1', 'Class 2'])
ax.set_yticklabels(['Class 0', 'Class 1', 'Class 2'])
plt.title('Confusion Matrix with Ellipses - how2matplotlib.com')
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.grid(False)
plt.show()
Output:
Adding Color Bars with Custom Labels
Color bars can be customized to show specific labels that describe the range of values.
Example 9: Confusion Matrix with Custom Color Bar Labels
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools
# Sample data
y_true = [1, 2, 1, 0, 2, 0]
y_pred = [1, 0, 1, 0, 1, 0]
# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Greens)
plt.title('Confusion Matrix with Custom Color Bar - how2matplotlib.com')
cbar = plt.colorbar()
cbar.set_label('Number of Predictions', rotation=270, labelpad=20)
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'])
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
Output:
Conclusion
In this article, we have explored various ways to plot a confusion matrix using Matplotlib in Python. We’ve covered basic plots, annotations, normalization, different color maps, shapes, and customization options. Each example provided is a complete, standalone code snippet that can be run independently to understand the different aspects of plotting confusion matrices. By mastering these techniques, you can effectively visualize the performance of classification models and gain deeper insights into their behavior.