Interpretable Deep Neural Networks Part 1: Self Explaining Neural Networks

Deep learning-based systems have not been widely adopted in critical areas such as healthcare and criminal justice due to their lack of interpretability. In addition to high performance, interpretability is necessary for obtaining an appropriate level of trust in this kind of system. In this five-part post, we discuss five recent works related to the development of interpretable deep neural networks by design; that is, they incorporate the interpretability objective into the learning process. The discussed methods are Self-explaining neural networks, ProtoAttend, concept whitening, a framework to learn with interpretation (FLINT), and Entropy-based logic explanations of neural networks. Their novelty and contributions as well as their potential drawbacks and gaps are presented and analyzed.

Introduction

The popularization of deep learning-based systems has allowed unprecedented achievements in fields such as computer vision, speech recognition, natural language processing, medical diagnosis, and others. However, they are usually referred to as “black-box models” in the literature due to the high complexity of the non-linear functions they learn and the great number of parameters they require. Therefore, predictions made by these models, although they may be accurate, are not traceable by humans. As a consequence, they are considered unreliable for critical applications that deal with uncertainty.

The use of deep neural networks (DNNs), as one popular deep learning technique, has impacted applications that are used daily by a very diverse group of people (e.g., mobile applications, route prediction, and search engines). As such, they no longer serve just small teams of researchers or expert users in highly specialized settings. DNNs behave in ways that are difficult for humans to understand and, because their importance in inference tasks grew rapidly, the area of explainable artificial intelligence (XAI) has gained more interest in the last few years.

XAI encompasses two main approaches: explainability and interpretability. Recently, there has been a call for the use of interpretable models instead of explainable models for high-stakes decisions that deeply impact human lives (e.g., healthcare and criminal justice) [1]. In this context, interpretations allow humans to identify cause-effect relationships within the system’s inputs and outputs. Thus, in this five-part post, we will focus on the discussion of recent interpretability methods; specifically, those whose goal is to enhance the interpretability of deep neural networks \textit{by design}. In Part I, we describe the self-explaining neural networks proposed by Alvarez-Melis and Jaakkol [2]. Part II focuses on ProtoAttend, an attention-based prototypical learning method proposed by Arik and Pfister [3]. Part III discusses a method called concept whitening proposed by Hence, Chen, Bei, and Rudin [4]. Part IV describes a framework to learn with interpretation (FLINT) presented by Parekh, Mozharovskyi, and d’ Alché-Bu [5]. In Part V, we discuss the entropy-based logic explanations of neural networks proposed by Barbiero et al. [6]. Open research questions are discussed in Part VI, where we also offer some concluding remarks.

Generalized Interpretable Linear Models

In this section, we build on simple linear regression models, which are widely accepted as interpretable models, and generalize them towards more complex models, as suggested by Alvarez-Melis and Jaakkola [2].

Consider a variable \(\textbf{x} \in \mathbb{R}^d\) that consists of \(d\) features \(\textbf{x} = \{ x_1, \dots , x_d\}\). Let \(\theta = \{ \theta_0, \dots , \theta_d\} \in \mathbb{R}^d\) be their associated regression parameters. Then, the linear regression model is given by \(f(\textbf{x}) = \sum_{i=0}^d \theta_i x_i = \theta^\top \textbf{x}\) (where \(x_0 = 1\)). This model is interpretable because each parameter \(\theta_i\) quantifies a positive/negative contribution of the \(i\)-th feature to the prediction. Besides, feature-specific terms \(\theta_i x_i\) are aggregated additively; that is, the interpretation of the impact of each feature for a given input \(x\) is independent of the other features.

The complexity of the linear model can be improved if we allow the coefficients $\theta$ to be dependent on the input \(\textbf{x}\). As such, the new model can be expressed as \(f(\textbf{x}) = \theta(\textbf{x}) ^ \top \textbf{x}\), where \(\theta\) may be chosen from a complex model class (e.g., a neural network). In general, we may argue that \(f(\textbf{x})\) is no longer interpretable. However, it is possible to maintain interpretability at a local level by ensuring that the coefficients corresponding to all inputs \(\textbf{x}\) that are neighbors to an input \(\textbf{x}'\) in the \(\mathbb{R}^n\) space do not differ significantly; that is, \(\nabla_\textbf{x} f(\textbf{x}) \approx \theta(\textbf{x}'), \;\; \forall \textbf{x} \in \mathcal{N}(\textbf{x}')\), (where \(\mathcal{N}(\textbf{x}')\) denotes the neighborhood of \(\textbf{x}'\)). Thus, the model acts locally around each \(\textbf{x}'\) as a linear model with stable interpretable coefficients \(\theta(\textbf{x}')\).

Furthermore, it may be difficult for humans to base their analyses on raw features. For example, pixels are the basic unit of digital images; however, humans rely on spatial structures and higher-order features for image understanding. Alvarez-Melis and Jaakkola refer to these features as interpretable basis concepts [2]. Let us consider the function \(h(\textbf{x}): \mathcal{X} \rightarrow \mathcal{Z} \subset \mathbb{R}^k\) that transforms the set of raw features from the input space \(\mathcal{X}\) into some latent space \(\mathcal{Z}\) of \(k\) interpretable atoms. Hence, the model is further generalized as follows: \begin{equation} f(\textbf{x}) = \sum_{i=1}^k \theta_i(\textbf{x}) h_i(\textbf{x}) = \theta(\textbf{x})^\top h(\textbf{x}), \end{equation} where $h_i(\textbf{x})$ can be interpreted as the degree to which the $i$-th concept is present in $\textbf{x}$. Here, locality refers to the idea that coefficients $\theta$ are locally linear in the concepts rather than in the raw inputs $\textbf{x}$ (Eq. 1):
\begin{equation} \nabla_{h(\textbf{x})} f(\textbf{x}) \approx \theta(\textbf{x}’), \;\; \forall \textbf{x} \text{ s.t. } h(\textbf{x}) \in \mathcal{N}(h(\textbf{x}’)) . \label{eq:locality} \end{equation}

The final generalization considers how the elements $\theta_i(\textbf{x}) h_i(\textbf{x})$ are aggregated. That is, instead of aggregating them additively, we can define a more flexible class of function $g$ (Eq. 2): \begin{equation} f(\textbf{x}) = g(\theta_1(\textbf{x}) h_1(\textbf{x}), \dots , \theta_k(\textbf{x}) h_k(\textbf{x})). \label{eq:generallinear} \end{equation}

For $g$ to preserve interpretation it should obey certain principles, such as permutation invariance (to avoid higher order uninterpretable effects caused by the relative position of the arguments) and preservation of the sign and relative magnitude of the impact of the relevance values $\theta_i(\textbf{x})$.

Self-explaining Neural Networks

Self-explaining Neural Networks (SENNs) [2] are complex interpretable models by design. They were proposed aiming to behave as generalized interpretable linear models. As such, they maintain desirable characteristics of linear models without limiting performance.

Methodology

SENNs take on the idea of local stability introduced in Eq. 2. We desire \(\theta\) to be consistent for neighboring inputs. This suggests that the difference \(\lVert \theta(\textbf{x}) - \theta(\textbf{x}') \rVert\), \(h(\textbf{x})\) and \(h(\textbf{x}')\) are neighbors, is bounded as follows (Eq. 3): \begin{equation} \lVert \theta(\textbf{x}) - \theta(\textbf{x}’) \rVert \leq L \lVert h(\textbf{x}) - h(\textbf{x}’) \rVert, \label{eq:SENNbounded} \end{equation} where \(L\) is some constant. Thus, we say that $\theta$ is \textit{locally difference-bounded} by function \(h\).

A self-explaining prediction has the general form shown in Eq. 2. The explanations for \(f(\textbf{x})\) are defined as the set \(\{ (h_i(\textbf{x}), \theta_i(\textbf{x})) \}_{i=1} ^ k\) of basis concepts and their influence scores. In addition, a self-explaining model must follow the following principles:

  • P1: The aggregation function \(g\) is monotone and additively separable.
  • P2: For every \(z_i = \theta_i(\textbf{x}) h_i(\textbf{x})\), \(g\) satisfies \(\frac{\partial g}{\partial z_i} \geq 0\).
  • P3: \(\theta\) is locally difference-bounded by \(h\).

Specifically, a SENN is a self-explaining model whose parameters $\theta(\cdot)$ are realized using deep neural networks. Properties P1 and P2 are enforced by using simple affine functions (i.e., $g(z_1, \dots, z_k) = \sum_i A_i z_i$, where $A_i \geq 0$).

The remaining question is how to enforce P3. Given a certain point \(\textbf{x}'\), we want \(\theta(\textbf{x}')\) to behave as the derivative of \(f\) with respect to the concept vector, as stated in Eq. 1. However, it is difficult to impose the equality constraint \(\theta(\textbf{x}') \approx \nabla_{h(\textbf{x})} f(\textbf{x})\) explicitly. Therefore, Alvarez-Melis and Jaakkola [2] proposed to incorporate a penalty function (Eq. 4)

\[\mathcal{L}_\theta (f(\textbf{x})) = \lVert \nabla_{\textbf{x}} f(\textbf{x}) - \theta(\textbf{x})^\top J_\textbf{x}^h(\textbf{x}) \rVert,\]

where \(J_\textbf{x}^h\) denotes the Jacobian of \(h\) with respect to \(\textbf{x}\). We aim to minimize \(\mathcal{L}_\theta\) considering that when \(\mathcal{L}_\theta (f(\textbf{x})) \approx 0\) then \(\theta(\textbf{x}') \approx \nabla_{h(\textbf{x})} f(\textbf{x})\) (because \(\nabla_{\textbf{x}} f(\textbf{x}) = \nabla_{h(\textbf{x})} f(\textbf{x}) J_\textbf{x}^h(\textbf{x})\) by the chain rule).

The elements in Eq. 4 can be found through automatic differentiation. Thus, the final gradient-regularized-objective is expressed as follows:

\[\mathcal{L} = \mathcal{L}_y (f(\textbf{x}), y) + \lambda \mathcal{L}_\theta (f(\textbf{x})),\]

where $\mathcal{L}_y (f(\textbf{x}), y)$ represents a typical classification or regression loss, $y$ represents the target variable, and $\lambda$ is a tunable parameter that trades off performance against stability. Fig. 1 depicts the process of estimating the concept and relevance values, and how they are aggregated to generate the class labels and the set of explanations.

Fig. 1. SENN architecture [2].

Learning Interpretable Basis Concepts

As explained in the second section of this post, raw data features may be transformed into basis concepts that are easier for humans to interpret.
Transformations from raw data into basis concepts can be based on expert knowledge or can be learned automatically. Alvarez-Melis and Jaakkola [2] presented a set of principles for learning interpretable concepts:

  1. Fidelity: Concepts should preserve relevant information from the original data.
  2. Diversity: Original data should be represented with a few non-overlapping concepts.
  3. Grounding: Concepts should be understood by humans immediately.

In order to accomplish the first condition, the authors proposed to train \(h\) (which is part of the SENN architecture) as part of an autoencoder. Specifically, the autoencoder has two components: an encoder that represents the function \(h\) (i.e., it encodes the data input \(\textbf{x}\) into concepts \(h(\textbf{x})\)) and a decoder \(h_{dec}\) that attempts to reconstruct the original data from the concept scores \(h(\textbf{x})\) (i.e., \(h_{dec}(h(\textbf{x})) = \hat{\textbf{x}}\)). Thus, another objective \(\mathcal{L}_h (\textbf{x}, \hat{\textbf{x}})\) is introduced to encourage the affinity between the original and reconstructed data. The loss function is now expressed as:

\[\mathcal{L}_y (f(\textbf{x}), y) + \lambda \mathcal{L}_\theta (f(\textbf{x})) + \xi \mathcal{L}_h (\textbf{x}, \hat{\textbf{x}}).\]

Diversity of concepts is encouraged by enforcing sparsity. Nevertheless, the paper does not mention how sparsity is implemented. Analyzing their code repository, we found that \(\mathcal{L}_h (\textbf{x}, \hat{\textbf{x}})\) involves two penalties: the mean squared error (MSE) between \(\textbf{x}\) and \(\hat{\textbf{x}}\), and the L1 norm of \(h(\textbf{x})\) that is used for sparsity regularization.

Finally, it was proposed to use a basic type of prototyping to achieve the third condition (grounding). The approach consists of using a representative sample of data \(\textbf{X}\) such that each concept \(i\) is defined by the set of \(l\) elements from \(\hat{\textbf{X}}^i\) that maximizes their values: \(\hat{\textbf{X}}^i = \arg \max_{\hat{\textbf{X}} \subset \textbf{X}, |\hat{\textbf{X}}| = l} \sum_{\textbf{x} \in \hat{\textbf{X}}} h_i(\textbf{x})\).

Experimental Results

Experiments were carried out using open-access datasets such as MNIST digit recognition and benchmark UCI datasets. Explanations obtained by the SENN architecture were compared to those obtained by other interpretability methods such as local interpretable model-agnostic explanations (LIME) [7], kernel Shapley values (SHAP) [8], and (\(\epsilon\))-Layerwise Relevance Propagation (E-LRP) [9].

Evaluation of the results is based on:

  1. Explicitness: Are the explanations immediate and understandable?
  2. Faithfulness: Are relevance scores indicative of true importance?
  3. Stability: How consistent are the explanations for similar examples?

Results showed that the encoder of the SENN is able to learn non-overlapping concepts. For example, Fig. 2 shows the results of using a SENN for digit recognition on the MNIST dataset. The network was taught to recognize five concepts ($k = 5$) that are defined as sets of nine representative data samples ($l=9$). Most of the samples within each concept belong to the same class (i.e., represent the same digit) and share similar features (e.g., the “8” of Concept 3 has a similar shape to that of all the other “7”s).

The “true” influence of each feature needs to be known to verify the faithfulness of the trained models. The authors proposed to estimate the influence values by removing one feature at a time and measuring the resulting drop in probability prediction. Then, the correlation between the true influence values and the estimated relevance scores (i.e., parameters $\theta$) was calculated. Results showed that the highest correlation value was obtained by the SENN model.

Fig. 2. Comparison of classic input-based explanations and SENN’s concept-based ones for digit recognition on MNIST [2].

Critique

One of the most important contributions of the paper is that it sets the basis of the general features any interpretable model should have. In addition, it was one of the first methods to propose the use of interpretable DNNs by design. It also provided a set of principles (explicitness, faithfulness, and stability) that allows for the evaluation of explanations generated by any type of method.

Even though experimental results improved interpretations over other classic approaches, a few aspects of the methodology and the evaluation could be improved. For instance, our notion of locality, expressed in Eq. 1 and Eq. 3, is not exactly the same as that used by Alvarez-Melis and Jaakkola. The original paper states that the difference \(\lVert \theta(\textbf{x}) - \theta(\textbf{x}') \rVert\) is bounded by \(L \lVert h(\textbf{x}) - h(\textbf{x}') \rVert\) considering that \(\textbf{x}\) and \(\textbf{x}'\) are neighbors. However, the idea is that coefficients \(\theta\) should be linear on the concepts rather than in the data. This means that neighboring concepts \(h(\textbf{x})\) and \(h(\textbf{x}')\) should have similar coefficients (\(\theta(\textbf{x}) \approx \theta(\textbf{x}')\)). We argue that the fact that \(h(\textbf{x})\) and \(h(\textbf{x}')\) are neighbors does not imply that \(\textbf{x}\) and \(\textbf{x}'\) are. Take the case of image understanding. Fig. 3 shows two images of the same type of bird; however, the similarity value between both images may be low due to different illumination, background, object location, and scale. In spite of these differences, their corresponding encoded concept values (e.g., due to feather colors, peak size, and wing shape) could still be similar.

Fig. 3. Two images of the same type of bird. Pixel-wise comparison would yield a low similarity value even though the same basis concepts are present in both images.

SENNs were designed aiming to behave as generalized interpretable linear models, which assume that the input features (whether we use raw data or basis concepts) are independent. However, this is not always easy to achieve. It is possible that three or more input features are highly correlated (a phenomenon known as multicollinearity), but removing features may imply losing relevant information for the learning task. In addition, there is no constraint that asserts that the set of concepts learned is independent. The sparsity regularization applied in \(\mathcal{L}_h\) only encourages sparsity on the set \(\{ h_i(\textbf{x})\}_{i=1}^k\) (which conveys the degree to which each concept is present given an input \(\textbf{x}\)) but does not restrict the concepts themselves to be “non-overlapping” (i.e., the principle of ``diversity” is not accomplished). As a consequence of using linear models with possibly correlated input features, coefficients \(\theta\) may not quantify the contribution of each feature accurately.

Furthermore, one of the principles proposed to evaluate interpretable models is faithfulness. Even if it would be ideal to compare the relevance scores \(\theta\) to the “true” influence values of the input features, this is not possible in some cases. The authors estimated the features’ relevance by removing one feature at a time and calculating the drop in probability prediction. As discussed above, the input features may be correlated and, as such, their influence values would not be estimated accurately using this approach. Take the case of hyperspectral band selection. Here, we deal with hundreds of highly-correlated input features (spectral bands). For this reason, estimating the relevance of individual bands (and thus selecting a subset of the most relevant spectral bands) is a complex task that requires exploratory approaches [10].

Finally, we note that explaining prediction results by using concept prototypes that are defined as representative data samples does not provide a complete understanding of the model’s decision. For example, in Fig. 2, when given an input image that represents the digit “9”, the interpretation is that the network classified it as such because it is more similar to concept 3 than to the other concepts. Although this seems logical, it does not explain why the same input was not classified as a “7”, which would activate the same concept (i.e., the principle of ``grounding” is not accomplished). Therefore, we state that representative data samples are not ideal prototypes; instead, we should focus on learning new interpretable representations.

References

  1. C. Rudin, “Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead,” Nature Machine Intelligence, vol. 1, no. 5, pp. 206–215, 2019.
  2. D. Alvarez-Melis and T. Jaakkola, “Towards robust interpretability with self-explaining neural networks,” in Advances in Neural Information Processing Systems, S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, Eds., vol. 31. Curran Associates, Inc., 2018.
  3. S. O. Arik and T. Pfister, “Protoattend: Attention-based prototypical learning,” The Journal of Machine Learning Research, vol. 21, no. 1, pp. 8691–8725, 2020.
  4. Z. Chen, Y. Bei, and C. Rudin, “Concept whitening for interpretable image recognition,” Nature Machine Intelligence, vol. 2, no. 12, pp. 772–782, Dec. 2020.
  5. J. Parekh, P. Mozharovskyi, and F. d’Alche-Buc, “A framework to learn with interpretation,” in Advances in Neural Information Processing Systems, M. Ranzato, A. Beygelzimer, Y. Dauphin, P. Liang, and J. W. Vaughan, Eds., vol. 34, 2021, pp. 24 273–24 285.
  6. P. Barbiero, G. Ciravegna, F. Giannini, P. Li ́o, M. Gori, and S. Melacci, “Entropy-based logic explanations of neural networks,” Proceedings of the AAAI Conference on Artificial Intelligence, vol. 36, no. 6, pp. 6046–6054, Jun. 2022.
  7. M. Ribeiro, S. Singh, and C. Guestrin, “Why should I trust you?: Explaining the predictions of any classifier,” in Proceedings of the 2016 Conference of the North American Chapter of the Association for Computational Linguistics: Demonstrations, San Diego, California, Jun. 2016, pp. 97–101.
  8. S. M. Lundberg and S.-I. Lee, “A unified approach to interpreting model predictions,” in Advances in Neural Information Processing Systems, I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, Eds., vol. 30. Curran Associates, Inc., 2017.
  9. S. Bach, A. Binder, G. Montavon, F. Klauschen, K.R. M ̈uller, and W. Samek, “On pixel-wise explanations for non-linear classifier decisions by layer-wise relevance propagation,” PLOS ONE, vol. 10, no. 7, pp. 1–46, 07 2015.
  10. G. Morales, J. W. Sheppard, R. D. Logan, and J. A. Shaw, “Hyperspectral dimensionality reduction based on inter-band redundancy analysis and greedy spectral selection,” Remote Sensing, vol. 13, no. 18, 2021.
Written on October 23, 2022