1. Introduction
Thanks to remarkable advances in the Convolutional Neural Networks (CNN) [
1,
2,
3,
4], real-life applications of deep learning have become a popular subject, and the feasibility of applications of artificial intelligence, such as self-driving vehicles [
5,
6,
7], medical imaging [
8,
9,
10], finance [
11], and recommendation generation [
12], has increased. However, deep neural network-based approaches have been treated as ‘black box’ functions due to their complex structure and consecutive non-linearity. Despite their high accuracy, they cannot justify their reliability, and interaction with users is also difficult. This opaque nature can cause problems when the model returns the wrong decisions or malfunctions. Therefore, methods have been proposed to present insight into the deep learning model’s output. There are two main ways to describe a model. One is to explain the prediction of individual instances as a heatmap, and the other is to explain the behavior of the model.
Class Activation Mapping (CAM) and its variants [
13,
14,
15] are widely used explanation methods for the prediction of single instances. A heatmap is generated that highlights the input pixels that had the most influence on decision-making by using a linear combination of target layer feature maps. Thus, the definition of the coefficients for the linear combination is important. Grad-CAM [
14] uses the gradient values of the target layer on class
c as the coefficient, and Score-CAM [
15] uses the rate of change of the model’s prediction for class
c when the input is perturbed. However, a heatmap based on CAM has the same size as the target layer feature map, so the final heatmap shows a blurry output because of resizing. Thus, approaches with a relevance score [
16,
17,
18] have been proposed to produce a heatmap in which each pixel represents the importance to the output of the model. They redistribute the relevance score from the final layer of the model to the input layer so that no resizing is required when creating a heatmap. However, they provide similar heatmaps for any class
c so that they provide similar explanations even if the predictions of the model are different, which makes the interpretation ambiguous.
To explain the output of the model, the following should be addressed: ‘Why did it judge that way?’, ‘How did it work?’, and ‘What are the possible failures of the model?’. However, methods based on the single instance explanation can only answer the first question, and they are only valid when the explanation of the correct label is valid.
To explain the model behavior, research into model approximation and learned representation has been proposed. Local Interpretable Model-agnostic Explanations (LIME) [
19] provide an explanation based on a local linear approximation of the model, but it is not sufficient to approximate the complex neural network structures, potentially resulting in poor performance on CNN approximations. Other approaches to the interpretation of the learned representation of CNN, Testing with Concept Activation Vectors (TCAV) [
20] and Automatic Concept-based Explanations (ACE) [
21], attempt to convert internal states of a neural network to human-friendly concepts based on the concept vector. The Prototypical Part Network (ProtoPNet) [
22] tries to attain model transparency via training with prototypes that can summarize each layer representation with a set of instances. However, these methods cannot provide answers to the last of the three questions presented above. Furthermore, when using prototypes or concepts to interpret learned representations, it is necessary to define them, and additional partial datasets for them should be built.
Recently, some graph-based model explanation methods [
23,
24,
25] have been used to disentangle each CNN filter response with probabilistic modeling and learn a graphical model to reveal the hierarchy of CNN activations. They have the advantage of decomposing neural networks from a probabilistic point of view, but they require many hyperparameters, and the results are sensitive to how the hyperparameters are set. Moreover, they do not satisfy the third prerequisite of the model explanation.
To solve this problem, we present an algorithm to draw a chain graph of a model that reveals the channel activations in each layer that play a role in decision boundaries and provide a confusion information dictionary that can reveal the class relationship in which the difference in the formed features is relatively small. The proposed confusion graph shows the chain graph of activation maps, which are common and play an important role in decision changes of the model. Our method does not require any additional data or hyperparameter settings and can satisfy the three requirements for the explanation method that are written as questions above. Our contribution is as follows.
Through the unit channel suppression with two directions, so-called ‘correction’ and ‘violation’, we propose a confusion graph that is a connection of channels that are vulnerable on the decision boundary of each layer in the CNN.
We propose a confusion information dictionary that can reveal the class relationship in which the difference in the formed features is relatively small. Furthermore, it is possible to confirm the effect of the change in the feature level on the class prediction based on the proposed dictionary.
The proposed method is model agnostic, requires no additional data, and does not need any hyperparameters.
The rest of the paper is organized as follows.
Section 2 presents previous work relevant to our research.
Section 3 describes the construction details of the confusion graph on consecutive layers in deep convolutional neural networks. In
Section 4, we comprehensively show how the proposed confusion graph affects the model in both quantitative and qualitative ways and depicts the possible class confusion based on the intermediate sources of the graph construction. Finally, in
Section 5, we conclude the paper with possible applications of our proposed method.
3. Proposed Method
Given the properties of visual representations made by CNNs [
31], there is no doubt that each layer of a CNN takes part in image recognition. It also turns out that the channels of each layer can affect the accuracy for a particular class [
27]. However, how they affect the relationships between classes and these relationships have not been studied in detail. In this section, we measure the effect of each channel on the boundary in image classification by unit channel suppression and create a confusion graph that connects the channels that commonly cause confusion by associating the selected key channels according to the layer order.
3.1. Extract Key Channels for Class Confusion Based on Unit Channel Suppression
Our assumption for channel suppression is to make the target channel zero-valued. This is because we want to observe the information loss that comes from the absence of the targeted channel.
Figure 1 shows how unit channel suppression works on the
i-th layer of a CNN model.
To measure the effect of an individual channel on classification confusion, it is necessary to define the types of decision conversion that can be observed by removing channels. There are three cases in which the class changes. ‘Violation Confusion’ occurs when the model makes correct predictions when the full feature is used but incorrect decisions when the unit channel is deleted. ‘Correction Confusion’ refers to the opposite case, in which the model makes the wrong prediction when the full feature is used, but the decision becomes correct when the unit channel is deleted. In the final confusion case, the model makes the wrong prediction when the full feature is used and also when channel deletion is performed, but the two predictions are different. Based on these observations, we collect the confusion information for the ‘violation’ and ‘correction’ cases to single out the channels that the model was sensitive to in terms of correct and incorrect predictions.
Suppose that we have a fixed target model and input , label . We can select layers for which we want to observe the key channel in , and the set of activations are defined as , where indicates the total number of channels in k activations, . For example, if we choose 3 layers with 64, 128, and 256 channels, then the set consists of that () belongs to the layer with 64 channels and the channels () compose the layer with 128 channels and the remainders are for the final layer. The Algorithm 1 shows the pseudo code of confusion graph construction with confusion dictionary collection.
First, we collect the violation confusion information from each layer. The confusion information should contain the three items below:
Channel information with the layers, which causes confusion.
The class prediction of the model before and after the confusion.
Images that show confusion while the unit channel is suppressed.
Algorithm 1: Pseudo Code for Proposed Confusion Graph Explanation Construction (Violation) |
Input: Fixed model , Input , Label , Model prediction , The channels of the model , Neighbor matrix ; Output:
; Violation Confusion Information Dictionary ; Violation Confusion Graph Step (1) Confusion Relation Collection
;
For in :
For () in enumerate (): If : # collection criteria key If key not in [i]:
[i][key] else:
[i] [key]
Step (2) Filter common confusion relations = {}; assert Common_Confusion
Step (3) Neighboring between and assert
For key in Common_Confusion:
key in
key in For (, ) in :
Step (4) Neighboring between and For in : row coordinates of where the value is max()]) column coordinates of where the value is max()])
Return |
Algorithm 2: Pseudo Code for Total Confusion Graph Construction |
Input: ; Violation Confusion Graph from Algorithm 1 ; Correction Confusion Graph from Algorithm 1 Output: ; Total Confusion Graph Init:
Process: For in :
Return |
After repeating information collection times, we extract the violation confusion relation that appears common to the layers to observe the most common feature-level confusions in target model and extract the channels from each layer that induces the confusions. The selected channels are key channels since they cause the common confusion for each layer. The Algorithm 2 shows the pseudo code of total confusion graph construction based on violation and correction confusion graph.
3.2. Graph Construction Based on Key Channel Neighboring
From the acquired key channels among layers, neighboring is performed by constructing a neighbor matrix between the upper and lower layers to mine channels with the same role. Thus, we need neighbor matrixes to make consecutive connections between layers.
Suppose that the zero-initialized neighbor matrix of with channels and with channels is expressed as , which means no relations between the two layers. If channels in and channels in had the same confusion relation, the cartesian product of two lists of channels is performed, and the element corresponding to the cartesian coordinate in the neighbor matrix increases in value by , indicating that the channels constituting the coordinate play a common role. Thus, if we find common confusions, the maximum value that each neighbor matrix element can have should be . This means that the two channels generate all the common confusion, indicating that they are playing similar roles in class confusion. Similarly, the larger weight on the neighbor matrix indicates that the two connected nodes take similar roles in class confusion. Therefore, when we extract the violation confusion graph, we extract the edges that have the highest weight in the neighbor matrix.
In the case of the correction confusion graph construction, all procedures are identical to the steps for generating a violation confusion graph. The only difference is that the confusion information to be collected is no longer a violation but a correction. The obtained violation graph and correction graph each show the connection relationships of the most confusion-related channels in non-overlapped confusion directions. Therefore, by merging these two graphs, we can obtain a chain graph of the network based on feature-level class confusions. Algorithm 1 is the pseudo code of our violation/correction confusion graph construction and Algorithm 2 is the pseudo code for the total confusion graph.
4. Experiments
For the following experiments, we use VGG16 [
21] initialized with the ImageNet [
22] pre-trained model and fine-tuned on the Animals with Attributes 2 (AwA2) dataset [
23]. AwA2 has 37,322 images of 50 animal classes with pre-extracted feature representations for each image. To avoid confusion based on imbalanced numbers of classes, we excluded the classes with less than 500 images and constructed the entire dataset by randomly sampling 500 images each for the remaining 25 classes. The training and evaluation data were composed by dividing the entire dataset into an 8:2 ratio, and the graph analysis was performed on the evaluation data through unit channel suppression. The model was optimized with Stochastic Gradient Descent (SGD) with a learning rate of 0.01, which achieves 91.82% accuracy on the validation data. Our experiment was conducted on NVIDIA Titan Xp 12GB. We investigated the output of each convolution block and classification layers in VGG16.
Moreover, we compare our work with [
28], which is the state-of-the-art method to approximate the importance of neurons in the model, on ImageNet validation data. For fair comparison, we followed the same settings as [
28] by dividing the released ImageNet validation set into half (25,000 images each) and using one part as graph construction and another part as test sets. Since the proposed method does not require hyperparameters except batch size, we set the batch size as 128, which is the same as [
28].
4.1. Confusion Graphs Based on Feature-Level Confusion and How It Affects the Model’s Decision
In this section, we describe the effect of our proposed graph on the model’s performance.
Figure 2 shows the violation and correction confusion graphs of VGG16. As mentioned above, we select the last convolution of each convolution block and classification layers as our target layers. Compared to other studies that are complicated to visualize, our proposed graph selects a relatively small number of channels, less than 5% of the total number of channels, so the key channels can be drawn as a single figure.
To observe whether the proposed graph, which collects channels that behaved identically for a common confusion relationship, can be a weak point in the model, we observed the effect of each graph on the model performance. We compared the accuracy of the model after deleting the channels that consist of the confusion graph and the random channel deletion. To maintain fairness between the random and proposed graph, the number of channels turned off for each layer in random channel deletion is the same as the number of channels turned off in the same layer in the proposed confusion graph. We repeat this procedure five times to maintain the randomness of the experiment and use the average value.
Table 1 shows the performance drop of the model with each type of graph and the full-featured model. The total number of channels in our selected layer was 9152, and the classification accuracy was 91.82% on the evaluation data.
The first column shows the accuracy of the model according to each state. The full-featured state indicates a state in which no zero-out manipulation is applied to the model, and violation, correction, and total confusion are the results of measuring the accuracy of the model after deleting each type of graph from the full-featured state. The second column shows the model accuracy after zero-out by randomly selecting the same number of channels as the proposed graph for each layer. The last column indicates the total number of channels dropped by each graph state.
The result shows that when the total number of channels in confusion graphs, 54, is less than 1% of the total number of channels, the performance decreased by 58.7%, dropping to less than one-third of the original accuracy. Furthermore, compared to the random channel deletion, the selected channels play a similar role in that they show 47.11% lower performance compared to the random deletion when 54 channels are deleted. This suggests that the proposed graph points to the weak channels of the model, which should be modified to improve the robustness of the model in the future.
Table 2 shows the performance drop of the proposed graph and the Neuron Shapley (NShap) [
28]. According to [
28], 25,000 ImageNet test sets they used showed 74% accuracy for the pretrained Inception-V3 and dropping 10 filters lowered the model’s performance by 38% and dropping 20 could lower it by as much as 8%. Since we also separate the ImageNet validation set into two parts, one for graph construction and another for test, we should report the performance of pretrained Inception-V3 on our test part. We get 76% accuracy, and it is indicated in parentheses of the last column of
Table 2. The observed model performance by deleting the proposed graphs is shown in the 3rd to 5th rows of
Table 2. Both the violation graph and the correction graph delete less than 20 neurons, and the violation lowers the accuracy by 0.75% and the correction by 2%. It infers that the proposed method can find more efficient neurons than [
28]. Moreover, even though the confusion graph uses a larger number than [
28], it can completely destroy the model in terms of reducing the accuracy by 0.1%. The time taken to obtain the graph will be covered in
Section 4.2.
4.2. Efficiency of Proposed Confusion Graph
Additionally, we checked the effect of each graph by performing a random selection on each type of confusion graph. We randomly turn off as many as x% of channels in each graph and leave (100 − x)% of channels in each graph alone. For example, 0% selection on violation graph means that none of the channels in the graph got zero-values, and 90% attack means that we choose 14 × 0.9 = 12.6 ≈ 12 channels in the graph to be zeroed out and leave the remaining two channels with their own values. Thus, the larger x is, the more channels are deleted. We also compare the results with random channel deletion for the same amount of deletion.
Figure 3 shows the results of random channel selection on the violation, correction, and total confusion graphs, respectively. The blue line in each figure indicates performance with
x% selection of each graph, and the orange line is that of random channel deletion with the same number of channels. From each figure it can be observed that even with fewer channels in the path, there is more performance drop than for random channel selection. Further, even if only about 50% of channels are selected in each path, it can cause a greater performance degradation than the random off accuracy in
Table 1. This confirms that the channels found by the proposed method operate at the boundaries of the class feature, which is the weak point of the model.
Efficiency not only indicates how each graph corruption affects model accuracy but also includes the time until graph creation. Therefore, we also analyzed the computational details of the proposed algorithm. To measure the importance of each neuron in the model’s performance, NShap formulates the problem into Multi-Armed Bandit (MAB) problem and applies early truncation to speed up the single iteration of its algorithm. However, it iterates repeatedly until the algorithm converges. For example, it took 21 h for their algorithm to converge, requiring a total of 3000 iterations to get the Shapley value of Inception-V3 with 100 parallelized computing machines. This indicates that it needs a lot of computing power to converge their algorithm. Meanwhile, the proposed method does not have the convergence problem, so it can be done with a single computing machine with a single iteration. It took 23 h 48 m with a single computing machine to figure out the neurons for our graph with a single computing machine. Although it took about three more hours to obtain the graph than [
28], we used only one computing machine, much less than 100, which implies it is better for real-world applications.
It took 8 h 51 m to collect the confusion dictionary for VGG16 in a single machine, and it can be further reduced by half if we used multiple machines. Once the confusion dictionary is built, then it only takes 31 s to put it together to form the proposed confusion graph without GPU support.
The total number of forward-pass is another option for comparing computation complexity. According to [
28], their algorithm can approximate the top 100 import neurons without observing 1500 neurons in one iteration among 17,216 neurons. Thus, they search about 15,000 neurons per batch and repeat 3000 times, and it requires about 4.5 × 10
7 forward-passes. However, in such a case, the proposed method needs only 17,216 forward-pass for each batch, which is far less than the state-of-the-art method.
4.3. Observation of Confusion Relations between Two Classes
In addition to the graph, since the proposed method is concentrated on the role of each neuron, the proposed confusion information dictionary can provide clues to understand feature-level class confusions.
When building the confusion matrix with image classification results, we cannot consider where each image is at the feature level but only reflect the output for that image. However, based on the confusion information dictionary, which contains what kind of images caused class confusion when modification occurred at the feature level with unit channel suppression, it is possible to distinguish the images that can be changed to another class at the feature level from those that cannot. This can reveal a closer confusion relationship between the two specific classes from confusion relation based on channel deletion.
We experimented with ox and cow as an example.
Figure 4 shows the ox–cow confusion relationship that can be seen through a simple confusion matrix, and
Figure 5 shows the ox–cow confusion relationship drawn through the confusion information dictionary. In both figures, the blue frames represent samples from the cow class, and the red frames indicate samples from the ox class. Model decisions of ‘ox’ are placed on the left side of the figures, and decisions of ‘cow’ are placed on the right side of the figures. Then, in
Figure 4, we can observe the samples where the label was ‘ox’, but the model prediction was ‘cow’ and vice versa.
However, in
Figure 5, we can observe the classification of the samples in greater detail. In this case, the red arrows indicate cases where the decision changed from cow to ox during unit channel suppression, and the blue arrows indicate the opposite. At this time, cases with dark brown or black objects with cow predictions show red arrows regardless of the presence or absence of horns. Moreover, if a cow is incorrectly predicted as an ox, the decision does not change at the feature level if the color of the object is dark brown or black. Similarly, even in the case of the blue arrow, if the object is a relatively pale color between white and brown, regardless of the presence or absence of horns, confusion occurs at the feature level between ox and cow. Furthermore, in the case of a completely white ox, even if it is incorrectly predicted as cow, the decision does not change at all during the channel suppression.
In addition to explaining the cause of confusion for the two classes in
Figure 5, we also checked the concept in ACE. It should be noted here that ACE is mainly used to reveal which concept was used to determine a specific class, but it is not possible to identify the confusion between two specific classes. Therefore, after identifying the confusion relationship with the proposed method, it can be additionally used to find the confused patches for confusion classes.
Figure 6 presents concept patches from the ACE algorithm with the purpose of distinguishing between ox and cow in our evaluation set. The collected patches are the subset of concepts that has more than 90% accuracy on cow and ox classification and
p-value < 0.05. Except for the patches that are related to the background, which has a green or sky-blue color, the color range of ox patches is darker than that of cows. However, the pattern of concept is too noisy, so concepts from ACE are also difficult to translate into simple human language. Thus, we investigate the difference between ox and cow that was observed in the pre-defined predicate provided by AwA2. The color-related parts (e.g., brown, black, white) are marked with the common points of ox and cow in the predicates. Therefore, it becomes a factor that can confuse the two classes. This indicates that class-discriminative characteristics can be inferred by observing the sample distribution, as shown in
Figure 6, through the confusion dictionary of the proposed method. This can be helpful when making a future training strategy, such as planning for additional training with light-colored oxen and dark-colored cows, enabling the two classes to be distinguished between, irrespective of color.
4.4. Possible Failures Based on Feature Relation from Confusion Information Dictionary
The most common confusions in the collected information were used to make confusion graphs, but the confusion information dictionary contains how the predictions change when a part of each feature is lost in a whole range.
For the i-th layer activation , is the c-th channel of the activation and should be the feature map with the c-th channel zeroed out. To investigate the importance of the suppressed channel in the layer when the decision of the model is changed by unit channel suppression, we measured two things.
First, we measured the cosine similarity between
and
if the predicted output from the two features causes a confusion to see how much feature changes are based on the angular similarity. This is closer to 1 when the two features point in the same direction, 0 when they are orthogonal, and close to −1 when they are opposite. Thus, if the value is closer to 1, it means that
has small importance on angular perspective. Further, we calculate the sum of the absolute values of
and divide it by the sum of the absolute values of
to observe the value proportion of
to
. It is also measured only when the suppressing
induces the confusion. As this ratio increases, the value occupied by
is large when the confusion occurs so that the importance of the corresponding channel to the size of the value is large.
Table 3 shows this.
According to
Table 3, even if it is a channel that changed the decision-making of the model, the importance of the value occupied by the corresponding channel in the layer is small, and the difference between the zeroed-out and non-zeroed-out channels is small. This can be seen from the value ratio of all layers being less than 1% and the cosine similarity before and after zero-out being more than 0.98. Therefore, the importance of the suppressed channels in each layer is weak. This result indicates that the difference between the unit channel subtracted feature and the full feature is small enough, and the class confusion caused by this small difference means that the distinction between the two features is small. Therefore, the two classes in such a relationship can be said to be possible classification failure cases for each other. Reflecting this relationship, in
Table 4 and
Table 5, we describe the class relationship that is closest on the feature level.
Table 4 is the result of confirming whether the confusion relations that appear frequently in the classification result are often confused, even at the actual feature level. The first column of
Table 4 represents the top 10 ‘label-prediction’ relations with the highest number of confusions at the image level. The remaining columns indicate the correction and violation relations, those with the top 10 frequencies on the entire layer and common with the ‘label-prediction’ confusions. Among them, the bold font indicates relationships that appear in the first column, which indicates that relationships that were confused in image level confusion also appear in feature level confusion. That is, when observing the features of two classes, the intra-class bias and inter-class variance are smaller than the relationship with other classes. This accounted for 65.04% of the number of times of confusion in features, and it is considered to be a priority in improving model performance in the future.
Table 5 shows unique relationships that did not exist at the image level but only at the feature level. This information cannot be known when only the input–output relationship is observed, but it can be understood as a feature-forming relationship of classes that should not be ignored because it accounts for 34.96% of the number of instances of confusion in features. Therefore, like the relationships shown in
Table 3, those relations also have smaller intra-class bias and inter-class variance than unseen relations. Thus, it can be used to improve the robustness of the model.
5. Discussion
The proposed method can identify the model’s vulnerability in a reasonable time. Since this is a vulnerability based on class confusion in the model, it can provide feedback on future model training. It gives hints about the propensity of input images to be confused, for example, as described in
Section 4.3, or which classes should have more separable relationships within the feature space described in
Section 4.4. If feedback is given to the model through these hints, the model’s confusion will be relieved, naturally reducing the proposed confusion graph.
As another kind of feedback, from the obtained confusion dictionary, we can extract channels related to a particular confusion relationship. By zero-outing them, the decision tendency of the model can be manipulated in the desired direction. For example, the VGG16 model we used in our experiment predicts 104 ox and 93 cow when full-featured. If all channels confusing from ox to cow are deleted based on the proposed confusion dictionary, the model predicts 0 ox and 296 cows. Conversely, if we delete channels that confuse cows with ox, the model provides three predictions for cow and 94 for ox. This is natural as the proposed graph is constructed by role-based neurons. However, since we investigated each neuron independently, if we lose the graph that connects them all, the part about continuous information loss is less clear, so decision adjustment at the current stage tends not to be stable.
Based on these examples, this model can be applied to various fields, such as medical imaging. The explainable AI methods proposed in the general computer vision domain might be directly applicable to the medical area [
32,
33]. There are several papers using model activation to indicate the location of lesions for model decision making [
34,
35] or provide neuron’s importance on the output of CNN [
36]. The proposed method also provides analysis on the trained model regardless of the input domain, so it is sufficiently applicable to models using medical images. In this case, among the feedback suggested by the model, it can be applied to the former part, and the model designer can ask the doctor for additional data through the tendency of the model obtained from the former. We can discuss whether this is happening in the clinical setting as well.
6. Conclusions and Future Work
In this paper, we proposed a model-agnostic method to draw a graph of key channels affecting decisions in a CNN. Unlike previous research that concentrated on the effect of neurons in the accuracy drop aspect, we observe the neuron’s role of class confusion. Even though the proposed graph utilizes a small amount of the model, it can derive a catastrophic performance degradation. Thus, the channels that compose the proposed graph can be seen as the weak points of the model. In addition, it was confirmed that the channels that are suppressed and change the decision of the model do not occupy a large portion in the actual feature layer. Based on this observation, unlike conventional studies that are concentrated on the impact of neurons on the model’s performance, it is possible to analyze the cause of confusion for two specific classes through the confusion information dictionary and to observe the relationship that appears in image level confusion and the relationship that appears only in the feature level. Through this method, the proposed analysis method can be used to formulate a strategy to improve the model in the future.
Furthermore, the proposed method has the advantage of time and complexity compared to the state-of-the-art method. The proposed method can be applied to a single computing machine with a reasonable execution time, and it does not require any tedious hyperparameter settings so that it is easier to use. However, because the proposed method treats each layer separately in information collection and mines consecutive layer relations in a post-hoc way, it cannot reflect the consecutive information loss in neighbor layers and the effect of the feedback loop on the constructed graph. Therefore, in future research, we will study how to build a graph more quickly by reflecting the continuous information relation of each layer and applying more detailed training strategies to the weakness channel graph for robustness enhancement of the model in the bias–variance perspective of each class feature cluster.