Pruning-Aware Merging for Efficient Multitask Inference
Xiaoxi He, Dawei Gao, Zimu Zhou, Yongxin Tong, Lothar Thiele

TL;DR
This paper introduces Pruning-Aware Merging (PAM), a method to merge and prune neural networks for multitask inference on resource-limited devices, significantly reducing computation costs across task combinations.
Contribution
The paper proposes a novel heuristic merging scheme, PAM, that considers future pruning to optimize multitask network efficiency, outperforming existing merging methods.
Findings
PAM achieves up to 4.87x less computation than no-merging baseline.
PAM outperforms state-of-the-art merging schemes by up to 2.01x.
The method is effective across different datasets and architectures.
Abstract
Many mobile applications demand selective execution of multiple correlated deep learning inference tasks on resource-constrained platforms. Given a set of deep neural networks, each pre-trained for a single task, it is desired that executing arbitrary combinations of tasks yields minimal computation cost. Pruning each network separately yields suboptimal computation cost due to task relatedness. A promising remedy is to merge the networks into a multitask network to eliminate redundancy across tasks before network pruning. However, pruning a multitask network combined by existing network merging schemes cannot minimise the computation cost of every task combination because they do not consider such a future pruning. To this end, we theoretically identify the conditions such that pruning a multitask network minimises the computation of all task combinations. On this basis, we propose…
| Pruning | Tasks | Accuracy | FLOPs () | ||||
|---|---|---|---|---|---|---|---|
| B1 | B2 | PAM | B1 | B2 | PAM | ||
| P1 | A | 95.42% | 95.30% | 94.67% | 28.34 | 52.58 | 28.49 |
| B | 96.30% | 96.40% | 95.70% | 28.34 | 52.58 | 26.16 | |
| A&B | 95.86% | 95.85% | 95.19% | 56.69 | 52.58 | 48.68 | |
| P2 | A | 95.82% | 95.73% | 95.70% | 18.64 | 31.19 | 18.65 |
| B | 96.46% | 96.72% | 96.38% | 18.64 | 31.19 | 18.65 | |
| A&B | 96.14% | 96.22% | 96.04% | 37.27 | 31.19 | 26.48 | |
| Pruning | Tasks | Accuracy | FLOPs () | ||||
|---|---|---|---|---|---|---|---|
| B1 | B2 | PAM | B1 | B2 | PAM | ||
| P1 | A | 89.45% | 89.09% | 89.60% | 4.52 | 7.3 | 4.48 |
| B | 87.81% | 87.69% | 88.00% | 4.32 | 7.3 | 4.49 | |
| A&B | 88.63% | 88.39% | 88.80% | 8.85 | 7.3 | 4.70 | |
| P2 | A | 90.34% | 90.27% | 90.36% | 153.13 | 243.20 | 155.82 |
| B | 88.84% | 88.74% | 88.76% | 152.65 | 243.20 | 155.84 | |
| A&B | 89.59% | 89.51% | 89.56% | 305.78 | 243.20 | 156.74 | |
| Pruning | Tasks | Accuracy | FLOPs () | ||||
| B1 | B2 | PAM | B1 | B2 | PAM | ||
| P1 | A | 89.77% | 89.49% | 89.87% | 7.96 | 12.66 | 7.94 |
| B | 82.81% | 82.82% | 82.14% | 7.91 | 12.66 | 7.95 | |
| C | 83.20% | 82.68% | 83.30% | 7.94 | 12.66 | 7.94 | |
| D | 85.74% | 86.45% | 86.03% | 7.58 | 12.66 | 7.93 | |
| E | 87.10% | 86.52% | 86.90% | 7.87 | 12.66 | 7.93 | |
| A&B | 86.29% | 86.16% | 86.00% | 15.87 | 12.66 | 7.98 | |
| A&C | 86.48% | 86.09% | 86.59% | 15.90 | 12.66 | 7.97 | |
| A&D | 87.75% | 87.97% | 87.95% | 15.54 | 12.66 | 7.97 | |
| A&E | 88.44% | 88.01% | 88.39% | 15.84 | 12.66 | 7.96 | |
| B&C | 83.00% | 82.75% | 82.72% | 15.85 | 12.66 | 7.98 | |
| B&D | 84.28% | 84.64% | 84.09% | 15.49 | 12.66 | 7.97 | |
| B&E | 84.95% | 84.67% | 84.52% | 15.79 | 12.66 | 7.97 | |
| C&D | 84.47% | 84.57% | 84.66% | 15.52 | 12.66 | 7.96 | |
| C&E | 85.15% | 84.60% | 85.10% | 15.81 | 12.66 | 7.96 | |
| D&E | 86.42% | 86.49% | 86.47% | 15.45 | 12.66 | 7.96 | |
| A&B&C | 85.26% | 85.00% | 85.10% | 23.81 | 12.66 | 8.01 | |
| A&B&D | 86.11% | 86.25% | 86.01% | 23.45 | 12.66 | 8.01 | |
| A&B&E | 86.56% | 86.28% | 86.30% | 23.75 | 12.66 | 8.00 | |
| A&C&D | 86.24% | 86.21% | 86.40% | 23.48 | 12.66 | 8.00 | |
| A&C&E | 86.69% | 86.23% | 86.69% | 23.78 | 12.66 | 7.99 | |
| A&D&E | 87.54% | 87.49% | 87.60% | 23.42 | 12.66 | 7.99 | |
| B&C&D | 83.92% | 83.98% | 83.82% | 23.43 | 12.66 | 8.01 | |
| B&C&E | 84.37% | 84.01% | 84.11% | 23.73 | 12.66 | 8.00 | |
| B&D&E | 85.22% | 85.26% | 85.02% | 23.37 | 12.66 | 8.00 | |
| C&D&E | 85.35% | 85.22% | 85.41% | 23.39 | 12.66 | 7.99 | |
| A&B&C&D | 85.38% | 85.36% | 85.34% | 31.39 | 12.66 | 8.04 | |
| A&B&C&E | 85.72% | 85.38% | 85.55% | 31.69 | 12.66 | 8.03 | |
| A&B&D&E | 86.35% | 86.32% | 86.23% | 31.33 | 12.66 | 8.03 | |
| A&C&D&E | 86.45% | 86.29% | 86.53% | 31.36 | 12.66 | 8.02 | |
| B&C&D&E | 84.71% | 84.62% | 84.59% | 31.31 | 12.66 | 8.03 | |
| A&B&C&D&E | 85.72% | 85.59% | 85.65% | 39.27 | 12.66 | 8.06 | |
| P2 | A | 89.57% | 89.38% | 89.24% | 22.91 | 36.33 | 23.28 |
| B | 81.96% | 83.15% | 83.39% | 23.16 | 36.33 | 23.29 | |
| C | 82.96% | 81.61% | 82.10% | 22.93 | 36.33 | 23.28 | |
| D | 85.04% | 85.12% | 85.29% | 21.16 | 36.33 | 23.27 | |
| E | 86.43% | 85.81% | 85.57% | 21.29 | 36.33 | 23.27 | |
| A&B | 85.76% | 86.27% | 86.31% | 46.07 | 36.33 | 23.32 | |
| A&C | 86.26% | 85.50% | 85.67% | 45.84 | 36.33 | 23.31 | |
| A&D | 87.31% | 87.25% | 87.27% | 44.07 | 36.33 | 23.30 | |
| A&E | 88.00% | 87.60% | 87.41% | 44.20 | 36.33 | 23.30 | |
| B&C | 82.46% | 82.38% | 82.75% | 46.08 | 36.33 | 23.31 | |
| B&D | 83.50% | 84.14% | 84.34% | 44.31 | 36.33 | 23.31 | |
| B&E | 84.19% | 84.48% | 84.48% | 44.45 | 36.33 | 23.31 | |
| C&D | 84.00% | 83.37% | 83.69% | 44.09 | 36.33 | 23.30 | |
| C&E | 84.69% | 83.71% | 83.83% | 44.22 | 36.33 | 23.30 | |
| D&E | 85.74% | 84.47% | 85.43% | 42.45 | 36.33 | 23.29 | |
| A&B&C | 84.83% | 84.71% | 84.91% | 68.99 | 36.33 | 23.34 | |
| A&B&D | 85.52% | 85.88% | 85.97% | 67.22 | 36.33 | 23.34 | |
| A&B&E | 85.99% | 86.11% | 86.07% | 67.36 | 36.33 | 23.34 | |
| A&C&D | 85.86% | 85.37% | 85.54% | 67.00 | 36.33 | 23.33 | |
| A&C&E | 86.32% | 85.60% | 85.64% | 67.13 | 36.33 | 23.32 | |
| A&D&E | 87.01% | 86.77% | 86.70% | 65.36 | 36.33 | 23.32 | |
| B&C&D | 83.32% | 83.29% | 83.59% | 67.24 | 36.33 | 23.34 | |
| B&C&E | 83.78% | 83.52% | 83.69% | 67.37 | 36.33 | 23.33 | |
| B&D&E | 84.48% | 84.69% | 84.75% | 65.60 | 36.33 | 23.33 | |
| C&D&E | 84.81% | 84.18% | 84.32% | 65.38 | 36.33 | 23.32 | |
| A&B&C&D | 84.88% | 84.82% | 85.00% | 90.15 | 36.33 | 23.37 | |
| A&B&C&E | 85.23% | 84.99% | 85.07% | 90.28 | 36.33 | 23.36 | |
| A&B&D&E | 85.75% | 85.87% | 85.87% | 88.51 | 36.33 | 23.36 | |
| A&C&D&E | 86.00% | 85.48% | 85.55% | 88.29 | 36.33 | 23.35 | |
| B&C&D&E | 84.10% | 83.92% | 84.09% | 88.53 | 36.33 | 23.36 | |
| A&B&C&D&E | 85.19% | 85.01% | 85.12% | 111.44 | 36.33 | 23.39 | |
| Model | Tasks | Accuracy | FLOPs () | ||||
|---|---|---|---|---|---|---|---|
| B1 | B2 | PAM | B1 | B2 | PAM | ||
| ResNet-18 | A | 89.83% | 89.30% | 89.93% | 5.72 | 8.84 | 4.78 |
| B | 88.25% | 88.20% | 88.36% | 5.72 | 8.84 | 4.83 | |
| A&B | 89.04% | 88.75% | 89.15% | 11.44 | 8.84 | 6.40 | |
| ResNet-34 | A | 89.99% | 89.70% | 90.05% | 8.43 | 12.11 | 6.94 |
| B | 88.44% | 88.98% | 88.42% | 8.43 | 12.11 | 6.94 | |
| A&B | 89.22% | 89.34% | 89.24% | 16.86 | 12.11 | 10.29 | |
| (27) | ||||
| (28) | ||||
| (29) | ||||
| (30) |
| Model/Dataset | Task | Accuracy | FLOPs () |
|---|---|---|---|
| LeNet-5/Fashion-MNIST | A | 96.05% | 106.42 |
| B | 96.37% | 106.42 | |
| VGG-16/CelebA | A | 90.28% | 3112.20 |
| B | 89.03% | 3112.20 | |
| VGG-16/LFW | A | 90.23% | 3110.12 |
| B | 84.15% | 3110.12 | |
| C | 85.03% | 3110.12 | |
| D | 86.62% | 3110.12 | |
| E | 87.44% | 3110.12 | |
| ResNet-18/CelebA | A | 90.56% | 994.00 |
| B | 88.91% | 994.00 | |
| ResNet-34/CelebA | A | 90.42% | 1115.06 |
| B | 88.70% | 1115.06 |
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
Taxonomy
TopicsAdvanced Neural Network Applications · Domain Adaptation and Few-Shot Learning · Anomaly Detection Techniques and Applications
MethodsPruning
Pruning-Aware Merging for Efficient Multitask Inference
Xiaoxi He
ETH ZürichZürichSwitzerland
,
Dawei Gao
SKLSDE & BDBC, Beihang UniversityBeijingChina
,
Zimu Zhou
Singapore Management UniversitySingaporeSingapore
,
Yongxin Tong
SKLSDE & BDBC, Beihang UniversityBeijingChina
and
Lothar Thiele
ETH ZürichZürichSwitzerland
(2021; 2021)
Abstract.
Many mobile applications demand selective execution of multiple correlated deep learning inference tasks on resource-constrained platforms. Given a set of deep neural networks, each pre-trained for a single task, it is desired that executing arbitrary combinations of tasks yields minimal computation cost. Pruning each network separately yields suboptimal computation cost due to task relatedness. A promising remedy is to merge the networks into a multitask network to eliminate redundancy across tasks before network pruning. However, pruning a multitask network combined by existing network merging schemes cannot minimise the computation cost of every task combination because they do not consider such a future pruning. To this end, we theoretically identify the conditions such that pruning a multitask network minimises the computation of all task combinations. On this basis, we propose Pruning-Aware Merging (PAM), a heuristic network merging scheme to construct a multitask network that approximates these conditions. The merged network is then ready to be further pruned by existing network pruning methods. Evaluations with different pruning schemes, datasets, and network architectures show that PAM achieves up to less computation against the baseline without network merging, and up to less computation against the baseline with a state-of-the-art network merging scheme.
Deep Learning; Network Pruning; Multitask Inference
††journalyear: 2021††copyright: acmcopyright††conference: Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery and Data Mining; August 14–18, 2021; Singapore, Singapore††booktitle: Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD ’21), August 14–18, 2021, Singapore, Singapore††price: 15.00††doi: 10.1145/1122445.1122456††isbn: 978-1-4503-XXXX-X/21/08††copyright: acmcopyright††journalyear: 2021††doi: 10.1145/1122445.1122456††ccs: Computing methodologies Neural networks
1. Introduction
Deep neural networks that can run locally on resource-constrained devices hold potential for various emerging applications such as autonomous drones and social robots (Fang et al., 2018; Lee and Nirjon, 2020). These applications often simultaneously perform a set of correlated inference tasks based on the current context to deliver accurate and adaptive services. Although deep neural networks pre-trained for individual tasks are readily available (LeCun et al., 1998; Simonyan and Zisserman, 2014), deploying multiple such networks easily overwhelms the resource budget.
To support these applications on low-resource platforms, we investigate efficient multitask inference. Given a set of correlated inference tasks and deep neural networks (each network pre-trained for an individual task), we aim to minimise the computation cost when any subset of tasks is performed at inference time.
One naive solution to efficient multitask inference is to prune each network for individual tasks separately. A deep neural network is typically over-parameterised (Denil et al., 2013). Network pruning (Dai et al., 2018; Deng et al., 2020; Gao et al., 2020; Molchanov et al., 2019; Sze et al., 2017) can radically reduce the number of operations within a network without accuracy loss in the inference task. This solution, however, is only optimal if a single task is executed at a time. When multiple correlated tasks are running concurrently, this solution is unable to save computation cost by exploiting tasks relatedness and sharing intermediate results among networks.
A more promising solution framework is “merge & prune”, which merges multiple networks into a multitask network, before pruning it (Fig. 1). A few pioneer studies (Chou et al., 2018; He et al., 2018) have explored network merging schemes to eliminate the redundancy among multiple networks pre-trained for correlated tasks. However, pruning a multitask network merged via these schemes can only minimise computation cost when all tasks are executed at the same time.
In this paper, we propose Pruning-Aware Merging (PAM), a new network merging scheme for efficient multitask inference. By applying existing network pruning methods on the multitask network merged by PAM, the computation cost when performing any subset of tasks can be reduced. Extensive experiments show that “PAM & Prune” consistently achieves solid advantages over the state-of-the-art network merging scheme across tasks, datasets, network architectures and pruning methods.
Our main contributions and results are as follows:
- •
We theoretically show that pruning a multitask network may not simultaneously minimise the computation cost of all task combinations in the network. We then identify conditions such that minimising the computation of all task combinations via network pruning becomes feasible. To the best of our knowledge, this is the first explicit analysis on the applicability of network pruning in multitask networks.
- •
We propose Pruning-Aware Merging (PAM), a heuristic network merging scheme to construct a multitask network that approximately meets the conditions in our analysis and enables “merge & prune” for efficient multitask inference.
- •
We evaluate PAM with various pruning schemes, datasets and architectures. PAM achieves up to less computation cost against the baseline without network merging, and up to less computation cost against the baseline with the state-of-the-art network merging scheme (He et al., 2018).
In the rest of this paper, we review related work in Sec. 2, introduce our problem statement in Sec. 3, theoretical analysis in Sec. 4 and our solution in Sec. 5. We present the evaluations of our methods in Sec. 6 and finally conclude in Sec. 7.
2. Related Work
Our work is related to the following categories of research.
Network Pruning. Network pruning reduces the number of operations in a deep neural network without loss in accuracy (Deng et al., 2020; Sze et al., 2017). Unstructured pruning removes unimportant weights (Dong et al., 2017; Gao et al., 2020; Guo et al., 2016). However, customised hardware (Han et al., 2016) is compulsory to exploit such irregular sparse connections for acceleration. Structured pruning enforces sparsity at the granularity of channels/filters/neurons (Dai et al., 2018; Li et al., 2017; Molchanov et al., 2019; Wen et al., 2016). The resulting sparsity is fit for acceleration on general-purpose processors. Prior pruning proposals implicitly assume a single task in the given network. We identify the challenges to prune a multitask network and propose a network merging scheme such that pruning the merged multitask network minimises computation cost of all task combinations in the network.
Multitask Networks. A multitask network can be either constructed from scratch via Multi-Task Learning (MTL) or merged from multiple networks pre-trained for individual tasks. MTL joint trains multiple tasks for better generalisation (Zhang and Yang, 2017), while we focus on the computation cost of running multiple tasks at inference time. Network merging schemes (Chou et al., 2018; He et al., 2018) aim to construct a compact multitask network from networks pre-trained for individual tasks. Both MTZ (He et al., 2018) and NeuralMerger (Chou et al., 2018) enforce weight sharing among networks to reduce their overall storage. In contrast, we account for the computation cost of a multitask network. Although constructing a multitask network using these schemes (Chou et al., 2018; He et al., 2018) and pruning it via existing pruning methods can reduce the computation when all tasks are concurrently executed, they cannot minimise the computation cost for every combination of tasks.
3. Problem Statement
We define and analyse our problem based on the graph representation of neural networks. The graph representation reflects the computation cost of neural networks (see below) and facilitates an information theoretical understanding on network pruning (see Sec. 4). Fig. 2 shows important notations used throughout this paper. For ease of illustration, we explain our analysis using two tasks. Extensions to more than two tasks are in Sec. 5.4.
3.1. Graph Representation of Neural Networks
Task. Consider three sets of random variable , , and . Task outputs , a prediction of , by learning the conditional distribution . Task outputs , a prediction of , by learning .
Single-Task Network. For task , a neural network without feedback loops can be represented by an acyclic directed graph . Each vertex represents a neuron. There is an edge between two vertices if two neurons are connected. The vertex set can be categorised into three types of nodes: source, internal and sink node. / is the indegree/outdegree of a vertex .
- •
Source node set represents the input layer. Each source node represents an input neuron and outputs a random variable . The output of the input layer is the input random variable set .
- •
Internal nodes represents the hidden neurons. The output of each hidden neuron is generated by calculating the weighted sum of its inputs and then applying an activation function.
- •
Sink node set represents the output layer. Each sink node represents an output neuron and the output is calculated in the same way as the hidden neurons. The output of the output layer is the prediction of ground-truth labels .
We organise the hidden neurons of into layers by Algorithm 1. represents the out-coming neighbours of the vertex set . Algorithm 1 can organise any acyclic single-task network into layers and the layer outputs satisfy the Markov property.
Multitask Network. For task and , a multitask network without feedback loops can be represented by an acyclic directed graph . All paths from the input neurons to the output neurons for task form a subgraph (see Fig. 2(c)), which is in effect the same as a single-task network. When only task is performed, only is activated. Subgraph is defined similarly. We also organise vertices of and into layers with Algorithm 1. Layer outputs of and are denoted as and . Suppose and have respectively and hidden layers. We assume w.l.o.g.. Then the -th layer output of is defined as with . As shown in Fig. 2(b), consists of three sets of neurons: , and .
Remarks. The above definitions have two benefits. (i) The computation cost of a neural network is an increasing function of the size of the graph, i.e., the number of edges plus vertices. Reducing the computation cost of the network is transformed into removing edges or vertices in the graph. (ii) For a single-task network with hidden layers, its layer outputs form a Markov chain: . All layer outputs in a multitask network also form a Markov chain. The Markov property allows an information theoretical analysis on neural networks (Saxe et al., 2018; Tishby and Zaslavsky, 2015).
3.2. Problem Definition
Given two single-task networks and pre-trained for task and , we aim to construct a multitask network such that pruning on can minimise the number of vertices and edges in , and while preserving inference accuracy on and . To ensure minimal computation of any subset of tasks, we need to minimise the number of vertices and edges in any subgraph. For two tasks, corresponds to running task and concurrently; () corresponds to running task () only. Next, we show the difficulty to optimise all subgraphs simultaneously.
4. Theoretical Understanding
This section presents a theoretical understanding on the challenges to prune a multitask network and identifies conditions such that minimising the computation cost of all task combinations via pruning becomes feasible (Theorem 3). Proofs are in Appendix A.
4.1. Why Pruning a Single-task Network Work
Pruning a single-task network reduces the computation cost of a neural network while retaining task inference accuracy by suppressing redundancy in the network (Deng et al., 2020; Sze et al., 2017). From the information theoretical perspective (Saxe et al., 2018; Tishby and Zaslavsky, 2015), since the layer outputs form a Markov chain, the inference accuracy for a given task is positively correlated to the task related information transmitted through the network at each layer, measured by . All other information is irrelevant for the task. Hence the redundancy within a single-task network can be defined as below.
Definition 0.
For the -th layer in the single-task neural network , the redundancy of the layer is defined as .
measures the maximal amount of information the layer can express. measures the amount of task related information in the layer output. By definition, .
Remarks. is positively correlated to the number of vertices and incoming edges of the -th layer. Therefore, in a well trained network where can no longer increase, the computation cost can be minimised by reducing .
Accordingly, pruning a single-task network can be formalised as an optimisation problem
[TABLE]
where controls the trade-off between inference accuracy and computation cost.
Remarks. Existing pruning methods implicitly assume a single-task network. That is, they are all designed to solve optimisation problem (1), even though the concrete strategies vary. We now show the problems that occur when these pruning methods are applied to a multitask network.
4.2. Why Pruning a Multitask Network Fail
As mentioned in Sec. 3.2, we aim to minimise the computation cost of any subset of tasks, which is a multi-objective optimisation problem. As we will show below, existing network pruning methods are unable to handle these objectives simultaneously.
We first define redundancy when performing two tasks at the same time, similarly as in Definition 1.
Definition 0.
For a multitask network , the redundancy of its -th layer is .
Following the above definitions of redundancy, our objective in Sec. 3.2 is equivalent to minimising the redundancy in as well as in its two subgraphs and , which leads to the following three-objective optimisation (still, we assume w.l.o.g.):
[TABLE]
Reducing , and decreases the number of vertices and edges in , and , respectively. are parameters to control the trade-off between computation cost and inference accuracy, as well as to balance task and .
To solve optimisation problem (2) with prior network pruning methods, we observe two problems.
Problem 1: The first two objectives in (2) may conflict. This is because reducing may decrease (proofs in Appendix A.1). In other words, when pruning subgraph , it is possible that some information related to task A is removed from the shared vertices between and . Hence decreases and the inference accuracy of task deteriorates.
Problem 2: It is unclear how to minimise the third objective in (2). As mentioned in Sec. 4.1, most pruning methods are designed with a single-task network in mind. It is unknown how to apply them to a multitask network with architecture in Fig. 2 (a).
4.3. When Pruning a Multitask Network Work
The two problems in Sec. 4.2 show that not all multitask networks can be pruned for efficient multitask inference. However, a multitask network can be effectively pruned if it meets the conditions stated by the following theorem.
Theorem 3.
If , the conditions below are satisfied:
[TABLE]
where is the co-information (Bell, 2003), then the three-objective optimisation problem (2) can be reduced to two non-conflicting optimisation problems that can be solved independently:
[TABLE]
Each of the two optimisation problems (4) are in effect single-task pruning problem like optimisation problem (1), which can be effectively solved by prior pruning proposals.
Remarks. Theorem 3 provides important guidelines to design the network merging scheme for our problem in Sec. 3.2. Specifically, if and can be merged into a a multitask network such that conditions (3) are satisfied, we can simply apply existing network pruning on the two subgraphs and to minimise the computation cost when performing any subset of tasks.
5. Pruning-Aware Merging
Based on the above analysis, we propose Pruning-Aware Merging (PAM), a novel network merging scheme that constructs a multitask network from pre-trained single task networks. PAM approximately meets the conditions in Theorem 3 such that the merged multitask network can be effectively pruned for efficient multitask inference.
5.1. PAM Workflow
Given two single-task networks and pre-trained for task and (), PAM constructs a multitask network with the steps below (see Fig. 3).
- (1)
Assign , as , and use the same inputs. 2. (2)
For , regroup the neurons from and into , and by the regrouping algorithm in Sec. 5.2. 3. (3)
Take over the output layer for task : . For , take over the remaining layers from : . 4. (4)
Reconnect the neurons as in Fig. 3. If a connection exist before merging, it preserves its original weight. Otherwise it is initialised with a zero. 5. (5)
Finetune on and to learn the newly added connections. For the shared connections, . The gradients are first calculated separately on and , and then averaged before weight updating.
Now the multitask network is ready to be pruned. From Theorem 3, we can apply network pruning on the two subgraphs and independently and achieve a minimal computation cost for all combinations of tasks. However, since we only approximate the conditions in (3), pruning and is not perfectly independent in practice. Hence we prune and in an alternating manner to balance between task and .
5.2. Regrouping Algorithm
The core of PAM is the regrouping algorithm in the second step in Sec. 5.1. It regroups the neurons from and into three sets: , and , such that the conditions (3) in Theorem 3 are satisfied. However, it is computation-intensive to estimate the co-information and conditional mutual information in (3) precisely. We rely on the following theorem to approximate the conditions.
Theorem 1.
The conditions in (3) can be achieved by minimising , , and maximising , .
Remarks. and describe the “misplaced” information, i.e., the information that is useful for one task, but contained in neurons that are not connected to the outputs of this task. Therefore such information is redundant and needs to be minimised. and measure the “relevant” information, i.e., the information useful for one task and contained in neurons connected to this task. Note that this information may not be simply maximised, because it includes the information that is useful for both tasks. It requires simultaneously minimising the “misplaced” information and maximising the “correct” information to achieve the conditions in (3). The proof of Theorem 1 is in Sec. A.3.
Based on Theorem 1, we propose an algorithm to regroup the neurons such that conditions (3) are approximately met. It constructs the largest possible set and from all the neurons in and while and remain close to zero, such that and are approximately maximised. To estimate and , we use a Kullback–Leibler-based mutual information upper bound estimator from (Kolchinsky and Tracey, 2017).
Algorithm 2 illustrates the pseudocode to regroup the neurons such that the conditions in Theorem 3 are approximated met. Central in Algorithm 2 is a greedy search in Lines 5-8 and 10-13. In Lines 5-8, we search for the largest possible set of neuron while remains approximately zero (smaller than a pre-defined threshold ), such that is approximately maximised. Similarly, in Lines 10-13, we approximately maximise while keeping close to zero. According to Theorem 1, the conditions in Theorem 3 are approximately met.
Practical Issue: How to Estimate Mutual Information. We use a Kullback–Leibler-based mutual information upper bound estimator from (Kolchinsky and Tracey, 2017) to estimate the upper bounds of and . Since the upper bounds are approximate, it is impossible to request them to be exactly zero. Hence, we use a threshold parameter to keep and close to zero.
Practical Issue: How to Tune Threshold . The parameter affects the performance of “PAM & prune”. A larger results in more neurons in and and fewer shared neurons in . In this case, the multitask network after “PAM & prune” performs worse in terms of efficiency when both tasks are executed concurrently, but better when only one task is executed (similar to “baseline 1 & prune”). Conversely, a smaller results in more shared neurons. In this case, the multitask network after “PAM & prune” performs worse when only one task is executed, but better when both tasks are executed concurrently, (similar to “baseline 2 & prune”).
The parameter can be empirically tuned as follows:
- (1)
Execute Algorithm 2 with a small . 2. (2)
Increase the value of slightly and rerun Algorithm 2. Since Lines 5-8 and 10-13 are greedy search, the results for the smaller in Step 1 (i.e., the already constructed neuron sets and ) can be reused, instead of starting with empty sets as in Line 4 and 9. 3. (3)
Iterate Step 2 till a satisfying balance among task combinations. In each iteration of Step 2, we can reuse the neuron sets and from the last iteration.
The impact of is shown in Appendix C.
5.3. Extensions to ResNets
In order to support merging Residual Networks (He et al., 2016), PAM needs to be slightly modified. As illustrated in Fig. 4, the regrouping of the last layer in each residual block happens not directly after the weighted summation, but after the superposition with the shortcut connection and just before the vector is passed as inputs to the first layer in the next block. This input vector of the first layer in each block is also regrouped using Algorithm 2 and then pruned at a later stage. This special treatment for the last layer in each residual block is consistent with ResNet compatible pruning methods such as (Molchanov et al., 2019), which can also prune the block outputs just before it is fed into the first layer in the next block.
5.4. Extension to Three or More Tasks
When there are tasks, we define the set of all the task as . The merged multitask network can be divided into subgraphs , where and is a nonempty subset of tasks. Each vertex in has paths to all the outputs with . When a task combination (i.e., a subset of tasks) is executed, only subgraph is activated. Layers in is denoted as . The output layer for task combination is denoted as , which is the prediction of ground-truth labels .
Extension of Theorem 3. For any pair of non-overlapped nonempty subsets of task and (), define:
[TABLE]
Then Theorem 3 is extended into:
Theorem 2.
If for all with , and for any pair of non-overlapped nonempty subsets of task and , the following conditions are satisfied:
[TABLE]
then the computation cost of executing all task combinations can be minimised by the following non-conflicting optimisation problems that can be solved independently:
[TABLE]
Theorem 2 can be proven by recursively applying Theorem 3.
Extension of PAM. The neuron sets , and are extended to:
[TABLE]
Note that neurons in are activated iff any task is executed. Now Algorithm 2 is extended to Algorithm 3. And at step 5 of the PAM workflow in Sec. 5.1, we connect iff .
It is worth mentioning that when tasks are highly related, the numbers of neurons in with can be extremely small (as in our experiment on the LFW dataset in Appendix B). Therefore we can simplify Algorithm 3 by fixing and skip the remaining loops. Every layer in the multitask network merged by the simplified PAM contains only neuron sets with and one shared neuron set . Shared neurons in are always activated, while non-shared neurons in are activated iff task is executed.
6. Experiments
We compare different network merging schemes on whether lower computation is achieved when performing any subset of tasks.
6.1. Experiment Settings
Baselines for Network Merging. We compare PAM with two merging schemes.
- •
Baseline 1. It simply skips network merging in the “merge & prune” framework. Therefore, no multitask network is constructed. As mentioned in Sec. 1, this scheme optimises the pruning of single-task networks.
- •
Baseline 2. Pre-trained single-task networks are merged as a multitask network by MTZ (He et al., 2018), a state-of-the-art network merging scheme. Applying MTZ in “merge & prune” can minimise the computation cost of a multitask network when all tasks are executed.
Methods for Network Pruning. Since we aim to compare different network merging schemes in the “merge & prune” framework, we apply the same network pruning method on the neural network(s) constructed by different merging schemes. To show that PAM works with different pruning methods, we choose two state-of-the-art structured network pruning methods: one (Dai et al., 2018) uses information theory based metrics (denoted as P1), and the other (Molchanov et al., 2019) uses sensitivity based metrics (denoted as P2).
The pruning methods are applied to the neural network(s) constructed by different merging schemes as follows. For Baseline 1, each single-task network is pruned independently. For the multitask network constructed with Baseline 2 and PAM, we prune every subgraph for each individual task in an alternating manner (e.g., task ) in order to balance between tasks. However, only P2 is originally designed to prune a ResNet. Hence we only experiment ResNets with P2.
Datasets and Single-Task Networks. We define tasks from three datasets: Fashion-MNIST (Xiao et al., 2017), CelebA (Liu et al., 2015), and LFW (Huang et al., 2012). Fashion-MNIST and CelebA each contains two tasks. LFW contains five tasks. We use LeNet-5 (LeCun et al., 1998) as pre-trained single-task networks for tasks derived from Fashion-MNIST, and VGG-16 (Simonyan and Zisserman, 2014) for tasks from CelebA and LFW. We also use ResNet-18 and ResNet-34 (He et al., 2016) as pre-trained single-task networks for CelebA. See Appendix B for more details of dataset setup and the inference accuracy and FLOPs of the pre-trained single-task networks.
Evaluation Metrics. For a given set of tasks, we aim to minimise the computation cost of all task combinations. To assess computation cost independent of hardware, we use the number of floating point operations (FLOP) as the metric. For fair comparison, the network(s) constructed by different merging schemes are pruned while preserving almost the same inference accuracy. To quantify the performance advantage of PAM over baselines over all task combinations, we adopt the following two single-valued criteria:
- •
Average Gain. This metric measures the averaged computation cost reduction of “PAM & prune” over “baseline & prune” across all task combinations. For example, given two tasks and , there are three task combinations: , and . When executing these task combinations, the FLOPs of the network after “PAM & prune” are , and , respectively. After “baseline 1 & prune”, the FLOPs are , and , respectively. The average gain over baseline 1 is calculated as .
- •
Peak Gain. This metric measures the maximal computation cost reduction across all task combinations. Using the same example and notations as above, the peak gain over baseline 1 is calculated as .
All experiments are implemented with TensorFlow and conducted on a workstation with Nvidia RTX 2080 Ti GPU.
6.2. Main Experiment Results
Overall Performance Gain. Fig. 5 shows the average and peak gains of PAM over the two baselines tested with different models (LeNet-5, VGG-16, ResNet-18, RestNet-34), datasets (Fashion-MNIST, CelebA, LFW), and pruning methods (P1, P2). The detailed FLOPs and inference accuracy on task merging (Fashion-MNIST and CelebA) are listed in Table 1, Table 2, Table 3 and Table 4.
Compared with baseline 1, PAM achieves to average gain and to peak gain. Compared with baseline 2, PAM achieves to average gain and to peak gain. In general, PAM has significant performance advantage over both baselines across datasets and network architectures.
Effectiveness of PAM. From Fig. 5, the performance gain of PAM varies across baselines and datasets. Such variations in average and peak gains are influenced by how many neurons are shared and how many networks are merged. Fig. 6 shows how many neurons (kernels) are shared after “PAM & prune” on LeNet-5 and VGG-16.
- •
The more neurons shared, the higher gain PAM has over baseline 1. “Baseline 1 & prune” can effectively reduce the computation cost when only one task is performed. However, when many neurons can be shared (see Fig. 6(b), (c), (e), and (f)), baseline 1 is sub-optimal when multiple tasks are executed simultaneously, as it is unable to reduce computation by sharing neurons. This is why PAM outperforms baseline 1 more on CelebA and LFW.
- •
The fewer neurons shared, the higher gain PAM has over baseline 2. “Baseline 2 & prune” can effectively reduce the computation cost via neuron sharing when all tasks are performed simultaneously. However, when only few neurons can be shared (see Fig. 6(a) and (d)), the multitask network merged by baseline 2 cannot shut down the unnecessary neurons when not all tasks are executed, and hence yields sub-optimal computation cost. This is why PAM outperforms baseline 2 more on Fashion-MNIST.
- •
The more networks merged, the higher gain PAM has over both baselines. As the number of single-task networks (tasks) increases, “PAM & prune” can either share more neurons and yield lower computation than “baseline 1 & prune”, or shut down more unnecessary neurons and yield lower computation than “baseline 2 & prune”. Therefore the performance gain of PAM over baseline 1 on LFW is such significantly higher than on CelebA. This is also the reason why the performance gain of PAM over baseline 2 on LFW is not much lower than on CelebA, although on LFW we have the highest degree of sharing.
Takeaways. Although the performance of PAM varies across tasks, it achieves consistently solid advantages over both baselines. We may conclude that it is always preferable to use PAM for efficient multitask inference, regardless of the amount of shareable neurons, of the probability of executing each task combination, of the network architecture, or of the pruning method used after merging.
6.3. Ablation Study
This subsection presents experiments to further understand the effectiveness of PAM.
6.3.1. Impact of Task Relatedness
This study aims to show the impact of task relatedness on the performance gain PAM can achieve. The number of neurons that can be shared among pre-trained networks is related to the relatedness among tasks. An effective network merging scheme should enforce increasing numbers of shared neurons between tasks with the increase of task relatedness.
Settings. We consider the 73 labels in LFW as 73 binary classification tasks, and measure the relatedness between each task pair by . We then pick four pairs of tasks with , , and bits, train four pairs of single-task VGG-16’s on them, and construct four multitask networks using PAM.
Results. Fig. 7a plots the number of shared neurons in layer f7 of these four multitask networks with different tuning threshold . The multitask networks for tasks pairs with higher correlation always share neurons. Hence, PAM can share an increasing number of neurons between tasks with the increase of task relatedness.
6.3.2. Case Study: Task Inclusion
This study aims to validate the effectiveness of PAM in an extreme yet common case of task relatedness where task is a sub-task of task . Ideally, when the mutual information is precisely estimated and true largest sets of task-exclusive neurons are selected, PAM should effectively pick out only task--exclusive neurons.
Settings. We pick 30 labels in LFW as task and 15 of them as task . Hence task includes task . We train two single-task VGG-16’s on these two tasks separately and then merge them by PAM.
Results. Fig. 7b shows the number of non-shared neurons in and in the last eight layers of the merged network (the previous layers have exclusively shared neurons). Almost no neurons are selected for by Algorithm 2, validating its effectiveness.
7. Conclusion
In this paper, we investigate network merging schemes for efficient multitask inference. Given a set of single-task networks pre-trained for individual tasks, we aim to construct a multitask network such that applying existing network pruning methods on it can minimise the computation cost when performing any subset of tasks. We theoretically identify the conditions on the multitask network, and design Pruning-Aware Merging (PAM), a heuristic network merging scheme to construct such a multitask network. The merged multitask network can then be effectively pruned by existing network pruning methods. Extensive evaluations show that pruning a multitask network constructed by PAM achieves low computation cost when performing any subset of tasks in the network.
Appendix
Appendix A Proofs
A.1. Proof of Problem 1 in Sec. 4.2
Problem 1 occurs because of the lemma below.
Lemma 0.
Reducing may decrease .
Proof.
We decompose :
[TABLE]
where is the co-information (Bell, 2003). From Definition 1, we have:
[TABLE]
For the last term, we have:
[TABLE]
Hence, includes . Reducing may decrease . ∎
A.2. Proof of Theorem 3
Proof.
The proof shows the conditions in Theorem 3 solve (i) Problem 1 in Sec. 4.2 and (ii) Problem 2 in Sec. 4.2.
Solving Problem 1 in Sec. 4.2. From (11) we have the following if :
[TABLE]
is not in . Hence is unaffected when is reduced. is included in . Thus minimising will not reduce with a proper . All still hold if we swap and in the above equations. Consequently, if = [math], the first two objectives in optimisation problem (2) become non-conflicting.
Solving Problem 2 in Sec. 4.2. We first decompose as in Table 5. Then from (30), we have
[TABLE]
Further,
[TABLE]
This is a loose upper bound. However, since , and are lower bounded by [math], it suffices to show that when , minimising and will minimise .
In summary, when
[TABLE]
the optimisation problem (2) is reduced to two non-conflicting optimisation problems (4). ∎
A.3. Proof of Theorem 1
Proof.
First, for co-information between four random variables, we have from (Bell, 2003):
[TABLE]
Therefore, the first condition in Theorem 3, i.e., = [math], is achieved by minimising and to [math].
For the second condition in Theorem 3, i.e., , we have:
[TABLE]
Given and , is constant. The second condition in Theorem 3 is achieved by minimising to [math] and maximising to .
The same holds if we swap and . The third condition in Theorem 3, i.e., , is achieved by minimising and maximising . ∎
Appendix B Detailed Dataset Setup
Fashion-MNIST. The Fashion-MNIST dataset111https://github.com/f-rumblefish/Multi-Label-Fashion-MNIST contains training images and test images with a resolution of . Each image has four fashion product images randomly selected from Fashion-MNIST (Xiao et al., 2017). The 10 categories of fashion products is considered as 10 binary classification problem, and we divide them into two groups (5/5) to form task and . On each task we train a LeNet-5, a commonly used architecture for Fashion-MNIST.
CelebA.
The CelebA dataset222http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html contains over thousand celebrity face images labelled with attributes. The attributes is divided into two groups (/) to form task and . The dataset is divided into training and test sets containing 80% and 20% of the samples. The input picture resolution is resized to . On each task we train slightly modified VGG-16 models, a commonly used single-task network architecture on CelebA. The width of the fully connected layers in VGG-16 is changed to 512. The convolutional layers are initialised with weights pre-trained for imdb-wiki (Rothe et al., 2018), and use the same pre-processing steps.
LFW. The Labeled Faces in the Wild (LFW) dataset333http://vis-www.cs.umass.edu/lfw/ contains over 13,000 face photographs collected from the web. Each face photo is associated with 73 attributes (Kumar et al., 2009). We randomly split the 73 labels in the LFW dataset into four groups with 15 labels each and one group with 13 labels. Each group of labels forms a single task. The dataset is divided into training and test sets containing 80% and 20% of the samples. Same as in CelebA, the input picture resolution is resized to . On each task we train slightly modified VGG-16 models, a commonly used single-task network architecture on LFW. The width of the fully connected layers in VGG-16 is changed to 128. The convolutional layers are initialised with weights pre-trained for imdb-wiki (Rothe et al., 2018), and use the same pre-processing steps.
Table 6 summarises the inference accuracy and FLOPs of the pre-trained single-task networks.
Appendix C Visualisation of Algorithm 2
Fig. 8 illustrates two iterations of Line 19-22 and 24-27 in Algorithm 2 by showing and against the number of iterations. Here we use the f7 layer of VGG-16 trained and merged for CelebA dataset as an example. The tuning parameter is set to infinitely large in order to show all the possible cases of the iterations. From Fig. 8, we can observe three phases:
- (1)
In the first phase, and remains small, indicating that the selected and provides little information about the other task. 2. (2)
In the second phase, and start to increase as it is impossible to add more neurons to and while keeping and close to zero. 3. (3)
In the third phase, and start to saturate as the newly joined neurons contain mostly information already included in existing and .
In practice, the parameter tuned as remains small, and the iterations in Algorithm 2 as well as Algorithm 3 usually stop at the end of the first phase or the beginning of the second phase.
The reference list from the paper itself. Each links out to its DOI / PubMed record.
- 1(1)
- 2Bell (2003) Anthony J Bell. 2003. The co-information lattice. In International Workshop on Independent Component Analysis and Blind Signal Separation: ICA . IEEE Press, Piscataway, NJ, USA.
- 3Chou et al . (2018) Yi-Min Chou, Yi-Ming Chan, Jia-Hong Lee, Chih-Yi Chiu, and Chu-Song Chen. 2018. Unifying and merging well-trained deep neural networks for inference stage. In IJCAI . Morgan Kaufmann, Burlington, MA, USA, 2049–2056.
- 4Dai et al . (2018) Bin Dai, Chen Zhu, Baining Guo, and David Wipf. 2018. Compressing neural networks using the variational information bottleneck. In ICML . ACM, New York, NY, USA, 1143–1152.
- 5Deng et al . (2020) Lei Deng, Guoqi Li, Song Han, Luping Shi, and Yuan Xie. 2020. Model compression and hardware acceleration for neural networks: a comprehensive survey. Proc. IEEE 108, 4 (2020), 485–532.
- 6Denil et al . (2013) Misha Denil, Babak Shakibi, Laurent Dinh, Nando De Freitas, et al . 2013. Predicting parameters in deep learning. In Neur IPS . Curran Associates Inc., Red Hook, NY, USA, 2148–2156.
- 7Dong et al . (2017) Xin Dong, Shangyu Chen, and Sinno Pan. 2017. Learning to prune deep neural networks via layer-wise optimal brain surgeon. In Neur IPS . Curran Associates Inc., Red Hook, NY, USA, 4860–4874.
- 8Fang et al . (2018) Biyi Fang, Xiao Zeng, and Mi Zhang. 2018. Nest DNN: resource-aware multi-tenant on-device deep learning for continuous mobile vision. In Mobi Com . ACM, New York, NY, USA, 115–127.
