Interpretable Deep Neural Networks Part 4: Flint

The approaches discussed so far are set in the context of conventional supervised learning. They use a single prediction model and aim to enhance its interpretability. Conversely, in this fourth part of our review on interpretable deep neural networks by design, we discuss a framework proposed by Parekh, Mozharovskyi, and d’Alche-Bu [1] to learn with interpretation (FLINT) that jointly learns a predictive model and an associated interpretation model. This gives rise to a new generic task the authors called supervised learning with interpretation (SLI).

Fig. 1. (Left) General view of FLINT. (Right) Detailed view of FLINT [1].

Methodology

In this section, we define the task of SLI, which is the learning paradigm used by FLINT. Then we describe the architecture of FLINT and how interpretations are obtained from it, as well as how interpretability properties are imposed.

Supervised Learning with Interpretation}

Let \(S = \{ (\textbf{x}_i, y_i) \}_{i=1}^N\) denote the training set consisting of $N$ data sample-target pairs (\(x_i \in \mathbb{R}^d\)), and \(\mathcal{X}\) and \(\mathcal{Y}\) denote the input and output spaces, respectively. In the context of SLI, the interpretation task differs from the prediction task and requires a dedicated model \(g\) that still depends on the predictive model \(f\). Then the loss function is based on that of conventional supervised learning and includes an explicit interpretability objective:

\[\arg \min_{f, g} \mathcal{L}_{pred}(f, S) + \mathcal{L}_{int}(f, g, S),\]

where \(\mathcal{L}_{pred}(f, S)\) is a loss related to the prediction error and \(\mathcal{L}_{int}(f, g, S)\) qualifies the interpretations provided by \(g\) about the predictions generated by \(f\).

Architecture

The predictive model \(f\) is a DNN of \(l\) hidden layers, each denoted as \(f_k\) (\(k \in \{ 1, \dots, l\}\)). The interpreter \(g\) receives as an input the concatenated features generated by the set of intermediate layers of \(f\) with indices \(\mathcal{I} = \{ i_1, \dots, i_T \} \subset \{ 1, \dots, l\}\). The concatenated vector of all intermediate outputs for an input sample \(\textbf{x}\) is denoted as \(f_\mathcal{I} (\textbf{x})\). Then the network \(g\) computes a dictionary of $J$ attribute functions \(\boldsymbol\Phi: \mathcal{X} \rightarrow \mathbb{R}^J\) and an interpretable function \(h: \mathbb{R}^J \rightarrow \mathcal{Y}\) so: \(g(\textbf{x}) = h (\boldsymbol\Phi (\textbf{x}))\). Here, \(h\) takes the form of \(h(\boldsymbol\Phi (\textbf{x})) = \text{softmax}(W^\top \boldsymbol\Phi (\textbf{x}))\), and \(W \in \mathbb{R}^{J \times C}\) (the classification problem has \(C\) classes) is a fully-connected layer.

The attribute dictionary consists of functions \(\phi_j : \mathcal{X} \rightarrow R^+\) (\(j \in \{ 1, \dots, J\}\)) so that the outputs \(\phi_j(\textbf{x})\) are considered as the activation of some basis concept. In addition, the function \(\boldsymbol\Phi\) can be expressed based on the function \(\boldsymbol\Psi\), which represents a deep neural network (with three hidden layers), and the vector \(f_\mathcal{I} (\textbf{x})\): \(\boldsymbol\Psi (f_\mathcal{I} (\textbf{x})) = \boldsymbol\Phi(\textbf{x})\). The parameters of \(g\) are denoted using \(\boldsymbol\Theta_g = (\theta_\Psi, \theta_h)\). Fig. 1 shows a general and a detailed view of the FLINT architecture.

Interpretation in FLINT

Two types of interpretation are sought: global and local. Global interpretations allow for the understanding of which attribute functions are more relevant when predicting a class, while local interpretations convey which attribute functions are relevant for the prediction of a given sample.

Given an input \(\textbf{x}\), the outputs of \(f\) and \(g\) are required to be equal \(\hat{y} = f(\textbf{x}) = g(\textbf{x})\). The contribution of attribute \(\phi_j\) is denoted as \(\alpha_{j, \hat{y}, \textbf{x}} = \phi_j(\textbf{x}) W_{j, \hat{y}}\), and its local relevance score is its normalized contribution \(r_{j, \textbf{x}} = \frac{\alpha_{j, \hat{y}, \textbf{x}}}{\max_i | \alpha_{i, \hat{y}, \textbf{x}} |}\). Then the local interpretation for sample \(\textbf{x}\) is the set of attribute functions with \(r_{j, \textbf{x}}\) greater than some threshold \(1 / \tau\) (\(\tau > 1\)).

The global relevance of an attribute \(\phi_j\) in the prediction of a class \(c\) is obtained by averaging local relevance values over a subset \(S_c\) of the training set where predicted class is \(c\): \(r_{j, c} = \frac{1}{| S_c |} \sum_{\textbf{x} \in S_c} r_{j, \textbf{x}}\). Thus, the global interpretation provided by \(g\) is the set of class-attribute pairs such that their global relevance is greater than some threshold \(1/ \tau\).

Imposing Interpretability Properties

Similar to Section “Learning Interpretable Basis Concepts” (Part 1), three principles are imposed. Based on these, three penalty terms are proposed as part of $\mathcal{L}_{int}$:

  • Fidelity to output: The output of \(g\) should be close to that of \(f\): \(\mathcal{L}_{of} = -\sum_{\textbf{x} \in S} g(\textbf{x})^\top \log (f(\textbf{x}))\).
  • Conciseness and diversity: Conciseness refers to the use of fewer attributes for interpretation. Thus, this is promoted by minimizing the entropy in the attribute space (\(H(\boldsymbol\Psi (f_\mathcal{I} (\textbf{x})))\)). Diversity is achieved by encouraging that multiple attributes are activated across various samples. Therefore, the entropy of average \(\boldsymbol\Psi (f_\mathcal{I} (\textbf{x}))\) over a mini-batch is maximized. The \(\ell_1\) norm of the attributes is also minimized to constrain the magnitude of the attribute activation: \(\mathcal{L}_{cd} = -H(\frac{1}{S}\sum_{\textbf{x} \in S} \boldsymbol\Psi (f_\mathcal{I} (\textbf{x}))) + \sum_{\textbf{x} \in S} H(\boldsymbol\Psi (f_\mathcal{I} (\textbf{x}))) + \eta \sum_{\textbf{x} \in S} || \boldsymbol\Psi (f_\mathcal{I} (\textbf{x})) ||_1\) (\(\eta\) is a hyperparameter).
  • Fidelity to input: Inspired by the autoencoders used by SENNs (Part 1), it is proposed to use a decoder network $dec$ to reconstruct $\textbf{x}$ from the encoded attributes: \(\mathcal{L}_{if} = \sum_{\textbf{x} \in S} (dec(\boldsymbol\Psi (f_\mathcal{I} (\textbf{x}))) - \textbf{x})^2\).

Finally, the loss for interpretability is \(\mathcal{L}_{int} = \beta \mathcal{L}_{of} + \gamma \mathcal{L}_{if} + \delta \mathcal{L}_{cd}\), where \(\beta\), \(\gamma\), and \(\delta\) are hyperparameters.

Experimental Results

The experiments considered four datasets: MNIST, FashionMNIST, CIFAR-10, and QuickDraw. FLINT was compared to SENN [2] and PrototypeDNN [3], which consists of an autoencoder DNN that includes a special layer that encodes prototypes.

The methods were compared quantitatively by calculating the predictive accuracy, fidelity of interpreter (defined as the percentage of samples where the prediction of a model and its interpreter coincide), and conciseness of interpretations (measured as the average number of relevant attributes). Results showed that FLINT achieved significantly higher classification accuracy and fidelity than SENN and PrototypeDNN. This may be attributed to the higher complexity of the network architectures used by FLINT. In addition, FLINT was shown to be more concise than SENNs (conciseness cannot be measured for PrototypeDNN) as it required significantly fewer relevant attributes for interpretation for all tested values of threshold \(1/ \tau\).

On the other hand, the qualitative analysis consisted of studying the global and local interpretations generated by FLINT. Given that this work was focused on image classification, the encoded concepts were presented as sets of visual patterns in the input space that highly activate certain attributes. To do this, given the class-attribute pairs \((c, \phi_j)\) of the set of global interpretations, a subset of three samples from class \(c\) that maximally activate $\phi_j$ was selected from the training set (the subset of maximally activated samples is referred as MAS). Then, each element in MAS was analyzed through a process called activation maximization with partial initialization (AM+PI) [4]. The objective of AM+PI is to synthesize an input sample that maximally activates \(\phi_j\). Fig. 2 depicts some class-attribute pairs with global relevance \(r_{j,c} > 0.2\) for each dataset. Here, the three MAS samples are shown with their corresponding AM+PI map.

Local interpretations for random test samples are shown in Fig. 3. The top three relevant attributes with their corresponding AM+PI maps are shown on the left side of the figure. In addition, on the right side of Fig. 3, it is shown how similar classes activate the same concept. For example, images of trucks and cars activate the concept of “wheels” similarly.

Fig. 2. Example class-attribute pair analysis [1].

Fig. 3. (Left) Local interpretations for test samples. (Right) Examples of attribute functions detecting the same part across various test samples.

Critique

FLINT was proposed as a new paradigm for the interpretation of DNNs. Unlike previous approaches, it states that interpretation should be carried out by a separate interpreter network with dedicated losses. One of the advantages of FLINT over previous methods like SENN is that it does not restrict the complexity of the network architecture, which entails that it avoids a trade-off between accuracy and interpretability. In addition, the concepts it learns are not restricted to simple prototypes that are defined as representative data samples, as was the case of SENNs and ProtoAttend. The attribute functions learned by the interpreter \(g\) allow encoding more abstract concepts. This is promising for cases where a dictionary of concepts is not known a priori, as methods such as CW assume.

Even though the proposed framework offers new ways to study the interpretability problem, there are some aspects that remain unclear. For instance, it was not discussed how to choose the optimal set of intermediate layers \(\mathcal{I}\) that generate the feature maps that serve as inputs for the interpreter network. Selecting \(\mathcal{I}\) may be a difficult choice, especially in cases where the network is considerably deep and involves hundreds of layers. It is not clear if selecting different sets of \(\mathcal{I}\) would generate drastically different or less interpretable concepts. Another potential drawback is the need to tune of several hyperparameters, namely, \(\eta\), \(\beta\), \(\gamma\), \(\delta\), and \(\tau\). Experiments did not show how sensible the results are to different hyperparameters values and how computationally expensive it is to tune them.

Finally, AM+PI maps are arguably interpretable. For example, Fig. 2 includes the case of two samples from MNIST that represent the digit “1”; the AM+PI maps generated for these inputs show artifacts near the edges of the image that correspond to blank regions of the image and thus do not contain useful information for the classification task. This is an indication that synthetic inputs that are generated to maximally activate a given attribute are not necessarily interpretable.

References

  1. J. Parekh, P. Mozharovskyi, and F. d’Alche-Buc, “A framework to learn with interpretation,” in Advances in Neural Information Processing Systems, M. Ranzato, A. Beygelzimer, Y. Dauphin, P. Liang, and J. W. Vaughan, Eds., vol. 34, 2021, pp. 24 273–24 285.
  2. D. Alvarez-Melis and T. Jaakkola, “Towards robust interpretability with self-explaining neural networks,” in Advances in Neural Information Processing Systems, S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, Eds., vol. 31. Curran Associates, Inc., 2018.
  3. O. Li, H. Liu, C. Chen, and C. Rudin, “Deep learning for case-based reasoning through prototypes: A neural network that explains its predictions,” in Proceedings of the 32nd AAAI Conference on Artificial Intelligence and Thirtieth Innovative Applications of Artificial Intelligence Conference and Eighth AAAI Symposium on Educational Advances in Artificial Intelligence, 2018
  4. A. Mahendran and A. Vedaldi, “Visualizing deep convolutional neural networks using natural pre-images,” Int. J. Comput. Vision, vol. 120, no. 3, p. 233–255, dec 2016
Written on May 31, 2023