Interpretable Deep Neural Networks Part 2: Protoattend

The second part of our review on interpretable deep neural networks by design focuses on a method called ProtoAttend proposed by Arik and Pfister [1]. ProtoAttend aims to model inherently-interpretable deep neural networks based on the principle that “prototypes should constitute a minimal subset of samples with a high interpretable value that can serve as distillation or condensed view of a dataset”. The ProtoAttend mechanism can be integrated into a wide range of neural network architectures. It generates encoded representations and relates them to data samples using attention mechanisms for prototype selection.

Methodology

Consider a training set that consists of the data-target pairs \(\{\textbf{x}_i, y_i \}\). The goal is to impose that decisions should be made based on a small number of prototypes. Thus, the following principles were proposed:

  1. The network has a encoder component represented by \(v_i = f(\textbf{x}_i; \theta_f)\) that encodes the information from \(\textbf{x}_i\) for the final decision. \(\theta_f\) is a set of learnable parameters.
  2. The decision component is represented by \(g(v_i; \theta_g)\). It maps the \(v_i\) values into the target space so that \(g(v_i; \theta_g) \approx y_i\). \(\theta_g\) is a set of learnable parameters.
  3. Given a sample \(\textbf{x}_i\) and the prototype candidates \(\{\textbf{x}_j^{(c)} \}_{j=1}^D\), there exist weights \(p_{i,j}\) that relate \(\textbf{x}_i\) to each candidate \((p_{i,j} > 0, \sum_j^D p_{i,j} = 1)\) s.t. \(g(\sum_j^D p_{i,j} v_j^{(c)}; \theta_g)\) is close to \(y_i\).
  4. Prototypes with higher weights have a higher contribution to the decision.
  5. Weights should be sparse.
  6. Weights depend on the relation between the input and the candidates. \(p_{i,j} = r(\textbf{x}_i, \textbf{x}_j^{(c)}; \theta_r)\). \(\theta_r\) is a set of learnable parameters.

Fig. 1. ProtoAttend method for training and testing [1]. During training, the relational mechanism is taught to select the most relevant prototypes by comparing the encoded values of the input and the prototype candidates. During testing, the encoded values of the input are not needed, as only those from the prototypes are used.

Fig. 1 shows the ProtoAttend network architecture and how it behaves during training and testing. The main three blocks are described below:

-Encoder: The encoder component, denoted as \(E\), is a trainable feature mapping function. It is used to process the input samples and the samples from the database of prototype candidates. \(E\) transforms the input batch \(\{\textbf{x}_i \}_{i=1}^B\) into values \(V \in R ^{B \times d_{out}}\) (only during training) and queries \(Q \in R ^{B \times d_{att}}\). In addition, \(E\) transforms the candidate batch \(\{ x_j \}_{j=1}^D\) into values \(V \in R ^{D \times d_{out}}\) and keys \(K^{(c)} \in R ^{D \times d_{att}}\).

-Relational Attention: The relational attention component is a mechanism that yields the weights \(p_{i, j}\) via alignment of the input sample query and candidate key: \(p_{i,j} = n \left( K_j^{(c)} Q_i^\top \right) / \sqrt{d_{att}}\), where \(n(\cdot)\) is a normalization function (e.g., softmax).

-Decision Making: This is the final block and consists of a linear mapping that combines the “values” from the input samples and prototype candidates. Hence, \(\hat{y}_i(\alpha) = g \left( (1 - \alpha) v_i + \alpha \sum_{j=1}^D p_{i,j} v_j^{(c)} \right)\), where \(\alpha\) is a tunable parameter. When \(\alpha = 0\), the relational mechanism is not trained (equivalent to conventional supervised learning). When \(\alpha = 1\), the encoded value from the input is not taken into account. Thus, the authors proposed to use a loss function of the form \(\mathcal{L} = L(y, \hat{y}(0)) + L(y, \hat{y}(1)) + L(y, \hat{y}(0.5))\), where \(L(\cdot)\) is a penalty function such as cross-entropy.

The goal is that input samples should be represented with a few representative prototypes. As such, the prototype weights should be sparse. Similar to the loss function used to train SENNs (see previous [post]{https://giorgiomorales.github.io/Interpretable-Deep-Neural-Networks-Part-1-Self-Explaining-Neural-Networks/}), sparsity can be encouraged using regularization. Here, the regularization term introduced into the loss function aims to reduce entropy of the weights: \(L_{sparse} (\textbf{p}) = -1/B \sum_{i=1}^B \sum_{j=1}^D p_{i, j} \log (p_{i, j} + \epsilon)\), where \(\epsilon\) is a small number for numerical stability. Then, the new loss function is: \(\mathcal{L} = L(y, \hat{y}(0)) + L(y, \hat{y}(1)) + L(y, \hat{y}(0.5)) + \lambda_{sparse} L_{sparse}\).

In a classification setting, there exists more confidence if the prototypes that correspond to the higher weights belong to the same class. Thus, a confidence score based on the agreement between prototypes was proposed: \(C_i = \sum_{j=1}^D p_{i, j} \cdot \textbf{I}(y_j^{(c)} = \hat{y}_i)\), where \(\textbf{I}(\cdot)\) is the indicator function. This is used as part of a new penalty term: \(L_{conf}(p) = -\frac{1}{B} \sum_{i=1}^B C_i\). The complete loss function is (Eq. 1): \begin{equation} \mathcal{L} = L(y, \hat{y}(0)) + L(y, \hat{y}(1)) + L(y, \hat{y}(0.5)) + \lambda_{sparse} L_{sparse} + \lambda_{conf} L_{conf}. \end{equation}

Experimental Results

Experiments were carried out using images, text, and tabular data classification problems. For image classification, the MNIST, Fashion-MNIST, and CIFAR-10 datasets were used. For text classification, the DBPedia dataset was used. For tabular data classification, the Income dataset was used. 2-D Convolutional neural networks (CNNs) were used for image classification, 1-D CNNs were used for text classification, and a long-short term memory (LSTM) architecture was used for tabular data classification.

Table 1 shows the classification accuracy comparison between the baseline encoder (trained using conventional supervised learning) and that using variations of the ProtoAttend mechanisms. In particular, the variations refer to the normalization functions used for the relational attention mechanism (softmax and sparsemax) and if the network was trained using sparsity regularization. Table 1 also shows the median number of prototypes required to add up to a particular portion of the decision (50%, 90%, or 95%). Fig. 2 depicts example inputs and the two most influential prototypes selected by ProtoAttend.

Table 1. Performance with and without the ProtoAttend mechanism [1].

Fig. 2. Example inputs and ProtoAttend prototype for (left) MNIST (with sparsemax) and (right) Fashion MNIST dataset (with sparsemax and sparsity regularization) [1].

Critique

The main advantage of ProtoAttend is that it can add an interpretability component to a wide range of pre-trained neural network architectures. The proposed sparse attention mechanism and sparsity regularization allow to find a reduced set of prototypes similar to the model inputs. In fact, experiments showed that fewer than 10 prototypes per input are needed to explain at least 95% of the classification result. Experiments also showed that there was no statistical drop in performance after including the ProtoAttend mechanism with respect to the performance of the baseline networks.

Similar to SENNs, the prototypes retrieved by ProtoAttend consist of representative data samples. One disadvantage with respect to SENNs is that ProtoAttend does not exploit the notion of basic interpretable concepts explicitly. That is, interpretation is purely dependent on the human realization that a given input is visually similar to a small set of prototypes (representative raw data samples). In the case that the input data consist of hundreds of features, retrieving prototypes with the same number of features would not be considered interpretable.

Furthermore, using the three terms \(L(y, \hat{y}(0))\), \(L(y, \hat{y}(1))\), and \(L(y, \hat{y}(0.5))\) in the loss function (Eq. 1) seems to be redundant. It is clear that we should avoid minimizing \(L(y, \hat{y}(0))\) only or \(L(y, \hat{y}(1))\) only (i.e., we look for \(0 < \alpha < 1\)) in order to accomplish the six proposed principles. Nevertheless, minimizing \(L(y, \hat{y}(0.5))\) (or another \(\alpha\) value) should be enough to achieve a trade-off between the encoder learning task and the relational mechanism learning task. An ablation study with different forms for the loss function was presented in the paper. The authors decided to use the three terms \(L(y, \hat{y}(0)) + L(y, \hat{y}(1)) + L(y, \hat{y}(0.5))\) instead of just \(L(y, \hat{y}(0.5))\) because it yielded an apparent improvement in classification accuracy. We argue that these results are not conclusive. The reported difference in performance did not seem to be statistically significant (e.g., 94.25% against 94.45%). Besides, the comparison was carried out using only one dataset (Fashion-MNIST) and a single training-test split.

The ProtoAttend method resembles an information retrieval system that uses queries and keys to find similarities between the input samples and the set of prototype candidates. As such, the encoding mechanism should have been given more importance since it is the one that is in charge of transforming the \(\textbf{x}_i\) samples into small vectors of \(d_{att}\) features (\(d_{att}\) can be seen as the number of encoding bits for the queries and keys). We may argue that this reduced set of features can be interpreted as basis concepts; however, it was not explained what type of features/concepts were learned by the system or what they represented.

It was also not explained the role that the selection of the number \(d_{att}\) played (what is more, it was not mentioned what \(d_{att}\) was used in the experiments). Intuitively, we could claim that the more encoding bits used, the more complex the encoded information is and the less interpretable it is at the same time. Thus, further research could focus on the trade-off between the classification performance and the interpretability of the encoded information for various values of \(d_{att}\).

References

  1. S. O. Arik and T. Pfister, “Protoattend: Attention-based prototypical learning,” The Journal of Machine Learning Research, vol. 21, no. 1, pp. 8691–8725, 2020.
Written on November 7, 2022