Interpretable Deep Neural Networks Part 3: Concept Whitening

In this third part of our review on interpretable deep neural networks by design, we discuss a method called Concept Whitening. Recall that both SENNs (Part 1) and ProtoAttend (Part 2) use encoders that transform the original data input into a reduced representation using several neural network layers. However, the encoded information in the latent space is challenging to be interpreted directly. Hence, Chen, Bei, and Rudin [1] introduced a mechanism called concept whitening (CW) that alters the information generated by a layer of the network and constrains the latent space to represent target concepts.

Methodology

Let \(\{ \textbf{x}_1, \dots, \textbf{x}_n\} \in \mathcal{X}\) represent the data samples and \(\{ y_1, \dots, y_n\} \in \mathcal{Y}\) their labels. A DNN classifier can be divided in two parts: a feature extractor \(\phi: \mathcal{X} \rightarrow \mathcal{Z}\) and a classifier \(g: \mathcal{Z} \rightarrow \mathcal{Y}\), where \(\mathcal{Z}\) is the latent space defined by a hidden layer. Furthermore, \(\textbf{z} = \phi(\textbf{x}; \theta_\phi)\) is the latent representation of the input \(\textbf{x}\) and \(f(\textbf{x}) = g(\textbf{z}; \theta_g)\) is the predicted label. \(\theta_\phi\) and \(\theta_g\) are sets of learnable parameters. The \(k\) concepts and their corresponding auxiliary datasets are denoted as \(\{ c_i \}_{i=1}^k\) and \(\{ \textbf{X}_{c_i} \}_{i=1}^k\), respectively (the samples in \(\textbf{X}_{c_i}\) are the most representative samples of concept \(c_i\)). Thus, the goal is not only to predict labels accurately but also to align \(z_i\) (i.e., the \(i\)-th dimension of the latent representation \(\textbf{z}\)) with concept \(c_i\).

The CW module consists of two parts: whitening and orthogonal transformation. The first is denoted by \(\psi\). It decorrelates and standardizes the data by \(\psi(\textbf{Z}) = \textbf{W} (\textbf{Z} - \mu \textbf{1}_{n \times 1}^\top)\), where \(\textbf{Z} \in \mathbb{R}^{d \times n}\) is the latent representation of a batch of \(n\) samples (each \(\textbf{z}_i \in \textbf{Z}\) is represented with \(d\) features), \(\mu\) is the sample mean, and \(\textbf{W} \in \mathbb{R}^{d \times d}\) is the whitening matrix such that \(\textbf{W}^\top \textbf{W} = \Sigma^{-1}\) (\(\Sigma\) is the covariance matrix of \(\textbf{Z}\)).

Note that \(\textbf{W}\) is rotation free, which means that \(\textbf{W}' = \textbf{Q}^\top \textbf{W}\) is also a whitening matrix if \(\textbf{Q}\) is an orthogonal matrix. Therefore, the second part of the CW module, the orthogonal transformation, consists of rotating the samples in their latent space so that the data \(\textbf{X}_{c_i}\) are highly activated on the \(i\)-th axis. The matrix \(\textbf{Q}\) is found by optimizing (Eq. 1):

\[\max_{\textbf{q}_1, \dots, \textbf{q}_d} \sum_{i=1}^k \frac{1}{n_i} \textbf{q}_i ^ \top \psi(\textbf{Z}_{c_i}) \textbf{1}_{n_i \times 1}, \;\;\; s.t.\;\; \textbf{Q}^\top \textbf{Q} = \textbf{I}_d,\]

where \(\textbf{q}_i\) is the \(i\)-th column of \(\textbf{Q},\) \(\textbf{Z}_{c_i}\) is a \(d \times n_i\) matrix that denotes the latent representation of \(\textbf{X}_{c_i}\) and

\[| \textbf{X}_{c_i} | = n_i\]

Two types of data are handled during training: the dataset $\textbf{X}$ used for calculating the main objective, and auxiliary datasets \(\textbf{X}_{c_i}\) used for concept alignment. As such, the optimization of the model alternates two objectives. The first objective is the main objective:

\[\min_{\theta_\phi, \theta_g, \textbf{W}, \mu} \frac{1}{n} \sum_{i=1}^n \ell(g(\textbf{Q}^\top \psi (\phi(\textbf{x}_i; \theta_\phi); \textbf{W}, \mu ); \theta_g ), y_i),\]

where \(\ell(\cdot)\) is a loss function (e.g., cross-entropy for a classification setting), \(\phi\) and \(g\) are network layers located before and after the CW module, and the matrix \(\textbf{Q}\) is fixed during this step.

The second objective is the concept of alignment loss presented in Eq. 1. This is a linear programming problem with quadratic constraints (LPQC), which can be approximately solved using gradient methods on the Stiefel manifold. For the sake of computational efficiency, the second optimization of the CW module is carried out only every 20 batches. By doing so, the authors reported that there was no significant training speed slowdown during the experiments.

Experimental Results

The first set of experiments showed that replacing a few batch-normalization (BN) layers with CW modules did not affect classification accuracy significantly. To do this, four different well-known network architectures were trained on the Places365 dataset. The auxiliary concept datasets were extracted from the MS COCO dataset so that each label in this dataset corresponds to a concept (80 concepts in total). First, only three concepts were selected randomly per run and only one BN layer was replaced with a CW module. The classification accuracy after five runs was averaged. Results showed that the accuracy of the networks that used the CW module was on par with the original model.

More experiments were performed replacing more BN layers (up to 16) with CW modules and learning multiple concepts (up to nine). These results showed that the performance consistently decreased as more concepts were learned simultaneously; although, the drop in performance was not significant with respect to the performance of the original model.

Furthermore, the concept basis produced by the CW module were visualized to validate that the axes were aligned with their assigned concepts and to verify the interpretability benefits of equipping models with a CW module. Thus, Fig. 1 shows the top-10 largest activations for three different concepts (i.e. person, bed, and airplane). From Fig. 1.a, it can be verified that if CW is applied to an early layer (e.g., the 2nd layer of the network), it captures low-level information such as color or texture. However, all the top-activated images have the same semantic meaning when the CW module is applied to a higher layer of the network (e.g., the 16th layer).

Fig. 1. Top-10 Image activated on axes representing different concepts [1]. (a) Results when the 2nd layer (BN) is replaced by CW. (b) Results when the 16th layer (BN) is replaced by CW.

Fig. 1 helps us to understand that different layers express different levels of semantic meaning. Correspondingly, Fig. 2 allows us to track how the representation of an input image changes as the CW module is applied to different layers. Here, a point in the plot represents the percentile ranking of the activation value for each of the two axes (i.e., bed and airplane). The points are connected by arrows according to the depth of the layer in which the CW module is applied. As it happened in Fig. 1.a, we verified in Fig. 2 that lower layers of the network capture lower-level information. For example, Fig. 2.a, a bedroom image, shows that the ranking that corresponds to the “airplane” axis is higher because the blue color is a low-level feature that is more commonly associated with images of airplanes. Nevertheless, the deeper in the network the CW module is applied, the higher ranking is assigned to the “bed” axis as higher-level information is available.

Fig. 2. Percentile rank for the activation values of two input images on two axes (bed vs. airplane) throughout the network [1].

Finally, the properties of the spatial distribution of the concepts in the latent space were studied. Three approaches were used: a conventional CNN, a CNN that uses an auxiliary loss function whose goal is to classify different concepts in the latent space; and a CNN that uses a CW module in the 16-th layer instead of a BN module. Fig. 3 shows the normalized intra-concept and inter-concept similarities using the three approaches. From this, it can be seen that the greatest separability between latent representations of concepts was achieved by the CW module. When using the CW module, the inter-concept similarities were very small, which indicates that the concepts in the latent space of CW are nearly orthogonal, which is not the case with the other two approaches.

Fig. 3. Normalized intra-concept and inter-concept similarities [1]. (a) The 16-th layer is a BN module. (b), The 16-th layer is a BN module with auxiliary loss to classify these concepts. (c), The 16-th layer is a CW module.

Critique

The main contribution of this work is that the inter-concept similarity in the latent space is reduced drastically, yielding concept axes that are nearly orthogonal. This means that the principle of “diversity” that looks for non-overlapping basis concepts (see Part 1) is accomplished. In addition, the computational cost of this approach is viable considering that it is possible to use pre-trained convolutional neural networks that require just a few extra training epochs after the inclusion of the CW modules. Also, this can be implemented without a significant drop in accuracy performance with respect to the classical CNNs with BN modules.

A possible limitation of the CW module is that it requires the use of an additional set of auxiliary datasets for concept alignment. This approach is based on the assumption that the concepts of interest can be identified a priori and that there are accessible datasets from which we can learn to discern them.
However, in the likely case that expert knowledge is not available and that the set of basis concepts that are relevant for a given learning task is unknown, the implementation of the CW module would not be possible.

References

  1. Z. Chen, Y. Bei, and C. Rudin, “Concept whitening for interpretable image recognition,” Nature Machine Intelligence, vol. 2, no. 12, pp. 772–782, Dec. 2020.
Written on February 1, 2023