Max-Sliced Wasserstein Distance and its use for GANs
Ishan Deshpande, Yuan-Ting Hu, Ruoyu Sun, Ayis Pyrros, Nasir Siddiqui, Sanmi Koyejo, Zhizhen Zhao, David Forsyth, Alexander Schwing

TL;DR
This paper introduces the max-sliced Wasserstein distance, which improves sample complexity and reduces projection complexity, enabling efficient training of GANs on high-resolution images.
Contribution
It proposes the max-sliced Wasserstein distance, enhancing sliced Wasserstein with better sample and projection complexity for high-dimensional GAN training.
Findings
Max-sliced Wasserstein has superior sample complexity compared to Wasserstein.
The method enables GAN training on 256x256 images efficiently.
It reduces projection complexity in sliced Wasserstein computations.
Abstract
Generative adversarial nets (GANs) and variational auto-encoders have significantly improved our distribution modeling capabilities, showing promise for dataset augmentation, image-to-image translation and feature learning. However, to model high-dimensional distributions, sequential training and stacked architectures are common, increasing the number of tunable hyper-parameters as well as the training time. Nonetheless, the sample complexity of the distance metrics remains one of the factors affecting GAN training. We first show that the recently proposed sliced Wasserstein distance has compelling sample complexity properties when compared to the Wasserstein distance. To further improve the sliced Wasserstein distance we then analyze its `projection complexity' and develop the max-sliced Wasserstein distance which enjoys compelling sample complexity while reducing projection…
| en-es | es-en | en-fr | fr-en | en-de | de-en | en-ru | ru-en | en-zh | zh-en | |
|---|---|---|---|---|---|---|---|---|---|---|
| [6] - NN | 79.1 | 78.1 | 78.1 | 78.2 | 71.3 | 69.6 | 37.3 | 54.3 | 30.9 | 21.9 |
| [6] - CSLS | 81.7 | 83.3 | 82.3 | 82.1 | 74.0 | 72.2 | 44.0 | 59.1 | 32.5 | 31.4 |
| Max-sliced WGAN - NN | 79.6 | 79.1 | 78.2 | 78.5 | 71.9 | 69.6 | 38.4 | 58.7 | 34.9 | 25.1 |
| Max-sliced WGAN - CSLS | 82.0 | 84.1 | 82.5 | 82.3 | 74.8 | 73.1 | 44.6 | 61.7 | 35.3 | 31.9 |
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Max-Sliced Wasserstein Distance and its use for GANs
Ishan Deshpande, Yuan-Ting Hu, Ruoyu Sun, Ayis Pyrros†, Nasir Siddiqui†,
Sanmi Koyejo, Zhizhen Zhao, David Forsyth, Alexander Schwing
University of Illinois at Urbana-Champaign †Dupage Medical Group
[email protected], {ythu2, ruoyus}@illinois.edu, [email protected], [email protected],
{sanmi, zhizhenz, daf, aschwing}@illinois.edu
Abstract
Generative adversarial nets (GANs) and variational auto-encoders have significantly improved our distribution modeling capabilities, showing promise for dataset augmentation, image-to-image translation and feature learning. However, to model high-dimensional distributions, sequential training and stacked architectures are common, increasing the number of tunable hyper-parameters as well as the training time. Nonetheless, the sample complexity of the distance metrics remains one of the factors affecting GAN training. We first show that the recently proposed sliced Wasserstein distance has compelling sample complexity properties when compared to the Wasserstein distance. To further improve the sliced Wasserstein distance we then analyze its ‘projection complexity’ and develop the max-sliced Wasserstein distance which enjoys compelling sample complexity while reducing projection complexity, albeit necessitating a max estimation. We finally illustrate that the proposed distance trains GANs on high-dimensional images up to a resolution of 256x256 easily.
1 Introduction
Generative modeling capabilities have improved tremendously in the last few years, especially since the advent of deep learning-based models like generative adversarial nets (GANs) [11] and variational auto-encoders (VAEs) [17]. Instead of sampling from a high-dimensional distribution, GANs and VAEs transform a sample obtained from a simple distribution using deep nets. These models have found use in dataset augmentation [31], image-to-image translation [15, 37, 21, 14, 24, 29, 35, 38], and even feature learning for inference related tasks [9].
GANs and many of their variants formulate generative modeling as a two player game. A ‘generator’ creates samples that resemble the ground truth data. A ‘discriminator’ tries to distinguish between ‘artificial’ and ‘real’ samples. Both, the generator and discriminator, are parametrized using deep nets and trained via stochastic gradient descent. In its original formulation [11], a GAN minimizes the Jenson-Shannon divergence between the data distribution and the probability distribution induced in the data space by the generator. Many other variants have been proposed, which use either some divergence or the integral probability metric to measure the distance between the distributions [2, 22, 12, 20, 8, 7, 27, 4, 26, 23, 13, 30]. When carefully trained, GANs are able to produce high quality samples [28, 16, 25, 16, 25]. Training GANs is, however, difficult – especially on high dimensional datasets.
The scaling difficulty of GANs may be related to one fundamental theoretical issue: the sample complexity. It is shown in [3] that KL-divergence, Jenson-Shannon and Wasserstein distance do not generalize, in the sense that the population distance cannot be approximated by an empirical distance when there are only a polynomial number of samples. To improve generalization, one popular method is to limit the discriminator class [3, 10] and interpret the training process as minimizing a neural-net distance [3].
In this work, we promote a different path that resolves the sample complexity issue. A fundamental reason for the exponential sample complexity of the Wasserstein distance is the sparsity of points in a high dimensional space. Even if two collections of points are randomly drawn from the same ball, these two collections are far away from each other. Our intuition is that projection onto a low-dimensional subspace, such as a line, mitigates the artificial distance effect in high dimensions and the distance of the projected samples reflects the true distance.
We first apply this intuition to analyze the recently proposed sliced Wasserstein distance GAN, which is based on the average Wasserstein distance of the projected versions of two distributions along a few randomly picked directions [8, 20, 34]. We prove that the sliced Wasserstein distance is generalizable for Gaussian distributions (i.e., it has polynomial sample complexity), while Wasserstein distance is not, thus partially explaining why [8, 20, 34] may exhibit better behavior than the Wasserstein distance [2].
One drawback of the sliced Wasserstein distance is that it requires a large number of projection directions, since random directions lose a lot of information. To address this concern, we propose to project onto the “best direction,” along which the projected distance is maximized. We call the corresponding metric the “max-sliced Wasserstein distance,” and prove that it is also generalizable for Gaussian distributions.
Using this new metric, we are able to train GANs to generate high resolution images from the CelebA-HQ [16] and LSUN Bedrooms [36] datasets. We also achieve improved performance in other distribution matching tasks like unpaired word translation [6].
The main contributions of this paper are the following:
- •
We analyze in Sec. 3.1 the sample complexity of the Wasserstein and sliced Wasserstein distances. We show that for a certain class of distributions the Wasserstein distance has an exponential sample complexity, while the sliced Wasserstein distance [8, 34] has a polynomial sample complexity.
- •
We then study in Sec. 3.2 the projection complexity of the sliced Wasserstein distance, i.e., how the number of random projection directions affects estimation.
- •
We introduce the max-sliced Wasserstein distance in Sec. 3.3 to address the projection complexity issue.
- •
We then employ the max-sliced Wasserstein distance to train GANs in Sec. 4, demonstrating significant reduction in the number of projection directions required for the sliced-Wasserstein GAN.
2 Background
Generative modeling is the task of learning a probability distribution from a given dataset of samples drawn from an unknown data distribution . While this has traditionally been seen through the lens of likelihood-maximization, GANs pose generative modeling as a distance minimization problem. More specifically, these approaches recommend learning the data distribution by finding a distribution that solves:
[TABLE]
where is some distance or divergence between distributions. Arjovsky et al. [1] proposed using the Wasserstein distance in the context of GAN formulations. The Wasserstein-p distance between distributions and is defined as:
[TABLE]
where is the set of all possible joint distributions on with marginals and .
Estimating the Wasserstein distance is, however, not straightforward. Arjovsky et al. [2] used the Kantorovich-Rubinstein duality to the Wasserstein-1 distance, which states that:
[TABLE]
where the supremum is over all -Lipschitz functions . The function is commonly represented via a deep net and various ways have been suggested to enforce the Lipschitz constraint, e.g., [12].
While the Wasserstein distance based approaches have been successful in several complex generative tasks, they suffer from instability arising from incorrect estimation. The cause behind this was noted in [33], where it was shown that estimates of the Wasserstein distance suffer from the ‘curse of dimensionality.’ To tackle the instability and complexity, a sliced version of the Wasserstein-2 distance was employed by [8, 20, 18, 34], which only requires estimating distances of 1-d distributions and is, therefore, more efficient. The “sliced Wasserstein-p distance” [5] between distributions and is defined as
[TABLE]
where , denote the projection (i.e., marginal) of , onto the direction , and is the set of all possible directions on the unit sphere. Kolouri et al. [19] have shown that the sliced Wasserstein distance satisfies the properties of non-negativity, identity of indiscernibles, symmetry, and subadditivity. Hence, it is a true metric.
In practice, Deshpande et al. [8] approximate the sliced Wasserstein-2 distance between the distributions by using samples , , and a finite number of random Gaussian directions, replacing the integration over with a summation over a randomly chosen set of unit vectors , where ‘’ is used to indicate normalization to unit length. With (and hence, ) being implicitly parametrized by , [8] uses the following program for generative modeling:
[TABLE]
The Wasserstein-2 distance between the projected samples and can be computed by finding the optimal transport map. For 1-d distributions, this can be done through sorting [32], i.e.,
[TABLE]
where and are permutations that sort the projected sample sets and respectively, i.e., .
The program in Eq. (5), when coupled with a discriminator, was shown to work well on high-dimensional datasets. Instead of working directly with sets and , it was proposed that we transform them to an adversarially learnt feature space, say and respectively, where is implicitly parameterized by , e.g., by using a deep net. The generator, parametrized by , minimizes
[TABLE]
The adversarial feature space is learnt via a discriminator which classifies real and fake data. This discriminator can be written as , where is a logistic layer and the parameters are learnt using
[TABLE]
3 Analysis and Max-Sliced Distance
In this section we provide the first analysis of the sample-complexity benefits of the sliced Wasserstein distance compared to the Wasserstein distance. We discuss how ‘projection complexity’ is a shortcoming of the sliced Wasserstein distance and present as a fix the max-sliced Wasserstein distance, which – as we will show – enjoys the same beneficial sample-complexity as the slice Wasserstein distance, albeit necessitating estimation of a maximum. We will then show how those results are used for training GANs.
3.1 Sample complexity of the Wasserstein and sliced Wasserstein distances
We first show the benefits of using the sliced Wasserstein distance over the Wasserstein distance. Specifically, we show that, in certain cases, estimation of the sliced Wasserstein distance has polynomial complexity, while the Wasserstein distance does not. To make this notion concrete, we introduce ‘generalizability’ of a distance:
Definition 1
Consider a family of distributions over . A distance is said to be -generalizable if there exists a polynomial such that for any two distributions , and their empirical ensembles with size , the following holds:
[TABLE]
With this definition, we can prove the following result:
Claim 1
Consider the family of Gaussian distributions
[TABLE]
The sliced Wasserstein-2 distance defined in Eq. (4) is -generalizable whereas the Wasserstein-2 distance defined in Eq. (2) is not.
Proof. See the supplementary material.
Claim 1 implies that for GAN training, under certain conditions, it is better to use the sliced Wasserstein distance as we can get a more accurate training signal with a fixed computational budget. This will result in a more stable discriminator.
Even though the sliced Wasserstein distance enjoys better sample complexity, it has limitations when a finite number of random projection directions is used. We refer to this property as ‘projection complexity’ and illustrate it in the following section. We then present our proposed method to help alleviate this problem.
3.2 Projection complexity of the Sliced Wasserstein Distance
We begin with a simple example to demonstrate the limitations of using defined in Eq. (4) for learning distributions through gradient descent. To analyze the ‘projection complexity’ of we use infinitely many samples, but we use only finitely many directions .
Concretely, consider two -dimensional Gaussians with identity covariance. Let be the data distribution and let be the induced generator distribution, parametrized only by its mean , while is a fixed unit vector. Using gradient descent on the estimated sliced Wasserstein distance between and , we aim to learn so that . Thus, the updates for are
[TABLE]
where is the learning rate.
The sliced Wasserstein distance is calculated by projecting the distributions (since we use infinitely many samples) onto random directions and comparing the projections, i.e., marginals. Therefore, the estimated distance is
[TABLE]
where is the Wasserstein distance between marginal distributions , . Note that each is normalized to unit norm.
Intuitively, projection of the Gaussians , onto any direction other than makes them appear closer than they actually are – making the learning process slower. For any given , it is easy to see that . Therefore, the update equation for is
[TABLE]
The updates to are particularly small for high dimensional distributions, since any random unit-norm direction is orthogonal to with high probability. Therefore, very slowly. We verify this effect empirically in Fig. 1, experimenting with different numbers of random projections and find that using the sliced Wasserstein distance results in very slow convergence. This problem is further aggravated when the dimensions of the distributions increase.
It is intuitively obvious that the aforementioned problem can easily be solved by choosing as the projection direction. This results in larger updates and, consequently, faster convergence. This intuition is also verified empirically. We repeat the same experiment of learning , but this time we use only one projection direction . This is labelled as in Fig. 1. By simply using the important projection direction, we achieve fast convergence of the mean.
Considering this example, it is evident that some projection directions are more meaningful than others. Therefore, GAN training should benefit from including such directions when comparing distributions. This observation motivates the max-sliced Wasserstein distance which we discuss next.
3.3 Max sliced Wasserstein distance
In this section we introduce the max-sliced Wasserstein distance and illustrate that it fixes the ‘projection complexity’ concern. We also prove that the max-sliced Wasserstein distance enjoys the same sample-complexity as the sliced Wasserstein distance, i.e., we are not trading one benefit for another.
As noted in Sec. 3.2, it is useful to include the most meaningful projection direction. Formally, for the aforementioned example of , we want to use the direction that satisfies
[TABLE]
Comparing distributions along such a direction can, in fact, be shown to be a proper distance. We call it the ‘max-sliced Wasserstein distance’ and define it as follows:
Definition 2
Let be the set of all directions on the unit sphere. Then, the max-sliced Wasserstein-2 distance between distributions and is defined as:
[TABLE]
As illustrated in the following claim, it can be shown easily that max- is a valid distance.
Claim 2
The max-sliced Wasserstein-2 distance defined in Eq. (13) is a well defined distance between distributions.
Proof. See supplementary material.
We can also show that the max-sliced Wasserstein distance has polynomial sample complexity:
Claim 3
Consider the family of Gaussian distributions
[TABLE]
The max-sliced Wasserstein-2 (max-) distance is -generalizable.
Proof. See the supplementary material.
Since it is a valid metric, we can directly use the max-sliced Wasserstein distance for learning distributions.
By definition, the max-sliced Wasserstein distance overcomes the limitation discussed in Sec. 3.2. However, we note that the use of a max-estimator is necessary, which is harder than estimation of a conventional random variable. In the following section, we discuss how the max-sliced Wasserstein distance can be estimated and used in a GAN-like setting.
3.4 max-sliced GAN
In this section, we discuss our approach that uses the max-sliced Wasserstein distance to train a GAN. We also discuss how we approximate the max-sliced Wasserstein distance in practice. Since we use max-, we are able to achieve significant savings in terms of the number of projection directions needed as compared to [8].
Intuitively, we want to project data into a space where real samples can easily be differentiated from artificially generated points. To this end, we work with an adversarially learnt feature space, i.e., we use the penultimate layer of a discriminator network. In this feature space, we minimize the max-sliced Wasserstein distance max-. As will be discussed later in this section, finding the actual max is hard and therefore we resort to approximating it.
Let again denote the data distribution and let refer to the induced generator distribution. Further, let the discriminator be represented as , where denotes the weights of a fully connected layer and represents the feature space we are interested in. Further, let and represent the two empirical distributions in this feature space. Then, we would like to solve
[TABLE]
where is the set of all normalized directions. There is no easy way in general to solve
[TABLE]
even if the parameters of the feature transform are fixed. This is because computation of the Wasserstein distance in the 1-dimensional case requires sorting, i.e., solving of a minimization problem. Hence the program given in Eq. (15) is a saddlepoint objective, for which both maximization and minimization can be solved exactly when assuming the parameters of the other program to be fixed.
If we want to jointly find the parameters of the feature transform and the projection direction , i.e., if we want to solve
[TABLE]
using gradient descent based methods, we also need to pay attention to bounded-ness of the objective. Using regularization often proves tricky and may require separate tuning for each use case.
To circumvent those difficulties when jointly searching for and , we use a surrogate function and write the objective for the discriminator as follows:
[TABLE]
Intuitively, and in spirit similar to , we want the surrogate function to transform the data via into a space where and are easy to differentiate. Moreover, we want to be the direction which best separates the transformed real and generated data. A variety of surrogate functions such as the log-loss as specified in Eq. (8), the hinge-loss, or a moment separator with
[TABLE]
come to mind immediately.
For instance, in case of a log-loss, learns to classify real and fake samples, essentially performing linear logistic regression using on a learned feature representation . If trained to optimality, the two distributions are well separated in the discriminator’s feature space . An example is given in Fig. 2. The discriminator takes two distributions, shown in Fig. 2(a) and is trained to classify them. In doing so the discriminator transforms them to the feature space shown in Fig. 2(b). In this simple example, we can plot the Wasserstein distance along the different projection directions. This is visualized in Fig. 2(c). The discriminator’s final layer can be considered as a projection direction. This direction is very close to the maximizer of the projected Wasserstein distance in the feature space.
Additionally, in this case, can be approximated with – because the discriminator, trained for classification, essentially separates the distributions along . If we compute the Wasserstein-2 distance for projections onto different angles (as in Fig. 2(c)), we see that the maximum distance is achieved close to the projection direction from the discriminator, i.e., . We next assess: ‘how close?’
While log-loss and all other functions seem intuitive, we provide for the special case of the moment separator given in Eq. (18) and an identity transform the maximal sub-optimality in terms of the max-sliced Wasserstein distance:
Claim 4
For the surrogate function given in Eq. (18), the identity, and computed as specified in Eq. (17), we obtain
[TABLE]
for a lower bound , where is the difference of dataset means.
Proof. See the supplementary material.
To summarize, training the discriminator for classification provides a rich feature space which can be utilized for faster training. We note that the discriminator might be trained to obtain such features in a more explicit manner, but we leave this to future research.
3.5 max-sliced GAN Algorithm
We summarize the resulting training process in Alg. 1. It proceeds as follows: In every iteration, we draw a set of samples and from the true and fake distributions. We optimize the parameters and of the feature transform for iterations ( is a hyper-parameter) to maximize a surrogate loss function . Then we compute the Wasserstein-2 distance between the output distributions of the discriminator, i.e., . The generator is trained to minimize this distance. In our experiments, we choose to be the binary classification loss.
4 Experiments
In this section, we present results to demonstrate the effectiveness of the max-sliced Wasserstein distance and the computational benefits it offers over the sliced Wasserstein distance. We show quantitative results on unpaired word translation [6], and qualitative and quantitative results on image generation tasks using the CelebA-HQ [16] and the LSUN Bedrooms [36] datasets.
4.1 Word Translation without Parallel Data
We evaluate the effectiveness of the max-sliced GAN on unsupervised word translation tasks, i.e., without paired/parallel data [6]. This allows us to quantitatively compare different methods.
The setting of this experiment is as follows. We are given embeddings of words from two languages, say . We want to learn an orthogonal transformation that maps the source embeddings to , i.e.:
[TABLE]
The current state-of-the-art [6] employs a GAN-like [11] adversary to learn the transformation. Therefore, the transformation is learned by minimizing the Jenson-Shannon divergence between and . We instead minimize the max-sliced Wasserstein distance to learn .
We follow the training method and evaluation in [6] and report the word translation precision by computing the retrieval precision@k for on the MUSE bilingual dictionaries [6]. During testing, 1,500 queries are tested and 200k words of the target language are taken into account. We compare our method with [6] and present results for 5 pairs of languages in Tab. 1. In Tab. 1 ‘NN’ represents use of nearest neighbors to build the dictionary after training the transformation , and ‘CSLS’ stands for use of cross-domain similarity local scaling [6]. Our method with CSLS outperforms the baseline in all tested language pairs. This demonstrates the competitiveness of our method with current established GAN frameworks.
4.2 Image Generation
In this section, we present results on the task of image generation. Using the max-sliced Wasserstein distance, we train a GAN on the CelebA [16] and LSUN Bedrooms [36] datasets for images of resolution 256x256. We compare with the sliced Wasserstein GAN [8].
Samples generated by each trained model are presented in Fig. 3 and Fig. 4. The results of the max-sliced Wasserstein GAN are shown Fig. 3(a) and Fig. 4(a). We train the sliced Wasserstein GAN with 100, 1000, and 10000 random projections. Results of each of these are respectively shown in Fig. 3(b), Fig. 3(c), and Fig. 3(d) for CelebA-HQ, and in Fig. 4(b), Fig. 4(c), and Fig. 4(d) for LSUN. The max-sliced Wasserstein GAN using just one projection direction is able to produce results which are either comparable or better than the sliced Wasserstein GAN even when using 10000 projections. This significantly reduces the computational complexity and also the memory footprint of the model.
We used a simple extension of the popular DCGAN architecture for the generator and discriminator. Two extra strided (transpose) convolutional layers are added to the generator and the discriminator to scale to 256x256. We do not use any special normalization/ initialization to train the models. Specific details are given in the supplementary.
5 Conclusion
In this paper, we analyzed the Wasserstein and sliced Wasserstein distance and developed a simple yet effective training strategy for generative adversarial nets based on the max-sliced Wasserstein distance. We showed that this distance enjoys a better sample complexity than the Wasserstein distance, and a better projection complexity than the sliced Wasserstein distance. We developed a method to approximate it using a surrogate loss, and also analyzed the approximation error for one such surrogate. Empirically, we showed that the discussed approach is able to learn high dimensional distributions. The method requires orders of magnitude fewer projection directions than the sliced Wasserstein GAN even though both work in a similar distance space.
Acknowledgments: This work is supported in part by NSF under Grant No. 1718221, Samsung, and 3M. We thank NVIDIA for providing GPUs used for this work.
The reference list from the paper itself. Each links out to its DOI / PubMed record.
- 1[1] M. Arjovsky and L. Bottou. Towards principled methods for training generative adversarial networks. In ICLR , 2017.
- 2[2] M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein gan. In ICML , 2017.
- 3[3] S. Arora, R. Ge, Y. Liang, T. Ma, and Y. Zhang. Generalization and equilibrium in generative adversarial nets (gans). In ICML , 2017.
- 4[4] D. Berthelot, T. Schumm, and L. Metz. Began: Boundary equilibrium generative adversarial networks. ar Xiv preprint ar Xiv:1703.10717 , 2017.
- 5[5] N. Bonneel, J. Rabin, G. Peyré, and H. Pfister. Sliced and radon wasserstein barycenters of measures. Journal of Mathematical Imaging and Vision , 2015.
- 6[6] A. Conneau, G. Lample, M. Ranzato, L. Denoyer, and H. Jegou. Word translation without parallel data. In ICLR , 2018.
- 7[7] R. W. A. Cully, H. J. Chang, and Y. Demiris. Magan: Margin adaptation for generative adversarial networks. ar Xiv preprint ar Xiv:1704.03817 , 2017.
- 8[8] I. Deshpande, Z. Zhang, and A. Schwing. Generative modeling using the sliced wasserstein distance. In CVPR , 2018.
