Gradient Masked Averaging for Federated Learning
Irene Tenison, Sai Aravind Sreeramadas, Vaikkunth Mugunthan, Edouard, Oyallon, Irina Rish, Eugene Belilovsky

TL;DR
This paper introduces a gradient masked averaging method for federated learning that improves model generalization across heterogeneous, non-i.i.d. datasets by focusing on invariant mechanisms and reducing information loss during aggregation.
Contribution
It proposes a novel gradient masked averaging technique as a drop-in replacement for standard averaging in federated learning, enhancing performance on diverse client data.
Findings
Consistent performance improvements across multiple FL algorithms.
Enhanced generalization on non-i.i.d. and real-world datasets.
Robustness to data heterogeneity and imbalance.
Abstract
Federated learning (FL) is an emerging paradigm that permits a large number of clients with heterogeneous data to coordinate learning of a unified global model without the need to share data amongst each other. A major challenge in federated learning is the heterogeneity of data across client, which can degrade the performance of standard FL algorithms. Standard FL algorithms involve averaging of model parameters or gradient updates to approximate the global model at the server. However, we argue that in heterogeneous settings, averaging can result in information loss and lead to poor generalization due to the bias induced by dominant client gradients. We hypothesize that to generalize better across non-i.i.d datasets, the algorithms should focus on learning the invariant mechanism that is constant while ignoring spurious mechanisms that differ across clients. Inspired from recent works…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsPrivacy-Preserving Technologies in Data · Statistical Methods and Inference · Advanced Graph Neural Networks
