You are currently viewing a new version of our website. To view the old version click .
Entropy
  • Article
  • Open Access

31 October 2024

An Empirical Study of Self-Supervised Learning with Wasserstein Distance

,
,
,
,
,
and
1
Machine Learning and Data Science Unit, Okinawa Institute of Science and Technology, Okinawa 904-0412, Japan
2
Center for Advanced Intelligence Project RIKEN, Tokyo 103-0027, Japan
3
Department of Intelligence Science and Technology, Kyoto University, Kyoto 606-8501, Japan
4
Paris-Saclay Ecole Normale Superieure, 75005 Paris, France
This article belongs to the Special Issue Entropy in Real-World Datasets and Its Impact on Machine Learning II

Abstract

In this study, we consider the problem of self-supervised learning (SSL) utilizing the 1-Wasserstein distance on a tree structure (a.k.a., Tree-Wasserstein distance (TWD)), where TWD is defined as the L1 distance between two tree-embedded vectors. In SSL methods, the cosine similarity is often utilized as an objective function; however, it has not been well studied when utilizing the Wasserstein distance. Training the Wasserstein distance is numerically challenging. Thus, this study empirically investigates a strategy for optimizing the SSL with the Wasserstein distance and finds a stable training procedure. More specifically, we evaluate the combination of two types of TWD (total variation and ClusterTree) and several probability models, including the softmax function, the ArcFace probability model, and simplicial embedding. We propose a simple yet effective Jeffrey divergence-based regularization method to stabilize optimization. Through empirical experiments on STL10, CIFAR10, CIFAR100, and SVHN, we find that a simple combination of the softmax function and TWD can obtain significantly lower results than the standard SimCLR. Moreover, a simple combination of TWD and SimSiam fails to train the model. We find that the model performance depends on the combination of TWD and probability model, and that the Jeffrey divergence regularization helps in model training. Finally, we show that the appropriate combination of the TWD and probability model outperforms cosine similarity-based representation learning.

1. Introduction

Unsupervised learning is a widely studied topic, and includes autoencoders [] and variational autoencoders (VAEs) []. Self-supervised learning (SSL) algorithms, including SimCLR [], Bootstrap Your Own Latent (BYOL) [], MoCo [,], SwAV [], SimSiam [], and DINO [], can also be regarded as unsupervised learning methods.
One of the main self-supervised algorithms adopts contrastive learning, in which two data points are systematically generated from a common data source, and lower-dimensional representations are found by maximizing the similarity between the positive pairs while minimizing the similarity between negative pairs. Depending on the context, positive and negative pairs can be defined differently. For example, in SimCLR [], positive pairs correspond to images generated by independently applying different visual transformations, such as rotation and cropping. In multimodal learning, however, positive pairs are defined as the same examples corresponding in different modalities, such as images and text []. The flexibility of formulating positive and negative pairs also makes contrastive learning widely applicable beyond the image domain. This is a powerful pre-training method, because SSL does not require label information and can be trained using several data points.
In addition to contrastive learning-based SSL, non-contrastive approaches, such as BYOL [], SwAV [], and SimSiam [], have been widely used. The fundamental concept of non-contrastive approaches involves the utilization of momentum and/or stop-gradient techniques to prevent mode collapse, as opposed to relying on negative sampling. Many of these approaches employ negative cosine similarity as a loss function. However, a limited number of SSL methods utilize distribution measures, such as cross-entropy, as exemplified by DINO [], and simplicial embedding [].
In this paper, leveraging the idea of distribution measures, for the first time we empirically investigate SSL performance using the Wasserstein distance. The Wasserstein distance, a widely adopted optimal transport-based distance for measuring distributional discrepancies, is useful in various machine learning tasks, including generative adversarial networks [], document classification [,], image matching [], and algorithmic fairness [,]. The 1-Wasserstein distance is also known as the earth mover’s distance (EMD) and the word mover’s distance (WMD) [].
In this study, we consider an SSL framework with a 1-Wasserstein distance under a tree metric (i.e., Tree-Wasserstein distance (TWD)) [,]. TWD includes the sliced Wasserstein distance [,] and total variation as special cases, and can be represented by the 1 distance between two vectors. Due to the fact that TWD is given as a non-differentiable function, learning simplicial representations through back-propagation of TWD is challenging. Moreover, because the Wasserstein distance is computed from probability vectors, and several representations of probability vectors exist, it is difficult to determine which is most suitable for SSL training. Hence, we investigate a combination of probability models and the structure of TWD. Specifically, we consider the total variation and ClusterTree for TWD structure and show that the total variation is equivalent to a robust variant of TWD. In terms of the probability representations, we propose the combined use of softmax, an ArcFace-based probability model [], and simplicial embedding (SEM) []. Finally, to stabilize the training, we propose a Jeffrey divergence-based regularization. Through SSL experiments, we find that the standard softmax formulation with back-propagation yields poor results. In particular, the non-contrastive SSL case fails to train the model with a simple combination of the Wasserstein distance and softmax function. For total variation, the ArcFace-based model performs well. By contrast, SEM is suitable for ClusterTree, whereas ArcFace-based models achieve modest performance. Moreover, the proposed regularization significantly outperforms its non-regularized counterparts.
Contribution: The contributions of this study are summarized below:
  • We propose to use the tree Wasserstein distance for self-supervised learning including SimCLR and SimSiam for the first time.
  • We investigate the combination of probability models and TWD (total variation and ClusterTree). We find that the ArcFace model with prior information is suited for total variation, while SEM [] is suited for ClusterTree models.
  • We propose a robust variant of TWD (RTWD) and show that RTWD is equivalent to total variation.
  • We propose the Jeffrey divergence regularization for TWD minimization, and find that the regularization significantly stabilizes training.
  • We demonstrate that the combination of TWD and probability models can obtain better performance in self-supervised training for CIFAR10, STL10, and SVHN compared to the cosine similarity in SimCLR experiments, while the performance of CIFAR100 can be improved further in the future.

3. Background

3.1. Self-Supervised Learning Methods

SimCLR []: Given n input vectors { x i } i = 1 n , where x i R d , define the data transformation functions u ( 1 ) = ϕ 1 ( x ) R d and u ( 2 ) = ϕ 2 ( x ) R d . In the context of image applications, u ( 1 ) and u ( 2 ) can be understood as two image transformations over a given image: translation, rotation, blurring, etc. The neural network model is denoted as f θ ( u ) R d out , where θ is a learnable parameter.
SimCLR attempts to train the model by learning features such that z ( 1 ) = f θ ( u ( 1 ) ) and z ( 2 ) = f θ ( u ( 2 ) ) are close after the feature mapping, while ensuring that both are distant from the feature map of u , where u is a negative sample generated from a different input image. To this end, InfoNCE loss [] is employed in the SimCLR model:
InfoNCE z i ( 1 ) , z i ( 2 ) = log exp sim z i ( 1 ) , z i ( 2 ) / τ Z ¯ = sim ( z i ( 1 ) , z i ( 2 ) ) / τ + log ( Z ¯ ) ,
where Z ¯ = k = 1 2 R δ k i exp ( sim ( z i ( 1 ) , z ˜ k ) / τ ) is the normalizer, R is the batch size and sim ( z , z ) is a similarity function that takes a higher positive value when z and z are similar and a smaller (positive or negative) value when z and z are dissimilar. τ is the temperature parameter, and δ k i is a delta function that takes a value of 1 when k i and 0 otherwise. In contrastive learning, we aim to minimize the InfoNCE loss function. To achieve an optimal solution, we need to maximize the similarity sim z i ( 1 ) , z i ( 2 ) while minimizing log ( Z ) . The first term aims to make z i ( 1 ) and z i ( 2 ) as similar as possible. The second term is a log-sum-exp function, which can be interpreted for small τ as
log ( Z ) = log k = 1 2 R δ k i exp ( sim ( z i ( 1 ) , z ˜ k ) / τ ) , max k ( sim ( z i ( 1 ) , z ˜ k ) ) .
By minimizing log ( Z ) , we can make z i ( 1 ) dissimilar to the negative samples z ˜ k . Due to the fact that we attempt to minimize the maximum similarity between input z i and its negative samples, we can make z i and its negative samples dissimilar via the second term.
In SimCLR, the parameters are learned by minimizing the InfoNCE loss.
θ ^ : = argmin θ i = 1 n InfoNCE f θ ( u i ( 1 ) ) , f θ ( u i ( 2 ) ) .
SimSiam []: SimSiam is a non-contrastive learning method; it does not use negative sampling to prevent mode collapse. In place of negative sampling, SimSiam employs a stop-gradient method. Specifically, the loss function is given by
L S i m S i a m ( θ ) = 1 2 L 1 ( θ ) + 1 2 L 2 ( θ ) , L 1 ( θ ) = 1 n i = 1 n h ( z i ) z ¯ i h ( z i ) 2 z ¯ i 2 , L 2 ( θ ) = 1 n i = 1 n z ¯ i h ( z i ) z ¯ i 2 h ( z i ) 2 ,
where h ( · ) is the MLP head, z i is a latent variable, and z ¯ i = StopGradient ( z i ) is a latent variable with a stop gradient.

3.2. p-Wasserstein Distance

The p-Wasserstein distance between two discrete measures, μ = i = 1 n ¯ a i δ x i and μ = j = 1 m ¯ a j δ y j is given by
W p ( μ , μ ) = min Π U ( μ , μ ) i = 1 n ¯ j = 1 m ¯ π i j d ( x i , y j ) p 1 / p ,
where U ( μ , μ ) denotes the set of transport plans and U ( μ , μ ) = { Π R + n ¯ × m ¯ : Π 1 m ¯ = a , Π 1 n ¯ = a } . The Wasserstein distance can be computed using a linear program. However, because this includes an optimization problem, the computation of Wasserstein distance for each iteration is computationally expensive.

3.3. 1-Wasserstein Distance with Tree Metric (Tree-Wasserstein Distance)

Another 1-Wasserstein distance is based on trees [,]. The 1-Wasserstein distance between two probability distributions μ = i = 1 N leaf a i δ x i and μ = j = 1 N leaf a j δ y j with the tree metric is defined as
W T ( μ , μ ) = min Π U ( μ , μ ) i = 1 N leaf j = 1 N leaf π i j d T ( x i , y j ) ,
where d T ( x , y ) is the length of the shortest path between x and y on the tree and N leaf is the number of leaf nodes. TWD can be further represented by the closed form as follows []:
W T ( μ , μ ) = e E w e | μ ( Γ ( v e ) ) μ ( Γ ( v e ) ) | ,
where e is an edge index, w e R + is the edge weight of edge e, v e is the eth node index, and μ ( Γ ( v e ) ) is the total mass of the subtree with root v e . This closed form solution can be further represented as the L1 distance []:
W T ( μ , μ ) = diag ( w ) B a diag ( w ) B a 1 ,
where B { 0 , 1 } N node × N leaf is a tree parameter, [ B ] i , j = 1 if node i is the ancestor node of leaf node j and zero otherwise, N node is the total number of nodes of a tree, and w R + N node is the edge weight.
For illustration, we provide two examples to demonstrate the B matrix by considering a tree with a depth of one and a ClusterTree, as shown in Figure 1. If all edge weights w 1 = w 2 = = w N = 1 2 in the left panel of Figure 1, then the B matrix is given as B = I . By substituting this result into the TWD, we obtain
W T ( μ , μ ) = 1 2 a a 1 = a a TV .
Figure 1. Left tree corresponds to the total variation if we set the weight as w i = 1 2 , i . Right tree is a ClusterTree (2 class).
Thus, the total variation is a special case of TWD. In this setting, the shortest-path distance in the tree corresponds to the Hamming distance. Note that Raginsky et al. [] also assert that the 1-Wasserstein distance with the Hamming metric d ( x , y ) = δ x y is equivalent to the total variation (Proposition 3.4.1 in Raginsky et al. []).
The key advantage of the tree-based approach is that the Wasserstein distance is written in closed form, which is computationally efficient. A chain is included as a special case in the tree. Thus, the widely employed sliced Wasserstein distance is also included as a special case of TWD (Figure 2). Moreover, it has been empirically reported that TWD- and Sinkhorn-based approaches perform similarly in multilabel classification tasks [].
Figure 2. Tree for sliced Wasserstein distance for N leaf = 3 . The left figure is a chain and the right figure is the tree representation with internal nodes for the chain ( w 4 = w 5 = w 6 = 0 ).

4. SSL with 1-Wasserstein Distance

In this section, we first formulate SSL using TWD. We then introduce ArcFace-based probability models and Jeffrey divergence regularization.

4.1. SimCLR with Tree Wasserstein Distance

Let a and a be the embedding vectors of x and x (i.e., 1 a = 1 and 1 a ) with μ = j a j δ e j and μ = j a j δ e j , respectively. Here, e j is the virtual embedding corresponding to a j or a j . e is assumed unavailable in the problem setup. The main idea of this paper is to adopt the negative Wasserstein distance between μ and μ as the similarity score for SimCLR.
sim ( μ , μ ) = W T ( μ , μ ) .
We assume that B and w are given; that is, both the tree structure and weights are known. In particular, we consider the trees shown in Figure 1.
Following the original design of the InfoNCE loss and by substituting the similarity score given by the negative Wasserstein distance, we obtain the following simplified loss function:
θ ^ : = argmin θ i = 1 n W T ( μ i ( 1 ) , μ i ( 2 ) ) / τ + log k = 1 2 N δ k i exp W T ( μ i ( 1 ) , μ k ( 2 ) ) / τ ,
where τ > 0 is the temperature parameter for the InfoNCE loss. Although we mainly focus on the InfoNCE loss, the proposed negative Wasserstein distance as a measure of similarity can be used in other contrastive losses as well, e.g., the Barlow Twins.

4.2. SimSiam with Tree Wasserstein Distance

Here, we consider a combination of SimSiam and TWD. The loss function of the proposed approach is expressed as
L TWDSimsiam ( θ ) = 1 2 L 1 ( θ ) + 1 2 L 2 ( θ ) , L 1 ( θ ) = 1 n i = 1 n W T μ i ( 1 ) , μ ¯ i ( 2 ) , L 2 ( θ ) = 1 n i = 1 n W T μ ¯ i ( 1 ) , μ i ( 2 ) .
The distinction to the original SimSiam is that our formulation employs the Wasserstein distance, whereas the original formulation uses cosine similarity.

4.3. Robust Variant of Tree Wasserstein Distance

In our setup, it is difficult to estimate the tree structure B and edge weight w because the embedding vectors e 1 , e 2 , , e d out are unavailable. To address this problem, we consider a robust estimation of the Wasserstein distance, such as the subspace-robust Wasserstein distance (SRWD) [], for TWD. The key idea of SRWD is to solve an optimal transport problem in a subspace in which the distance is maximized. In the TWD case, we can consider solving the optimal transport problem for the maximum shortest-path distance. Specifically, for a given B , we propose the robust TWD (RTWD) as follows:
RTWD ( μ , μ ) = 1 2 min Π U ( μ , μ ) max w B i = 1 N leafs j = 1 N leafs π i j d T ( e i , e j ) ,
where B = { w R + N leaf : B w = 1 and w 0 } , d T ( e i , e j ) is the shortest-path distance between e i and e j , and e i and e j are embedded in a tree T . This constraint implies that the weights of the ancestor node of leaf node j are non-negative and sum to one.
Proposition 1. 
The robust variant of TWD (RTWD) is equivalent to total variation:
RTWD ( μ , μ ) = a a TV ,
where a a TV = 1 2 a a 1 denotes the total variation.
Proof. 
Let B { 0 , 1 } N × N leaf = [ b 1 , b 2 , , b N leaf ] and b i { 0 , 1 } N . The shortest-path distance between leaves i and j can be represented as []
d T ( e i , e j ) = w ( b i + b j 2 b i b j ) .
That is, d T ( e i , e j ) is represented by a linear function with respect to w for a given B and the constraints on w and Π are convex. Thus, strong duality holds, and we obtain the following representation using the minimax theorem [,]:
RTWD ( μ , μ ) = 1 2 max w s . t . B w = 1 and w 0 min Π U ( a , a ) i = 1 N leafs j = 1 N leafs π i j w ( b i + b j 2 b i b j ) = 1 2 max w s . t . B w = 1 and w 0 diag ( w ) B ( a a ) 1 ,
where TWD ( μ , μ ) = min Π U ( a , a ) i = 1 N leafs j = 1 N leafs π i j d T ( e i , e j ) = diag ( w ) B ( a a ) 1 .
Without loss of generality, we consider w 0 = 0 . First, we rewrite the norm diag ( w ) B ( a a ) 1 as
diag ( w ) B ( a a ) 1 = j = 1 N w j | k [ N leafs ] , k d e ( j ) ( a k a k ) | ,
where d e ( j ) denotes the set of descendants of node j [ N ] (including itself). Using the triangle inequality, we obtain
diag ( w ) B ( a a ) 1 j = 1 N w j k [ N leafs ] , k d e ( j ) | a k a k | = k [ N leafs ] | a k a k | j [ N ] , j p a ( k ) w j ,
where p a ( k ) is the set of ancestors of leaf k (including itself). By rewriting the constraint B w = 1 as j [ N ] , j p a ( k ) w j = 1 for any k [ N leafs ] , we obtain
diag ( w ) B ( a a ) 1 k [ N leafs ] | a k a k | = a a 1 .
The latter inequality holds for any weight vector w . Therefore, considering the vector such that w j = 1 if j [ N leafs ] and 0 otherwise, which satisfies the constraint B w = 1 , we obtain
diag ( w ) B ( a a ) 1 = k = 1 N leafs | a k a k | = a a 1 .
This completes the proof of the proposition. □
Based on this proposition, RTWD is equivalent to the total variation and does not depend on the tree structure B . That is, if we do not have prior information about the tree structure, using the total variation is a reasonable choice.

4.4. Probability Models

In this section, we discuss several choices of probability models for InfoNCE loss and SimSiam loss.
Softmax: The embedded vector with softmax function is given by
a θ ( x ) = Softmax ( f θ ( x ) ) ,
where f θ ( x ) is a neural network model.
Simplicial Embedding: Lavoie et al. [] proposed a simple yet efficient simplicial embedding method. Assume that the output dimensionality of a neural network model is d out . Then, SEM applies the softmax function to each V-dimensional vector of f θ ( x ) , where we have L = d out / V probability vectors. The th softmax function is thus defined as follows:
a θ ( x ) = a θ ( 1 ) ( x ) , a θ ( 2 ) ( x ) , , a θ ( L ) ( x ) with a θ ( ) ( x ) = Softmax f θ ( ) ( x ) / L ,
where f θ ( ) ( x ) ) R V is the -th block of a neural network model. We normalize the softmax function by L because a θ ( x ) must satisfy the sum-to-one constraint. Note that the softmax function can be regarded as a special case of simplicial embedding (where L = 1 ). In simplicial embedding, the softmax function is applied separately to each subset of the elements. For example, if d out = 10 and V = 5 , the softmax function is applied to each of the two five-dimensional vectors, and the results are then concatenated.
ArcFace model (AF): In comparison to SEM, we propose to employ the ArcFace probability model []. The ArcFace models employs cosine similarity in addition to softmax.
a θ ( x ) = S o f t m a x K f θ ( x ) / η ,
where K = [ k 1 , k 2 , , k d out ] R d out × d prob is a learning parameter, f θ ( x ) is the normalized output of a model ( f θ ( x ) f θ ( x ) = 1 ), and η is the temperature parameter. Note that AF has a structure similar to that of transformers [,]. The key difference from the original notion of attention in transformers is the normalization of the key matrix K and query vector f θ ( x ) .
AF with Positional Encoding: To the AF model, one can add one more linear layer and then apply the softmax function; then, the output is similar to the standard softmax function. Here, we propose replacing the key matrix with a normalized positional encoding matrix ( k i k i = 1 , i ):
k i = k ¯ i / k ¯ i 2 ,
where k ¯ i ( 2 j ) = sin ( i / 10 , 000 2 j / d out ) and k ¯ i ( 2 j + 1 ) = cos ( i / 10 , 000 2 j / d out ) .
AF with Discrete Cosine Transform Matrix: Another natural approach would be to utilize an orthogonal matrix as K . Therefore, we propose adopting a discrete cosine transform (DCT) [] matrix as K , where DCT is in general used for data compression for images. The DCT matrix is expressed as follows []:
k i ( j ) = 1 / d out ( i = 0 ) 2 d out cos π ( 2 j + 1 ) i 2 d out ( 1 i d out ) .
One of the contributions of this study is the finding that combining positional encoding and the DCT matrix with the ArcFace model significantly boosts performance, whereas the standard ArcFace model without these additions performs similarly to the softmax function.

4.5. Jeffrey Divergence Regularization

We empirically observed that optimizing the loss function described above is extremely challenging. In particular, the L1 distance cannot be differentiated at 0. Figure 3b illustrates the learning curve for standard optimization using the softmax function model.
Figure 3. InfoNCE loss and Top-1 (Train) comparisons on STL10 dataset.
To stabilize optimization, we propose including the Jeffrey divergence (JD) as a regularization term. JD is an upper bound of the square of the 1-Wasserstein distance.
Proposition 2. 
For B w = 1 and probability vectors a i and a j , we have
W T 2 ( μ i , μ j ) JD ( diag ( w ) B a i diag ( w ) B a j ) ,
where
JD ( diag ( w ) B a i diag ( w ) B a j ) = KL ( diag ( w ) B a i diag ( w ) B a j ) = + KL ( diag ( w ) B a j diag ( w ) B a i )
is a Jeffrey divergence.
Proof. 
The following holds if B w = 1 with the probability vector a (such that a 1 = 1 ).
1 diag ( w ) B a = 1 .
Then, using Pinsker’s Inequality, we derive the following inequalities:
W T ( μ i , μ j ) = diag ( w ) B a i diag ( w ) B a j 1 2 KL ( diag ( w ) B a i diag ( w ) B a j ) ,
and
W T ( μ i , μ j ) = diag ( w ) B a j diag ( w ) B a i 1 2 KL ( diag ( w ) B a j diag ( w ) B a i ) ,
Thus,
W T 2 ( μ i , μ j ) KL ( diag ( w ) B a i diag ( w ) B a j ) + KL ( diag ( w ) B a j diag ( w ) B a i )
This result indicates that minimizing the symmetric KL divergence (i.e., Jeffrey divergence) can minimize the tree-Wasserstein distance. Due to the fact that the Jeffrey divergence is smooth, the computation of the gradient of the upper bound is easier. For presentation, we denote W T ( μ ( 1 ) , μ ( 2 ) ) = W T ( a ( 1 ) , a ( 2 ) ) .
Note that Frogner et al. [] considered a multilabel classification problem utilizing the regularized Wasserstein loss. They proposed utilizing Kullback–Leibler divergence-based regularization to stabilize training. We derive the Jeffrey divergence from the TWD, and JD regularization includes a simple KL divergence-based regularization as a special case. Moreover, we propose employing JD regularization for SSL frameworks, which have not been extensively studied.

5. Experiments

This section evaluates SSL methods with different probability models.

5.1. Performance Comparison for SimCLR

For all experiments, we employed the Resnet18 model with an output dimension of ( d out = 256 ) and coded all the methods based on a standard SimCLR implementation (https://github.com/sthalles/SimCLR (accessed on 7 July 2023). We used the Adam optimizer and set the learning rate to 0.0003, the weight decay parameter to 1e-4, and temperature τ to 0.07. For the proposed method, we compared two variants of TWD: total variation and ClusterTree (Figure 1). As part of the model evaluation, we assessed the conventional softmax function, attention model (AF), and simplicial embedding (SEM) [] and set the temperature parameter τ = 0.1 for all experiments. For SEM, we set L = 16 and V = 16 .
We also evaluated JD regularization, where we set the regularization parameter λ = 0.1 for all experiments. For reference, we compared cosine similarity as a similarity function of SimCLR. For all approaches, we utilized the KNN classifier of the scikit-learn package (https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html (accessed on 7 July 2023)), where the number of nearest neighbor was set to K = 50 . We utilized the L1 distance for Wasserstein distances and cosine similarity for non-probability-based models. All the experiments were computed on A6000 GPUs. We ran all experiments three times and report the average scores.
Figure 3 illustrates the training loss and top-1 accuracy for the three methods: cosine + real-valued embedding, TV + softmax, and TV + AF (DCT). This experiment revealed that the convergence speed of the loss function was nearly identical across all methods. Regarding the training top-1 accuracy, cosine + real-valued embedding achieves the highest accuracy, followed by the softmax function, and AF (DCT) lags. This behavior is expected because real-valued embeddings offer the most flexibility, followed by softmax, with AF models exhibiting the least freedom. For all methods based on the TWD, JD regularization significantly aids the training process, particularly in the case of the softmax function. However, for AF (DCT), the improvement was relatively small. This is likely because AF (DCT) can also be considered a form of regularization.
Table 1 presents the experimental results for the test classification accuracy using KNN. The first observation is that the simple implementation of the conventional softmax function performs poorly (the performance is approximately 10 points lower) compared to cosine similarity. As expected, AF has only one more layer than the simple softmax model, and performs similarly to softmax. Compared to softmax and AF, AF (PE), and AF (DCT) significantly improve the classification accuracy for the total variation and ClusterTree cases. However, for the ClusterTree case, AF (PE) achieves a better classification performance, whereas the AF (DCT) improvement over the softmax model is limited. In the ClusterTree case, SEM significantly improves with the combination of ClusterTree and regularization. One potential reason of the performance improvement on TV + AF (DCT) combination and ClusterTree + SEM is that AF (DCT) utilizes the orthonormal DCT transform of the learned representation, while both SEM and ClusterTree have structures themselves. This means that each element of the final probability vector a θ can be uncorrelated for AF (DCT). As a result, the tree structure may not provide significant information, and the total variation (i.e., each leaf node connected to the root node) might be the best fit for the probability representation. Additionally, the cluster-like structure may conflict with the DCT-based representation. In contrast, SEM has an inherent structure and is computed without the DCT transformation (it learns a sum-to-one vector on subtrees). Therefore, the cluster tree structure and SEM can be a good match.
Table 1. KNN classification result with Resnet18 backbone. In this experiment, we set the number of neighbors as K = 50 and computed the averaged classification accuracy over three runs. Note that the Wasserstein distance with ( B = I d out ) is equivalent to total variation.
Overall, the proposed method performs better than cosine similarity without real-valued vector embedding when the number of classes is relatively small (i.e., STL10, CIFAR10, and SVHN). By contrast, the performance of the proposed method degrades for CIFAR100, and the results for ClusterTree are particularly poor. As the Wasserstein distance can be minimized even if it cannot overfit, it is natural for the Wasserstein distance to make mistakes when the number of classes is large. Note that the performances for CIFAR100 with simplicial representation degrade both cosine and TWD loss functions, and the performance degradation seems to come from the softmax operation. Moreover, the total variation is a robust measure and learning with total variation is generally designed to create models that are resilient to noise. In our setting, which involves self-supervised learning, it is likely that similar class representations could become mixed, leading to performance degradation. Since the proposed method performs well on CIFAR-10, we believe this could be the reason for the performance issues on larger datasets. To address this, it may be beneficial to use other types of regularizers or larger deep learning models.
Next, we evaluated the Jeffrey divergence regularization. Surprisingly, simple regularization dramatically improves the classification performance of all the probability models. These results support the idea that the main problem with Wasserstein distance-based representation learning is its numerical instability.
Among the methods, the proposed AF (DCT) + JD with total variation achieves the highest classification accuracy, comparable to the cosine similarity result, and achieves more than 10% improvement from the naive implementation with the softmax function. Moreover, all probability model performances with the cosine similarity combination tend to result in a lower classification error than those with the combination of the TWD and probability models. Based on our empirical study, we propose utilizing TWD (TV) + AF models or TWD (ClusterTree) + SEM for representation-learning tasks in probability-based representation learning.

5.2. Performance Comparison for SimSiam

Next, we evaluated the performance using a non-contrastive setup. For all experiments, we utilized the Resnet18-Cifar-Variant1 model with an output dimension of ( d out = 2048 ) and implemented all methods based on a standard SimSiam framework (https://github.com/PatrickHua/SimSiam). The optimization was performed using the SGD optimizer with a base learning rate of 0.03, weight decay parameter of 0.00005, momentum parameter of 0.9, batch size of 512, and a fixed number of epochs set to 800. For the proposed method, we employed the total variation as a loss function, along with the softmax function and ArcFace model (AF). The temperature parameter τ was set to 0.1 for all experiments. Additionally, we assessed JD regularization with the regularization parameter λ set to 0.1 across all experiments. A100 GPUs were used for all experiments, and each experiment was run three times, with the reported results being the average scores.
We compared the proposed methods, TWDSimSiam (softmax + JD) and TWDSimSiam (AF + JD), with the original SimSiam method which employs cosine similarity loss. Upon examination, we observe that learning the total variation with softmax encounters numerical issues, even with JD regularization (See Figure 4a,c). Conversely, the AF + JD combination proved successful in training the models, as shown in Figure 4b,c. One potential reason for the failure of TWD with softmax is that the total variation can easily become zero because the softmax function lacks normalization. For TWDSimSiam (AF + JD), normalization within the AF model prevents convergence to a poor local minimum. From a performance standpoint as shown in Table 2, the utilization of cosine similarity and total variation (TV) yield comparable results. However, a key contribution of this study is the introduction of a practical approach to enhance the model training stability by incorporating Wasserstein distance, specifically through total variation. This discovery has a potential utility in various SSL tasks.
Figure 4. TWD loss for SimSiam models.
Table 2. SimSiam evaluation with CIFAR10 dataset.

5.3. Effect of Number of Nearest Neighbors

In this section, we assess the performance of the KNN model by varying the number of nearest neighbors and setting K to 10 or 50. The results for K = 10 are presented in Table 3, and Table 4 illustrates a comparison of the best models across different nearest neighbor values. Our experiments revealed that utilizing K = 50 tends to enhance the performance, and the relative order of the results remains consistent, regardless of the number of nearest neighbors.
Table 3. KNN classification result with Resnet18 backbone. In this experiment, we set the number of neighbors as K = 10 and computed the averaged classification accuracy over three runs. Note that the Wasserstein distance with ( B = I d out ) is equivalent to a total variation.
Table 4. KNN classification accuracy with different number of neighbors.

5.4. Effect of the Regularization Parameter for Jeffrey Divergence

In this experiment, we evaluated model performance by varying the regularization parameter, denoted as λ . The results indicate a noteworthy improvement in performance with the introduction of regularization parameters. However, as shown in Table 5, it was observed that the performance did not exhibit significant changes across different values of λ , and setting λ = 0.1 yielded favorable results.
Table 5. KNN classification result with Resnet18 backbone. In this experiment, we set the number of neighbors as K = 50 and computed the averaged classification accuracy over three runs.

6. Conclusions

This study investigates SSL using TWD. We empirically evaluate several benchmark datasets and find that a simple combination of the softmax function and TWD performs poorly. To address this, we propose simplicial embedding [] and ArcFace models [] as probability models. Moreover, to mitigate the intricacies of optimizing TWD, we incorporate an upper bound on the squared 1-Wasserstein distance as a regularization technique. Overall, the combination of ArcFace and DCT outperforms their cosine similarity counterparts. Finally, we find that the combination of TWD (ClusterTree) and SEM yields favorable performance.
There are several potential future directions for our work. Firstly, improving representation learning for larger classes could involve employing larger models and/or introducing new regularization techniques. Secondly, integrating the proposed probability representation into other SSL models such as DINO [] could enhance our understanding of model performance across different learning tasks. Lastly, while we have empirically studied self-supervised learning with Wasserstein distance, the theoretical properties remain unclear. Therefore, investigating these theoretical properties represents another promising research direction.

Author Contributions

Conceptualization, M.Y., Y.T., G.H., H.Z. and Y.-H.T.; Methodology, M.Y., Y.T., G.H. and D.S.; Formal analysis, M.Y.; Writing—original draft, M.Y. and H.Z.; Writing—review & editing, Y.T., K.M.D., H.Z. and Y.-H.T.; Visualization, M.Y.; Funding acquisition, M.Y. All authors have read and agreed to the published version of the manuscript.

Funding

M.Y. was supported by MEXT KAKENHI Grant Number 24K03004. Y.T. was supported by MEXT KAKENHI Grant Number 23KJ1336. K.M.D. was funded by the Gatsby Charitable Foundation.

Institutional Review Board Statement

Not applicable.

Data Availability Statement

All the data used in the study is publicly accessible.

Conflicts of Interest

The authors declare no conflicts of interest.

References

  1. Kramer, M.A. Nonlinear principal component analysis using autoassociative neural networks. AIChE J. 1991, 37, 233–243. [Google Scholar] [CrossRef]
  2. Kingma, D.P.; Welling, M. Auto-encoding variational bayes. arXiv 2013, arXiv:1312.6114. [Google Scholar]
  3. Chen, X.; Fan, H.; Girshick, R.; He, K. Improved baselines with momentum contrastive learning. arXiv 2020, arXiv:2003.04297. [Google Scholar]
  4. Grill, J.B.; Strub, F.; Altché, F.; Tallec, C.; Richemond, P.; Buchatskaya, E.; Doersch, C.; Avila Pires, B.; Guo, Z.; Gheshlaghi Azar, M.; et al. Bootstrap your own latent—A new approach to self-supervised learning. In Proceedings of the NeurIPS, Virtual, 6–12 December 2020; pp. 21271–21284. [Google Scholar]
  5. He, K.; Fan, H.; Wu, Y.; Xie, S.; Girshick, R. Momentum contrast for unsupervised visual representation learning. In Proceedings of the CVPR, Virtual, 14–19 June 2020; pp. 9729–9738. [Google Scholar]
  6. Caron, M.; Misra, I.; Mairal, J.; Goyal, P.; Bojanowski, P.; Joulin, A. Unsupervised learning of visual features by contrasting cluster assignments. In Proceedings of the NeurIPS, Virtual, 6–12 December 2020; pp. 9912–9924. [Google Scholar]
  7. Chen, X.; He, K. Exploring simple siamese representation learning. In Proceedings of the CVPR, Virtual, 19–25 June 2021; pp. 15750–15758. [Google Scholar]
  8. Caron, M.; Touvron, H.; Misra, I.; Jégou, H.; Mairal, J.; Bojanowski, P.; Joulin, A. Emerging properties in self-supervised vision transformers. In Proceedings of the ICCV, Virtual, 11–17 October 2021; pp. 9650–9660. [Google Scholar]
  9. Jiang, Q.; Chen, C.; Zhao, H.; Chen, L.; Ping, Q.; Tran, S.D.; Xu, Y.; Zeng, B.; Chilimbi, T. Understanding and constructing latent modality structures in multi-modal representation learning. In Proceedings of the CVPR, Vancouver, BC, Canada, 18–22 June 2023; pp. 7661–7671. [Google Scholar]
  10. Lavoie, S.; Tsirigotis, C.; Schwarzer, M.; Vani, A.; Noukhovitch, M.; Kawaguchi, K.; Courville, A. Simplicial embeddings in self-supervised learning and downstream classification. In Proceedings of the ICLR, Kigali, Rwanda, 1–5 May 2023. [Google Scholar]
  11. Arjovsky, M.; Chintala, S.; Bottou, L. Wasserstein generative adversarial networks. In Proceedings of the ICML, Sydney, NSW, Australia, 6–11 August 2017; pp. 214–223. [Google Scholar]
  12. Kusner, M.; Sun, Y.; Kolkin, N.; Weinberger, K. From word embeddings to document distances. In Proceedings of the ICML, Lille, France, 6–11 July 2015; pp. 957–966. [Google Scholar]
  13. Sato, R.; Yamada, M.; Kashima, H. Re-evaluating Word Mover’s Distance. In Proceedings of the ICML, Baltimore, MD, USA, 17–23 July 2022; pp. 19231–19249. [Google Scholar]
  14. Sarlin, P.E.; DeTone, D.; Malisiewicz, T.; Rabinovich, A. Superglue: Learning feature matching with graph neural networks. In Proceedings of the CVPR, Virtual, 14–19 June 2020; pp. 4938–4947. [Google Scholar]
  15. Xian, R.; Yin, L.; Zhao, H. Fair and Optimal Classification via Post-Processing. In Proceedings of the ICML, Honolulu, HI, USA, 23–29 July 2023; pp. 37977–38012. [Google Scholar]
  16. Zhao, H. Costs and Benefits of Fair Regression. TMLR 2022, 1–22. [Google Scholar]
  17. Indyk, P.; Thaper, N. Fast image retrieval via embeddings. In Proceedings of the 3rd International Workshop on Statistical and Computational Theories of Vision, Nice, France, 12 October 2003; Volume 2, p. 5. [Google Scholar]
  18. Le, T.; Yamada, M.; Fukumizu, K.; Cuturi, M. Tree-sliced variants of wasserstein distances. In Proceedings of the NeurIPS, Vancouver, BC, Canada, 8–14 December 2019; pp. 12283–12294. [Google Scholar]
  19. Rabin, J.; Peyré, G.; Delon, J.; Bernot, M. Wasserstein Barycenter and Its Application to Texture Mixing. In Proceedings of the International Conference on Scale Space and Variational Methods in Computer Vision, Ein-Gedi, Israel, 29 May–2 June 2011; Springer: Berlin/Heidelberg, Germany, 2011; pp. 435–446. [Google Scholar]
  20. Kolouri, S.; Zou, Y.; Rohde, G.K. Sliced Wasserstein kernels for probability distributions. In Proceedings of the CVPR, Las Vegas, NV, USA, 26 June –1 July 2016; pp. 5258–5267. [Google Scholar]
  21. Deng, J.; Guo, J.; Xue, N.; Zafeiriou, S. Arcface: Additive angular margin loss for deep face recognition. In Proceedings of the CVPR, Long Beach, CA, USA, 16–20 June 2019; pp. 4690–4699. [Google Scholar]
  22. Becker, S.; Hinton, G.E. Self-organizing neural network that discovers surfaces in random-dot stereograms. Nature 1992, 355, 161–163. [Google Scholar] [CrossRef] [PubMed]
  23. Chen, T.; Kornblith, S.; Norouzi, M.; Hinton, G. A simple framework for contrastive learning of visual representations. In Proceedings of the ICML, Vienna, Austria, 12–18 July 2020; pp. 1597–1607. [Google Scholar]
  24. Oord, A.v.d.; Li, Y.; Vinyals, O. Representation learning with contrastive predictive coding. arXiv 2018, arXiv:1807.03748. [Google Scholar]
  25. Zbontar, J.; Jing, L.; Misra, I.; LeCun, Y.; Deny, S. Barlow twins: Self-supervised learning via redundancy reduction. In Proceedings of the ICML, Virtual, 18–24 July 2021; pp. 12310–12320. [Google Scholar]
  26. Gretton, A.; Bousquet, O.; Smola, A.; Schölkopf, B. Measuring statistical dependence with Hilbert-Schmidt norms. In Proceedings of the ALT, Singapore, 8–11 October 2005; pp. 63–77. [Google Scholar]
  27. Tsai, Y.H.H.; Bai, S.; Morency, L.P.; Salakhutdinov, R. A note on connecting barlow twins with negative-sample-free contrastive learning. arXiv 2021, arXiv:2104.13712. [Google Scholar]
  28. Cover, T.M.; Thomas, J.A. Elements of Information Theory; John Wiley & Sons: Hoboken, NJ, USA, 2012. [Google Scholar]
  29. Zhao, W.; Peyrard, M.; Liu, F.; Gao, Y.; Meyer, C.M.; Eger, S. MoverScore: Text generation evaluating with contextualized embeddings and earth mover distance. In Proceedings of the EMNLP-IJCNLP, Hong Kong, China, 3–7 November 2019; pp. 563–578. [Google Scholar]
  30. Yokoi, S.; Takahashi, R.; Akama, R.; Suzuki, J.; Inui, K. Word Rotator’s Distance. In Proceedings of the EMNLP, Virtual, 16–20 November 2020; pp. 2944–2960. [Google Scholar]
  31. Cuturi, M. Sinkhorn distances: Lightspeed computation of optimal transport. In Proceedings of the NIPS, Lake Tahoe, NV, USA, 5–10 December 2013; pp. 2292–2300. [Google Scholar]
  32. Kolouri, S.; Nadjahi, K.; Simsekli, U.; Badeau, R.; Rohde, G. Generalized sliced wasserstein distances. In Proceedings of the NeurIPS, Vancouver, BC, Canada, 8–14 December 2019; pp. 261–272. [Google Scholar]
  33. Mueller, J.W.; Jaakkola, T. Principal differences analysis: Interpretable characterization of differences between distributions. In Proceedings of the NIPS, Montreal, QC, Canada, 7–12 December 2015; pp. 1702–1710. [Google Scholar]
  34. Deshpande, I.; Hu, Y.T.; Sun, R.; Pyrros, A.; Siddiqui, N.; Koyejo, S.; Zhao, Z.; Forsyth, D.; Schwing, A.G. Max-Sliced Wasserstein distance and its use for GANs. In Proceedings of the CVPR, Long Beach, CA, USA, 16–20 June 2019; pp. 10648–10656. [Google Scholar]
  35. Paty, F.P.; Cuturi, M. Subspace Robust Wasserstein Distances. In Proceedings of the ICML, Long Beach, CA, USA, 9–15 June 2019; pp. 5072–5081. [Google Scholar]
  36. Evans, S.N.; Matsen, F.A. The phylogenetic Kantorovich–Rubinstein metric for environmental sequence samples. J. R. Stat. Soc. Ser. B (Stat. Methodol.) 2012, 74, 569–592. [Google Scholar] [CrossRef] [PubMed]
  37. Lozupone, C.; Knight, R. UniFrac: A new phylogenetic method for comparing microbial communities. Appl. Environ. Microbiol. 2005, 71, 8228–8235. [Google Scholar] [CrossRef] [PubMed]
  38. Sato, R.; Yamada, M.; Kashima, H. Fast Unbalanced Optimal Transport on Tree. In Proceedings of the NeurIPS, Virtual, 6–12 December 2020. [Google Scholar]
  39. Le, T.; Nguyen, T. Entropy partial transport with tree metrics: Theory and practice. In Proceedings of the AISTATS, Virtual, 13–15 April 2021; pp. 3835–3843. [Google Scholar]
  40. Takezawa, Y.; Sato, R.; Yamada, M. Supervised tree-wasserstein distance. In Proceedings of the ICML, Virtual, 18–24 July 2021; pp. 10086–10095. [Google Scholar]
  41. Takezawa, Y.; Sato, R.; Kozareva, Z.; Ravi, S.; Yamada, M. Fixed Support Tree-Sliced Wasserstein Barycenter. In Proceedings of the AISTATS, Valencia, Spain, 28–30 March 2022; pp. 1120–1137. [Google Scholar]
  42. Le, T.; Nguyen, T.; Fukumizu, K. Optimal transport for measures with noisy tree metric. In Proceedings of the AISTATS, Valencia, Spain, 2–4 May 2024; pp. 3115–3123. [Google Scholar]
  43. Chen, S.; Tabaghi, P.; Wang, Y. Learning ultrametric trees for optimal transport regression. In Proceedings of the AAAI, Buffalo, NY, USA, 3–6 June 2024; pp. 20657–20665. [Google Scholar]
  44. Houry, G.; Bao, H.; Zhao, H.; Yamada, M. Fast 1-Wasserstein distance approximations using greedy strategies. In Proceedings of the AISTATS, Valencia, Spain, 2–4 May 2024; pp. 325–333. [Google Scholar]
  45. Tong, A.Y.; Huguet, G.; Natik, A.; MacDonald, K.; Kuchroo, M.; Coifman, R.; Wolf, G.; Krishnaswamy, S. Diffusion earth mover’s distance and distribution embeddings. In Proceedings of the ICML, Virtual, 18–24 July 2021; pp. 10336–10346. [Google Scholar]
  46. Le, T.; Nguyen, T.; Phung, D.; Nguyen, V.A. Sobolev transport: A scalable metric for probability measures with graph metrics. In Proceedings of the AISTATS, Virtual, 28–30 March 2022; pp. 9844–9868. [Google Scholar]
  47. Otao, S.; Yamada, M. A linear time approximation of Wasserstein distance with word embedding selection. In Proceedings of the EMNLP, Singapore, 6–10 December 2023; pp. 15121–15134. [Google Scholar]
  48. Laouar, C.; Takezawa, Y.; Yamada, M. Large-scale similarity search with Optimal Transport. In Proceedings of the EMNLP, Singapore, 6–10 December 2023; pp. 11920–11930. [Google Scholar]
  49. Zapatero, M.R.; Tong, A.; Opzoomer, J.W.; O’Sullivan, R.; Rodriguez, F.C.; Sufi, J.; Vlckova, P.; Nattress, C.; Qin, X.; Claus, J.; et al. Trellis tree-based analysis reveals stromal regulation of patient-derived organoid drug responses. Cell 2023, 186, 5606–5619. [Google Scholar] [CrossRef] [PubMed]
  50. Backurs, A.; Dong, Y.; Indyk, P.; Razenshteyn, I.; Wagner, T. Scalable nearest neighbor search for optimal transport. In Proceedings of the ICML, Vienna, Austria, 12–18 July 2020; pp. 497–506. [Google Scholar]
  51. Dey, T.K.; Zhang, S. Approximating 1-Wasserstein Distance between Persistence Diagrams by Graph Sparsification. In Proceedings of the ALENEX, Alexandria, VA, USA, 9–10 January 2022; pp. 169–183. [Google Scholar]
  52. Yamada, M.; Takezawa, Y.; Sato, R.; Bao, H.; Kozareva, Z.; Ravi, S. Approximating 1-Wasserstein Distance with Trees. TMLR 2022, 1–9. [Google Scholar]
  53. Frogner, C.; Zhang, C.; Mobahi, H.; Araya, M.; Poggio, T.A. Learning with a Wasserstein loss. In Proceedings of the NIPS, Montreal, QC, Canada, 7–12 December 2015; pp. 2053–2061. [Google Scholar]
  54. Toyokuni, A.; Yokoi, S.; Kashima, H.; Yamada, M. Computationally Efficient Wasserstein Loss for Structured Labels. In Proceedings of the ECAL: Student Research Workshop, Virtual, 19–23 April 2021; pp. 1–7. [Google Scholar]
  55. Raginsky, M.; Sason, I. Concentration of measure inequalities in information theory, communications, and coding. Found. Trends® Commun. Inf. Theory 2013, 10, 1–246. [Google Scholar] [CrossRef]
  56. Neumann, J.V. Zur theorie der gesellschaftsspiele. Math. Ann. 1928, 100, 295–320. [Google Scholar] [CrossRef]
  57. Fan, K. Minimax theorems. Proc. Natl. Acad. Sci. USA 1953, 39, 42–47. [Google Scholar] [CrossRef] [PubMed]
  58. Bahdanau, D.; Cho, K.; Bengio, Y. Neural machine translation by jointly learning to align and translate. arXiv 2014, arXiv:1409.0473. [Google Scholar]
  59. Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit, J.; Jones, L.; Gomez, A.N.; Kaiser, Ł.; Polosukhin, I. Attention is all you need. In Proceedings of the NIPS, Long Beach, CA, USA, 4–9 December 2017; pp. 5998–6008. [Google Scholar]
  60. Ahmed, N.; Natarajan, T.; Rao, K.R. Discrete cosine transform. IEEE Trans. Comput. 1974, 100, 90–93. [Google Scholar] [CrossRef]
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

Article Metrics

Citations

Article Access Statistics

Multiple requests from the same IP address are counted as one view.