DeepASD: a deep adversarial-regularized graph learning method for ASD diagnosis with multimodal data

0
DeepASD: a deep adversarial-regularized graph learning method for ASD diagnosis with multimodal data

DeepASD overview

DeepASD receives \(N\) patients data and each patient is associated with \(K\) modalities. Let \(\bfX=\\bfx_i\_i=1^N\) denotes the raw multimodal features of \(N\) patients, and \(\bfY=\\bfy_i\_i=1^N\) denotes the corresponding labels of each patient. For patient \(i\), the feature \(\bfx_i=\\bfx_i^m_1,\bfx_i^m_2,\cdots ,\bfx_i^m_k\\) is composed of \(K\) modalities, and we denote the modality by superscript \(m_k\). For example, \(\bfx_i^m_k\) denotes the \(d_m_k\)-dimensional features of \(k\)-th modality of the \(i\)-th patient. As shown in Fig. 1, for each modality, we fed \(\bfx^m_k\) into feature extractor \(f_m_k\left(\cdot \right)\) as input (Section “Feature extractor”). The \(f_m_k\left(\cdot \right)\) generates aligned feature representation in \(\mathbbR^N\times d_c\) through the multimodal adversarial-regularized encoder (Section “Multimodal adversarial-regularized encoder”), where \(d_c\) represents the dimension of the aligned subspace. Then, in the multi-graph fusion GNN module (Section “Graph neural networks for multi-graph fusion”), we aggregate adjacency matrix \(\A\_k=1^K\) and node embeddings \(f\left(\cdot \right)=\f_m_k\left(\bfx^m_k\right)\_k=1^K\) in the patient similarity network (generated by all modalities), in order to generate one fusion adjacency matrix \(\mathscrA\mathscr\in \mathbbR^N\times N\). Thus, a patient network \(G=\left(V,E,\bfX\right)\) is built for ASD diagnosis, in which nodes represent patients \(V\) and \(\mathscrA_ij\mathscr\in \mathscrA\) denotes the edge weights of \(e_ij\in E\). From the network, we employ Simple Spectral Graph Convolution (\(\textS^2\)GC) [35] and a one-layer Multilayer Perceptron (MLP) to predict labels of nodes (i.e., the ASD prediction result of each patient).

Feature extractor

Since each modality has its specific characteristics and patterns, we design a modality-specific extractor for each modal that feds each modal data into the module and projects the samples in each modality into a \(d_c\)-dimensional subspace, where we align the dimension of different modalities. Each feature extractor for each modal \(\f_m_k(\cdot )\!\!:\mathbbR^d_m_k\to \mathbbR^d_c\_k=1^K\) is consisted of two-layer fully connected networks with the Leaky ReLU activation function, following with a one-layer fully connected networks for classification.

Multimodal adversarial-regularized encoder

Adversarial networks [36, 37] have demonstrated effectiveness to align different data distributions. Since the features are heterogeneous across the modalities and each modality provides distinct information in terms of other modalities, we develop a multimodal adversarial-regularized encoder method to eliminate the feature heterogeneity and reduce the distributional divergence. As shown in Fig. 1b, we construct two competitive modules: the modal discriminator and the feature extractor. The modal discriminator \(d\left(\cdot \right)\) aims to distinguish the modality of features, while the feature extractor \(f\left(\cdot \right)\) attempts to against the former. By leveraging the adversarial learning manner, we are able to obtain aligned distributions from all modalities through training a competitive loss \(\mathscrL_d,f\) (Eq. (1)) that minimizes over \(d\left(\cdot \right)\) but maximizes over \(f\left(\cdot \right)\).

$$\mathscrL_d,f=\mathop\sum \limits_k=1^K\frac1N\mathop\sum \limits_i=1^NL_s\left[d(f_m_k\left(\boldsymbolx_i^m_k\right),\boldsymbolz_i^m_k\right].$$

(1)

where \(L_s\left[\cdot ,\cdot \right]\) is the squared loss, and \(\bfz_i^m_k\) is the one-hot labels of \(\bfx_i^m_k\). In addition, there is a classification loss as shown in Eq. (2).

$$\boldsymbolL_f=\mathop\sum \limits_k=1^K\frac1N\mathop\sum \limits_i=1^NL_c\left[f\left(\boldsymbolx_i^m_k\right),\boldsymboly_i\right]+\tau \left(\parallel f\parallel ^2\right)$$

(2)

where \(L_c\left[\cdot ,\cdot \right]\) is the cross-entropy loss and \(\tau\) is a positive regularization parameter. The model tends to generate better discriminability via the feature extractor \(f\left(\cdot \right)\) by minimizing \(\mathscrL_f\). By combining two losses together, we have

$$\mathop\min \limits_f\mathop\max \limits_d\mathscrL_f-\beta \mathscrL_d,f$$

(3)

where \(\beta\) is the trade-off parameter between the classification loss and the modal discriminator loss.

To alleviate the problem that the adversarial optimization on the \(\mathscrL_d,f\) term may lead to vanishing gradients if \(f\left(\cdot \right)\) and \(g\left(\cdot \right)\) are not well synchronized, we adopt the invert label loss [38] as defined in Eq. (4):

$$\hat\mathscrL_d,f=\mathop\sum \limits_k=1^K\frac1N\mathop\sum \limits_i=1^NL_s\left[d(f_m_k\left(\boldsymbolx_i^m_k\right),{\hat\boldsymbolz}_i^m_k\right].$$

(4)

where \({\hat\bfz}_i^m_k\) is the one-hot inverted label of \(\bfx_i^m_k\) in each modal. Thus, the objective in Eq. (5) can be reformulated as

$$\mathop\min \limits_f\mathscrL_f+\beta \hat\mathscrL_d,f,\mathop\min \limits_d\mathscrL_d,f$$

(5)

Graph neural networks for multi-graph fusion

For multimodal features learned from the multimodal adversarial-regularized encoder, we apply a learnable cosine similarity [39] method in Eq. (6) to learn an inductive patient similarity graph, as follows,

$$A_ij=\frac\left(W_Af_i\right)^TW_Af_j\parallel W_Af_i\parallel \parallel W_Af_j\parallel $$

(6)

where \(A_ij\) is the learned similarity matrix between patient \(i\) and \(j\), \(W_A\) is the learned node embedding, \(f_i\) and \(f_j\) are the features from feature extractor \(\{f_m_k\left({\bfx}_i^m_k\right)\}_k=1^K\). We also employ a threshold \(\theta\) to constrain the similarity strength between each node.

For each modality, we learn an adjacency matrix \(\A\_k=1^K\) to capture patient relationships (a.k.a., similarity) in different modality, and then we combine all patient graphs into one graph whose adjacency matrix is \(\mathscrA\) by weighted sum operation such that \(\mathscrA\mathscr=\mathop\sum \limits_k=1^Kw_kA_k.\) We obtain fused feature representations \(\hatf\) by concatenating the aligned features from \(\{f_m_k\left({{\bfx}}_i^m_k\right)\}_k=1^K\). Based on the learned graph structure \(\mathscrA\mathscr\in \mathbbR^N\times N\) and fusion feature \(\hatf\in {\mathbbR}^d_c\), we apply \(\textS^2\)GC [35] and one-layer MLP for downstream ASD diagnosis task.

A spectral convolution of a graph signal \(x\) with a filter \(g_\theta \) is defined as \(g_\theta \star x=Ug_\theta U^Tx\) where \(U\) is the matrix of eigenvectors of the normalized graph Laplacian \(L=\textI-\textD^-\frac12\mathscrA\textD^\frac12\) with respect to the diagonal degree matrix \(D\). By the renormalization trick, we use a normalized version \(\widetilde\rmT=\widetildeD^-\frac12\mathscrA\widetildeD^\frac12=\left(\textD+\textI_\textn\right)^-\frac12\left(\mathscrA+\textI_\textn\right)\left(\textD+\textI_\textn\right)^-\frac12\) to replace the matrix \(\textI-\textD^-\frac12\mathscrA\textD^\frac12\) Motivated by Markov Diffusion Kernel [40], \(\textS^2\)GC includes self-loops and its final output can be defined as follows,

$$\hatY=softmax\left(\frac1C\mathop\sum \limits_c=1^C\left(\left(1-\alpha \right)\widetilde\boldsymbolT^c\boldsymbolX+\alpha \boldsymbolX\right)\boldsymbolW\right)$$

(7)

where W represents the network parameter, and \(\alpha\) is a trade-off between the self-information of a node and its consecutive neighborhoods. To this end, \(\alpha\) is typically set to 0.05 with a range of values between 0 and 1 in the experiments. The term \(\widetilde\bfT^c\bfX\) is computed as \(\widetilde\bfT\cdot (\widetilde\bfT\cdot (\cdots (\widetilde{\bfT}\bfX)\cdots ))\), where the multiplication is iteratively applied \(C\) times.

Then, we build a final classifier to identify ASD. The classifier is composed of two-layer graph convolution layers with ReLU activation, followed by a fully connected layer. The loss function for this graph learning classification is given by Eq. (8)

$$\mathscrL_g=L_c\left[g\left(\mathscrA\mathscr,\hatf\right),\boldsymboly\right]$$

(8)

where \(L_c\left[\cdot ,\cdot \right]\) is the cross-entropy loss, \(g\left(\cdot ,\cdot \right)\) is the GCN model, \(\bfy\) is the one-hot label of each patient data.

By combining all the aforementioned losses, we adopt the following joint loss function to guide the optimization of all modules simultaneously:

$$\mathscrL\mathscr=\mathscrL_g+\eta \left({\mathscrL}_f+\beta {\hat{{\mathscrL}}}_d,f\right)$$

(9)

where \(\eta\) and \(\beta\) are hyper-parameters to balance the loss terms.

Methodologies for training and testing

We first competitively train feature extractor \(f\left(\cdot \right)\) and modal discriminator \(d\left(\cdot \right)\) once, and then jointly train the graph construction embedding \(W_A\) and GCN \(g\left(\cdot ,\cdot \right)\) once in one training epoch. Through this training strategy, we are able to simultaneously obtain significant patient representations and patient similarity graphs with high prediction accuracy.

After selecting hyper-parameters through preliminary experiments, we optimize \(f\left(\cdot \right)\), \(d\left(\cdot \right)\), \(W_A\), \(g\left(\cdot ,\cdot \right)\) via the Adam optimizers [41] with learning rates of 0.004, 0.001, 0.001, 0.001, respectively. We empirically set \(\beta =0.03\), \(\tau =0.004\), and \(\eta =1\), and the other hyper-parameters are fine-tuned according to the dataset size.

We implement the proposed DeepASD using PyTorch. All experiments were conducted with a 10-fold cross-validation to divide the dataset into training and test sets, with 10% of the training set randomly selected as the validation set. Ultimately, the training, validation, and test sets were non-overlapping. The training of the model for 500 epochs on ABIDE A and ABIDE B datasets, utilizing a single Tesla V100 GPU, took approximately 15 min and 45 min, respectively.

link

Leave a Reply

Your email address will not be published. Required fields are marked *