How to Plot a Confusion Matrix in Matplotlib

How to Plot a Confusion Matrix in Matplotlib

Confusion matrices are a fundamental tool in machine learning for visualizing the performance of a classification algorithm. They show the actual versus predicted classifications in a grid, helping to identify where a model is making errors. In this article, we will explore how to plot a confusion matrix using the Matplotlib library in Python. We will provide detailed examples with complete code snippets that can be run independently.

Understanding Confusion Matrices

Before we dive into the code, let’s understand what a confusion matrix is. A confusion matrix is a table with two dimensions, “Actual” and “Predicted,” each divided into the number of classes in the classification problem. For a binary classification problem, the matrix has 2×2 dimensions, representing True Positive (TP), False Positive (FP), True Negative (TN), and False Negative (FN).

Basic Confusion Matrix Plot

Let’s start with a basic example of plotting a confusion matrix using Matplotlib.

Example 1: Basic Confusion Matrix

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools

# Sample data
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]

# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'])

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

Output:

How to Plot a Confusion Matrix in Matplotlib

Adding Text Annotations

To make the confusion matrix more informative, we can add text annotations inside each square to show the counts.

Example 2: Confusion Matrix with Annotations

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools

# Sample data
y_true = [1, 0, 1, 1, 0, 1]
y_pred = [0, 0, 1, 1, 0, 0]

# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Wistia)
plt.title('Confusion Matrix with Annotations - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1'])

# Adding text annotations
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, format(cm[i, j], 'd'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

Output:

How to Plot a Confusion Matrix in Matplotlib

Normalizing the Confusion Matrix

Sometimes, it’s useful to normalize the confusion matrix to show proportions rather than counts.

Example 3: Normalized Confusion Matrix

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools

# Sample data
y_true = [1, 2, 0, 1, 2, 0]
y_pred = [0, 2, 0, 2, 2, 0]

# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Plotting
plt.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.Spectral)
plt.title('Normalized Confusion Matrix - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'])

# Adding text annotations
for i, j in itertools.product(range(cm_normalized.shape[0]), range(cm_normalized.shape[1])):
    plt.text(j, i, "{:0.2f}".format(cm_normalized[i, j]),
             horizontalalignment="center",
             color="white" if cm_normalized[i, j] > 0.5 else "black")

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

Output:

How to Plot a Confusion Matrix in Matplotlib

Confusion Matrix with Different Color Maps

Changing the color map can help in highlighting different aspects of the confusion matrix.

Example 4: Confusion Matrix with a Different Color Map

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools

# Sample data
y_true = [1, 2, 1, 0, 2, 0]
y_pred = [1, 0, 1, 0, 1, 0]

# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Cool)
plt.title('Confusion Matrix with Cool Color Map - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'])

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

Plotting a Confusion Matrix for Multiclass Classification

In multiclass classification, the confusion matrix can become quite large. Here’s how to handle it.

Example 5: Multiclass Confusion Matrix

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools

# Sample data
y_true = np.random.randint(0, 5, 100)
y_pred = np.random.randint(0, 5, 100)

# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Accent)
plt.title('Multiclass Confusion Matrix - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(5)
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2', 'Class 3', 'Class 4'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2', 'Class 3', 'Class 4'])

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

Output:

How to Plot a Confusion Matrix in Matplotlib

Adding a Grid to the Confusion Matrix

A grid can help in visually separating the classes in the confusion matrix.

Example 6: Confusion Matrix with Grid

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools

# Sample data
y_true = [1, 0, 1, 1, 0, 1]
y_pred = [0, 0, 1, 1, 0, 0]

# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Oranges)
plt.title('Confusion Matrix with Grid - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1'])

# Adding grid
plt.grid(which='major', color='gray', linestyle='-', linewidth=0.5)

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

Output:

How to Plot a Confusion Matrix in Matplotlib

Customizing Text Properties

You can customize the text properties for better readability.

Example 7: Customizing Text in Confusion Matrix

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools

# Sample data
y_true = [1, 2, 1, 0, 2, 0]
y_pred = [1, 0, 1, 0, 1, 0]

# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Purples)
plt.title('Custom Text Confusion Matrix - how2matplotlib.com')
plt.colorbar()
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'])

# Customizing text properties
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, format(cm[i, j], 'd'),
             horizontalalignment="center",
             color="white" if cm[i, j] > 0.5 * cm.max() else "black",
             fontweight='bold', fontsize=14)

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

Output:

How to Plot a Confusion Matrix in Matplotlib

Using Different Shapes for the Confusion Matrix

Instead of the traditional square cells, you can use different shapes such as circles or ellipses to represent the values.

Example 8: Confusion Matrix with Ellipses

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools

# Sample data
y_true = [1, 2, 1, 0, 2, 0]
y_pred = [1, 0, 1, 0, 1, 0]

# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plotting
fig, ax = plt.subplots()
cmap = plt.cm.Blues
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        ellipse = patches.Ellipse((j, i), width=0.5, height=0.8, fill=True, color=cmap(cm[i, j] / cm.max()))
        ax.add_patch(ellipse)
        plt.text(j, i, format(cm[i, j], 'd'),
                 horizontalalignment="center",
                 verticalalignment="center",
                 color="white" if cm[i, j] > 0.5 * cm.max() else "black")

ax.set_xlim(-0.5, np.max(cm.shape)-0.5)
ax.set_ylim(-0.5, np.max(cm.shape)-0.5)
ax.set_xticks(np.arange(cm.shape[1]))
ax.set_yticks(np.arange(cm.shape[0]))
ax.set_xticklabels(['Class 0', 'Class 1', 'Class 2'])
ax.set_yticklabels(['Class 0', 'Class 1', 'Class 2'])
plt.title('Confusion Matrix with Ellipses - how2matplotlib.com')
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.grid(False)
plt.show()

Output:

How to Plot a Confusion Matrix in Matplotlib

Adding Color Bars with Custom Labels

Color bars can be customized to show specific labels that describe the range of values.

Example 9: Confusion Matrix with Custom Color Bar Labels

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
import itertools

# Sample data
y_true = [1, 2, 1, 0, 2, 0]
y_pred = [1, 0, 1, 0, 1, 0]

# Generating the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plotting
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Greens)
plt.title('Confusion Matrix with Custom Color Bar - how2matplotlib.com')
cbar = plt.colorbar()
cbar.set_label('Number of Predictions', rotation=270, labelpad=20)
tick_marks = np.arange(len(set(y_true)))
plt.xticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'], rotation=45)
plt.yticks(tick_marks, ['Class 0', 'Class 1', 'Class 2'])

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

Output:

How to Plot a Confusion Matrix in Matplotlib

Conclusion

In this article, we have explored various ways to plot a confusion matrix using Matplotlib in Python. We’ve covered basic plots, annotations, normalization, different color maps, shapes, and customization options. Each example provided is a complete, standalone code snippet that can be run independently to understand the different aspects of plotting confusion matrices. By mastering these techniques, you can effectively visualize the performance of classification models and gain deeper insights into their behavior.

Like(0)