Optimal Transport-based Domain Alignment as a Preprocessing Step for Federated Learning
Luiz Manella Pereira, M. Hadi Amini

TL;DR
This paper introduces an optimal transport-based preprocessing method for federated learning that aligns datasets across devices, reducing distributional discrepancies and improving model generalization with fewer communication rounds.
Contribution
It proposes a novel optimal transport approach using Wasserstein barycenters to preprocess data in federated learning, enhancing model performance and convergence.
Findings
Achieves higher generalization with fewer communication rounds.
Reduces distributional discrepancy across datasets.
Improves federated learning performance on CIFAR-10.
Abstract
Federated learning (FL) is a subfield of machine learning that avoids sharing local data with a central server, which can enhance privacy and scalability. The inability to consolidate data leads to a unique problem called dataset imbalance, where agents in a network do not have equal representation of the labels one is trying to learn to predict. In FL, fusing locally-trained models with unbalanced datasets may deteriorate the performance of global model aggregation, and reduce the quality of updated local models and the accuracy of the distributed agents' decisions. In this work, we introduce an Optimal Transport-based preprocessing algorithm that aligns the datasets by minimizing the distributional discrepancy of data along the edge devices. We accomplish this by leveraging Wasserstein barycenters when computing channel-wise averages. These barycenters are collected in a trusted…
Peer Reviews
Decision·Submitted to ICLR 2025
The proposed preprocessing step is interesting, simple yet effective in boosting FL learning performance.
Some concerns are as follows: 1) technical exposition in Section 3.2 is poor. For example, I have no idea what Lines 184-185 mean. There is no definition for $\mathcal{L}\_{d^P}(\cdot,\cdot)$ in Eq. (4), $\Sigma\_n$ in Eq.(5), $W\_{reg}$ in Eq. (7). Why $\lambda\_{s}\in\Sigma\_{n}$ in Line 215? 2) I feel confusing in Lines 339-350 regarding the number of epochs when the number of clients is small. In Lines 339-341, it seems to use a large number of epochs when P is small. But in Line 345 it seem
1. The author introduces a novel method using Optimal Transport to address the data heterogeneity in federated learning. 2. The method is impressive overall and seems easy to apply to different FL methods.
1. Overall presentation is below average: Many of the notations in section 3.2 are not explained. Also, the authors should focus on the meaning and utility of OT, instead of the formula derivation. Section 4 is lack of description. The authors should explain the whole process in detail by bullet list or paragraphs with subtitles based on Figure 2,3. 2. Weak Experiments: The comparisons are mostly with non-aligned FL methods but only one (CCVR) baseline. More baselines should be included to demon
- The paper is well motivated and clearly written. - As far as I know, use of optimal transport for domain alignment is novel in federated learning.
- The paper mentions several times that the proposed preprocessing algorithm would preserve privacy. However, I could not find a clear definition of privacy notion the paper is referring to. Could the authors clarify what privacy notion they're targeting and how they compute/measure the privacy guarantees?
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsPrivacy-Preserving Technologies in Data · Domain Adaptation and Few-Shot Learning · Adversarial Robustness in Machine Learning
