A Unifying Causal Framework for Analyzing Dataset Shift-stable Learning Algorithms
Adarsh Subbaswamy, Bryant Chen, Suchi Saria

TL;DR
This paper introduces a causal graphical framework to unify understanding of dataset shift invariance, enabling the derivation of optimal stable distributions and analyzing stability-performance tradeoffs.
Contribution
It provides a unifying causal framework that characterizes invariant distributions across dataset shifts and develops algorithms for optimal stability.
Findings
Invariant distributions form a causal hierarchy enabling stability analysis.
Conditions for minimax optimal performance across environments are established.
Empirical results show a tradeoff between minimax and average performance.
Abstract
Recent interest in the external validity of prediction models (i.e., the problem of different train and test distributions, known as dataset shift) has produced many methods for finding predictive distributions that are invariant to dataset shifts and can be used for prediction in new, unseen environments. However, these methods consider different types of shifts and have been developed under disparate frameworks, making it difficult to theoretically analyze how solutions differ with respect to stability and accuracy. Taking a causal graphical view, we use a flexible graphical representation to express various types of dataset shifts. Given a known graph of the data generating process, we show that all invariant distributions correspond to a causal hierarchy of graphical operators which disable the edges in the graph that are responsible for the shifts. The hierarchy provides a common…
Peer Reviews
No public reviews on file for this paper yet. If you reviewed it on a platform where reviews are public (OpenReview, ICLR, NeurIPS, ICML), you can paste yours below so the community can read it here.
Videos
No videos yet. Explain this paper in a talk, walkthrough, or lecture? Add one.
A Unifying Causal Framework for Analyzing Dataset Shift-stable Learning Algorithms
Adarsh Subbaswamy
Department of Computer Science
Johns Hopkins University
&Bryant Chen
Brex Inc.
&Suchi Saria
Department of Computer Science
Johns Hopkins University & Bayesian Health
(Published May 19, 2022 in the Journal of Causal inference)
Abstract
Recent interest in the external validity of prediction models (i.e., the problem of different train and test distributions, known as dataset shift) has produced many methods for finding predictive distributions that are invariant to dataset shifts and can be used for prediction in new, unseen environments. However, these methods consider different types of shifts and have been developed under disparate frameworks, making it difficult to theoretically analyze how solutions differ with respect to stability and accuracy. Taking a causal graphical view, we use a flexible graphical representation to express various types of dataset shifts. Given a known graph of the data generating process, we show that all invariant distributions correspond to a causal hierarchy of graphical operators which disable the edges in the graph that are responsible for the shifts. The hierarchy provides a common theoretical underpinning for understanding when and how stability to shifts can be achieved, and in what ways stable distributions can differ. We use it to establish conditions for minimax optimal performance across environments, and derive new algorithms that find optimal stable distributions. Using this new perspective, we empirically demonstrate that that there is a tradeoff between minimax and average performance.
1 Introduction
Statistical and machine learning (ML) predictive models are being deployed in a number of high impact applications, including healthcare [1], law enforcement [2], and criminal justice [3]. These safety-critical applications have a high cost of failure—model errors can lead to incorrect decisions that have a profound impact on the quality of human lives—which makes it important to ensure that systems being developed and deployed for these problems behave reliably (i.e., they perform to their specification). To do so, developers are forced to reason in advance about likely sources of failure and address them prior to deployment (i.e., during model training). A key source of failure is due to dataset shifts [4, 5]: differences between the environment in which training data was collected and the environment in which the model will be deployed that manifest as changes in the data distribution. These differences can arise due to deploying a model at a new site from which data was unavailable during training, or due to natural variations that occur over time. Failing to account for these differences can result in model predictions with worse performance (i.e., expected loss) than anticipated.
Across a number of application domains, the recent COVID-19 pandemic has demonstrated ways in which dataset shifts can induce model failures. For example, the pandemic resulted in a drastic shift in online retail and the consumer packed goods industries: during the onset of the pandemic, the predictive algorithms powering Amazon’s supply chain failed due to the sudden increased demand for household supplies (e.g., bottled water and paper products), resulting in unprecedented item shortages and delivery delays [6].
Beyond changes to customer behavior, dataset shift has been identified as a key challenge to ensuring reliability in safety-critical domains such as healthcare (see, e.g., the example ways in which dataset shift can occur in medical applications in [5]). Consider the following examples: Long term (e.g., 3 year) patient mortality prediction models are used to help determine which patients may need long term support after being discharged from the hospital. In one study, the authors trained a model to predict 3 year patient mortality from electronic health record (EHR) data at a single hospital. The authors found that, for 68% of laboratory tests, the timing of the laboratory test orders was more predictive of mortality for the model than the corresponding values of those tests [7]. As a result, the model learned predictive dependencies between the time of day when a lab test was ordered and patient mortality. These dependencies are brittle: they are highly variant across hospitals because the timing of lab tests is determined by hospital-specific policies and physician-specific preferences [8, 9]. Models which have learned these brittle dependencies can experience significant deterioration in performance and become unsafe to use [10].
As another example, consider [11], in which the authors trained a model to diagnose pneumonia from chest X-rays. While the model was found to be very accurate on new patients at the medical center where it was developed, this performance deteriorated significantly when applied at new, but similar, medical centers. Their analysis showed that the model had learned dependencies between stylistic features (e.g., text, orientation, coloring) present in the X-ray and pneumonia. These associations varied widely across hospitals because the choice of stylistic features depended on the X-ray equipment, hospital policies, and technician preferences.
These examples demonstrate that dataset shifts can arise from a variety of changes (i.e., interventions) in the underlying data generating process (DGP) such as changes in behavior (e.g., shifts in clinician treatment patterns) or changes in data acquisition (e.g., new X-ray machines and settings). Preventing failures due to these kinds of shifts requires a causal understanding of the parts of the data generating process that can shift, and learning models of stable distributions that are invariant to these shifts.111We will refer to distributions that are stable to dataset shifts as “shift-stable” or simply “stable”..
Given that dataset shifts can happen in nearly every domain where predictive models are used, and given how serious the consequence of failures due to these shifts can be, it is critical to be able to ensure the reliability of models under such shifts: That is, we need to understand under a given set of dataset shifts, how will a model’s behavior change? What can be said about a model’s stability (i.e., are the model’s predictions still accurate after the shifts)? In the X-ray example, a model developer wants to know: what shifts can lead to model instability? Would shifts in color encoding schemes or upgrades to X-ray equipment lead to instability and deteriorate model accuracy? Has the model learned any dependencies (e.g., between pneumonia and choice of equipment) that will lead to this instability? For models trained using different algorithms, what guarantees do they give about stability to shifts in color schemes? How do the accuracies of these models differ under these dataset shifts? What guarantees can be made about a model’s worst-case performance under such shifts? Lacking common footing for framing stability and dataset shift, it is difficult to begin to answer these questions and compare algorithms.
A common framing of dataset shift is to assume limited data from a clearly defined “target” environment or distribution of interest is used (along with more plentiful data from a “source” environment) to make inferences about the target environment. This framing allows an analyst to *“reactively”*222The term “reactive” to describe approaches which use (possibly unlabeled) data from the target domain during learning was coined in [12]. adjust to target data samples. Reactive approaches to addressing dataset shift exist across multiple fields of study. Some examples include methods for domain adaptation in machine learning (for an overview, see [4]), generalizability and transportability in causal inference (e.g., [13, 14, 15, 16]), and sample selection bias in statistics and econometrics (e.g., [17, 18, 19, 20]). Reactive approaches to model training require data from the target environment which makes it difficult to learn a model from source data alone which will perform well in a new, unseen environment. In this case, it is important to instead use proactive learning approaches which learn models that are stable to any anticipated problematic shifts.
One common class of proactive learning methods is declarative in nature. These methods allow users to specify dependencies (i.e., causal relationships) between variables in the dataset that are likely to experience a dataset shift. That is, the data generating process is expected to differ between source and target environments due to an unknown intervention on the specified causal relationships. For instance, in the previously discussed X-ray example, a user might want to specify that the model should be stable to changes in scanner manufacturer, the choice of X-ray orientation (front-to-back vs back-to-front), or the color encoding scheme, since all of these are problematic shifts that are likely to occur. Example learning methods include approaches which find stable subsets of features [21, 12], approaches which learn models under hypothetically stable data generating processes [22, 23], and “counterfactual” approaches which compute counterfactual features [12, 10, 24] or perform data augmentation [25] to remove unstable dependencies. A key feature of declarative learning methods is that they can give guarantees about the stability of model predictions to changes in the specified shifts. When a user specifies that they desire invariance to the choice of X-ray orientation, then the declarative method will find a stable solution satisfying this specification. However, this requires domain expertise to be able to specify the likely problematic shifts to which stability is desired. A notable exception is the method of [23], which learns candidate shifts that occurred across datasets and allows users to choose which invariances to enforce. Additionally, while different declarative methods can guarantee stability, it is unknown what tradeoffs exist between methods with respect to accuracy. For example, models trained using stable feature subsets vs counterfactuals can both satisfy stability to, e.g., the choice of X-ray orientation, but we currently struggle to answer how their accuracy will compare and differ under shifts in X-ray orientation preferences. While both are stable, is one more accurate under shifts than the other?
A second class of proactive methods are imperative in nature. These methods take in datasets collected from multiple, heterogeneous environments and automatically extract invariant predictors from the data without user input [26, 27, 28, 29]. Examples of these methods are those that compute features sets [26] or representations [27, 28] that yield invariant predictors. In the X-ray example, imperative methods would require datasets collected from a large number of health centers which diversely represented the sets of shifts that could be observed (e.g., the datasets differ in terms of scanner manufacturer, X-ray orientation, and encoding schemes). An advantage of such approaches is that they do not require domain expertise in order to determine invariances. These methods often provide theoretical guarantees about minimax optimal performance (i.e., that they have the smallest worst-case error) across the input distributions. Thus, using an imperative approach a model developer can guarantee good worst-case performance at new hospitals that “look like” a mixture of the training hospitals. However, they generally do not provide guarantees about stability to a set of specified shifts: we do not know the ways in which the datasets differ (or by how much), so we cannot answer if the model is stable to shifts in scanner manufacturers, X-ray orientations, or encoding schemes.
The difficulties of understanding model behavior within and across each thread of work prevent rigorous analysis of the reliability of models under dataset shifts. In reality, we are presented with a prediction problem in which the data has been collected and generated under some DGP. Dataset shifts can then lead to changes to arbitrary pieces of the DGP. Thus, there is a need for a framework that enables us to answer the fundamental questions about model behavior under changes to the DGP. This would provide common ground to compare algorithms which address dataset shift, and to generate generalizable insights from methods which address particular instances of shifts.
1.1 Contributions
In this paper, we provide a unifying framework for specifying dataset shifts that can occur, analyzing model stability to these shifts, and determining conditions for achieving the lowest worst-case error (i.e., minimax optimal performance) across environments produced by these shifts. This provides common ground so that we can begin to answer fundamental questions such as: To what dataset shifts are the model’s predictions stable vs unstable? Has the model learned a predictive relationship that is stable to a set of prespecified shifts of interest? How will the model’s performance be affected by these shifts? For models trained using different methods, what guarantees do they provide about stability and accuracy?
The framework centers around two key requirements: First, a known causal graphical representation of the environment data generating process which includes specifications of what can shift. These specifications take the form of marked unstable edges in the graph which represent causal dependencies between variables that can shift across environments. We consider arbitrary shifts (i.e., interventions) to causal mechanisms in the graph, as opposed to, e.g., constraining shifted distributions to be within bounded norm-balls of the training data. This specification entails commonly studied instances of dataset shift (such as label shifts, covariate shifts, and conditional shifts), but also handles the more general, unnamed shifts that we expect to see in practice. Second, we restrict our analysis to methods whose target distribution can be expressed graphically. This entails algorithms which do not learn intermediate feature representations, but instead operate directly on the observed variables. For models which do not induce a graphical representation, we discuss how our results might be used to probe these models for their stability properties and discuss opportunities for future work to bridge this gap.
Our main contribution is the development of a causal hierarchy of stable distributions, in which distributions at higher levels of the hierarchy guarantee lower worst-case error. The levels of the hierarchy provide insight into how stability can be achieved: the levels correspond to three operators on the graphical representation of the environment, which modify the graph to produce stable distributions—the learning targets for stable learning algorithms (Definitions 6,7,8). We further show that the operators have different stability properties: higher level operators more precisely remove unstable edges from the graph when producing stable distributions (Corollary 2). Using this graphical characterization of stability, we provide a simple graphical criterion for determining if a distribution is stable to a set of prespecified shifts: a distribution is stable if it modifies the graph to remove the corresponding unstable edges (Theorem 1). We then address questions about the accuracy of different stable solutions by showing how the hierarchy provides a causal characterization of the minimax optimal predictor: the predictor which achieves the lowest worst-case error across shifted environments (Proposition 7). Surprisingly, we find that frequently studied intervention-invariant solutions generally do not achieve this minimax optimal performance. Finally, we demonstrate through a series of semisynthetic experiments that there is a tradeoff between minimax and average performance. Through these contributions, we provide a common theoretical underpinning for understanding model behavior under dataset shifts: when and how stability to shifts can be achieved, in what ways stable distributions can differ, and how to achieve the lowest worst-case error across shifted environments.
2 Related Work
While the focus of this work is on statistical and machine learning models, we briefly discuss concepts related to dataset shift that have been studied in other fields. In particular, external validity, or the ability of experimental findings to generalize beyond a single study, has long been an important goal in the social and medical sciences [30]. For example, practitioners (such as clinicians) who want to assess the results of a randomized trial must consider how the results of the trial relate to their target population of interest. This need has led to much discussion and work on assessing the generalizability of randomized trials (see, e.g., [31, 32, 33]). More recently, methodological work in causal inference has focused on transportability (see [16] for a review). For example, researchers have developed causal graphical methods for determining when and how experimental findings can be transported from one population or setting to a new one. (see, e.g., [13, 34, 15, 35]). Generalizability is also of importance in economics research [36], with much methodological work focusing on the problem of sample selection bias resulting from non-random data collection (see, e.g., [17, 18, 19]). Selection bias leads to systematic differences between the observed data and the general population, and thus is related to problems of external validity, which consider different environments or populations.
Returning to the focus of this work, the problem of differing train and test distributions in predictive modeling is known as dataset shift in machine learning [4]. Classical approaches, such as domain adaptation, assume access to unlabeled samples from the target distribution which they use to reweight training data during learning or extract invariant feature representations (e.g., [37, 38, 39, 40]). More recently, work on statistical transportability has produced sound and complete algorithms for determining how data from multiple domains can be synthesized to compute a predictive conditional distribution in the target environment of interest [13, 41]. These methods reactively adjust to target data. In many practical scenarios, however, it is not possible to get samples from all possible target distributions. Instead, this requires proactive methods that make assumptions about the set of possible target environments in order to learn a model from source data that can be applied elsewhere [12]. Work on proactive methods has primarily focused on either bounded or unbounded shifts. Bounded shifts have been studied through the lens of distributional robustness, often assuming shifts within a finite radius divergence ball (e.g., [42, 43]). [44] consider robustness to bounded magnitude interventions in latent style features. [45] consider bounded and unbounded mean-shifts in mechanisms (in which the means of certain variables vary by environment). [46] build on this to allow for stochastic mean-shifts. In this paper, we focus on unbounded shifts: in safety-critical domains the cost of failure is high and it can be difficult to accurately specify bounds.333Though the focus of this paper is on arbitrary shifts, we note that a “stability accuracy tradeoff” has also been observed in works on bounded distributional robustness. See [47] for an example.
Many proactive methods use datasets from multiple source environments to train invariant models to predict in new, unseen environments. [48] propose a kernel method to find a data transformation minimizing the difference between feature distributions across environments while preserving the predictive relationship. Invariant risk minimization (IRM) learns a representation such that the optimal classifier is invariant across the input environments [27, 49]. Similar works learn invariant predictors by seeking derivative invariance across environments [28, 29]. These methods learn representations, so the predictors do not induce a graphical representation. Despite this, we discuss how our graphical results can be applied to probe these methods for their stability properties. IRM establishes its ability to generalize to new environments using the framework of invariant causal prediction (ICP) [50]. Under ICP, the minimax optimal performance of a causally invariant predictor is tied to assuming shifts occur in all variables except the target prediction variable. In reality, however, specific shifts occur and we want to determine which ones to protect against (and how). In this work, we take this approach by starting with the data generating process, and derive stable solutions to prespecified sets of shifts.
Other proactive methods make explicit use of connections to causality. Related to ICP, [26] use multiple source environment datasets to find a feature subset that yields an invariant conditional distribution. This has also been extended to the reactive case in which unlabeled target data is available (see also [21]). For discrete variable settings in which data from only one source environment are available and there is no unobserved confounding, covariate balancing techniques have been used to determine the causal features that yield a stable conditional distribution [51, 52]. Other causal methods assume explicit knowledge of the graph representing the DGP instead of requiring multiple datasets. Explicitly assuming no unobserved confounders, [10] protect against shifts in continuous-time longitudinal settings by predicting counterfactual outcomes. [12] find a stable feature set that can include counterfactual variables, assuming linear mechanisms. [24] regularize towards counterfactual invariance so that a model learns to predict only using causal associations. Other works consider using counterfactuals generated by human annotation [53], data augmentation (see [54, 25] for discussion of the relationship between causality and data augmentation), or active learning [55] to improve model robustness. [22] use selection diagrams [13] to identify mechanisms that can shift and find a stable interventional distribution to use for prediction. More recently, end-to-end approaches have been developed which relax the need for the graph to be known beforehand, instead learning it from data [56, 23]. In this paper we provide a common ground for understanding the different types of stable solutions and for finding stable predictors with the best worst-case performance.
3 A Hierarchy of Shift-Stable Distributions
In this section we present a causal hierarchy of stable distributions that are invariant to different types of dataset shifts. First, we will introduce a general graphical representation for specifying dataset shifts that can occur (Section 3.2). We use this representation to give a simple graphical criterion for determining if a distribution is stable to a set of prespecified shifts. Then, we present the hierarchy and show that the levels of the hierarchy correspond to three operators on the graphical representation which modify the graph to produce stable distributions (Section 3.3). This allows different stable distributions to be compared in terms of how they modify the graphical representation. We further show that the hierarchy is nested, and thus, it has implications on the existence of stable distributions (Section 3.3.2). We begin by introducing necessary background on causal graphs (Section 3.1). Proofs of results are in Appendix B.
3.1 Preliminaries
3.1.1 Notation
Throughout the paper sets of variables are denoted by bold capital letters while their particular assignments are denoted by bold lowercase letters. We will consider graphs with directed or bidirected edges (e.g., ). Acyclic will be taken to mean that there exists no purely directed cycle. The sets of parents, children, ancestors, and descendants in a graph will be denoted by , , , and , respectively (subscript omitted when obvious from context). For an edge , and will refer to the head and tail of the edge, respectively.
3.1.2 Structural Causal Models
We represent the data generating process (DGP) underlying a prediction problem using acyclic directed mixed graphs (ADMGs), , which consists of a set of vertices corresponding to observed variables and sets of directed and bidirected edges such that there are no directed cycles. Directed edges indicate direct causal relations while bidirected edges indicate the presence of an unobserved confounder (common cause) of the two variables. ADMGs are able to represent directed acyclic graph (DAG) models that contain latent variables. However, the latent variables do not need to be known. For example, if an ADMG has an edge , then this means that there is some, possibly unknown, mechanism by which and are confounded (i.e., there exists some unobserved variable such that , but the variable may be unknown). Thus, using ADMGs, a modeler can reason about the effects of unobserved confounding even if the mechanism or the confounder itself is unknown.
The graph defines a Structural Causal Model (SCM) [57] in which each variable is generated as a function of its parents and a variable-specific exogenous noise variable : . The prediction problem associated with the graph consists of a target output variable and the remaining observed variables as input features.
As an example, consider the DAG in Fig 1a. This DAG corresponds to a simple version of the pneumonia example in [11]. The goal is to diagnose pneumonia from chest x-rays and stylistic features (i.e., orientation and coloring) of the image . The latent variable represents the hospital department the patient visited. The corresponding ADMG is shown in Fig 1b. The unobserved confounder, , has been replaced by a bidrected edge.
3.2 Stability and Types of Dataset Shifts
In this section we introduce types of dataset shifts that have been previously studied. Then, we graphically characterize instability in terms of edges in the graph of the data generating process. This will be key to the development of the hierarchy in Section 3.3.
To define the types of dataset shifts, assume that there is a set of environments such that a prediction problem maps to the same graph structure . However, each environment is a different instantiation of that graph such that certain mechanisms differ. Thus, the factorization of the data distribution is the same in each environment, but terms in the factorization corresponding to shifts will vary across environments. As an example, consider again the graph in Fig 1a. In the pneumonia example, each department has its own protocols and equipment, so the style preferences vary across departments. In this example, a mechanism shift in the style mechanism leads to differences across environments.
Definition 1** (Mechanism shift).**
A shift in the mechanism generating a variable corresponds to arbitrary changes in the distribution .
Causal mechanism shifts produce many previously studied instances of dataset shift. Consider, for example, label shift, a well-studied mechanism shift in which the distribution of the features given the label () is stable, but varies across environments. Label shift corresponds to a causal graph in which the features are caused by the label, and the mechanism that generates varies across environments, resulting in changes in the prevalence (see, e.g., [58, 38]).
More generally, mechanism shifts are the most common and general type of shift considered in prior work on proactive approaches for addressing dataset shift [50, 26, 21, 51, 22]. However, special cases of mechanism shifts have also been studied. For example, [59, 45, 46] considered parametric mean-shifted mechanisms, in which the means of variables in linear SCMs can vary by environment.
Definition 2** (Mean-shifted mechanisms).**
A mean-shift in the mechanism generating a variable corresponds to an environment-specific change in the intercept of its linear structural equation . Nonlinear generalizations are possible.
Another special case considered by [12] is edge-strength shifts, in which the relationship encoded by a subset of edges into a variable may vary. Variation along an individual edge corresponds to the natural direct effect [57, Chapter 4]. Thus, an edge-strength shift is a mechanism shift which changes the natural direct effect associated with the edge.
Definition 3** (Edge-strength shift).**
An edge-strength shift in edge corresponds to a change in the natural direct effect: for we have that changes, where is the counterfactual value of had been , and is the counterfactual value of had been and had been counterfactually generated under .
Key Result: All of these shifts can be expressed in terms of edges. First, edge-strength shifts directly correspond to particular edges. Next, since the mechanism generating a variable is encoded graphically by all of the edges into , shifts in mechanism can be represented by marking all edges into as unstable. For shifts in mechanism to an exogenous variable with no parents in the graph, one might imagine adding an explicit mechanism variable to the graph and considering the edge to be unstable. Finally, mean-shifts correspond to an edge where the mean of is shifted in each environment (also referred to as an “anchor”, see [45] for a discussion of anchor variables). Thus, mean-shifts are an example of a specific type of edge shift. While the edge representation of shifts is more general, we note that it cannot differentiate between specific instances of shifts (e.g., a mean-shift and a shift in the natural direct effect of an “anchor” variable will have the same graphical representation).
We denote the set of unstable edges that can vary across environments by where is the set of edges in . Graphically, unstable edges will be colored.
Definition 4** (Unstable Edge).**
An edge is said to be unstable if it is the target of an edge-strength shift or a mechanism shift.
The concept of unstable edges provides a flexible and extensible way to graphically represent dataset shifts.
3.2.1 Extensions to New Types of Shifts:
We note that defining shifts in terms of unstable edges makes it possible to tackle new problems determined by shifts in sets or paths of unstable edges. For example, DAGs can be used to represent non i.i.d. network data in which certain edges represent interference between units (e.g., friendship ties in social networks) [60, 61]. Thus, one can define dataset shifts pertaining to networks (e.g., deleting, adding, or changing the strength of friendships). Similarly, dataset shifts due to changing path-specific effects [62] are another interesting avenue for future exploration (e.g., reductions in side effects of a drug while maintaining its efficacy).
While the focus of this paper is on predictive modeling, we note that the shifts we describe have the opportunity to interact with causal inference work on transportability and the “data fusion problem” [15]. There has been much methodological work on causal graphical methods for transporting causal effect estimates (see, e.g., [13, 63, 64, 65, 66]). These works have primarily considered transporting causal effects under mechanism shifts. The proposed edge-based definitions of shifts can help frame transportability problems under new types of edge-based shifts such as those motivated above.
3.2.2 Stable Distributions
We can now define stable distributions, which are the target sought by methods addressing instability due to shifts. We will refer to a model of a stable distribution as a stable predictor.
Definition 5** (Stable Distribution).**
Consider a graph with unstable edges defining a set of environments (different data distributions that factorize with respect to that have been generated by differences in the mechanisms associated with ). A distribution is said to be stable if for any two environments, , that are instantiations of , holds. The distribution is not restricted to being an observational distribution.
Having established a common graphical representation for arbitrary shifts of various types, we provide a graphical definition of stable distributions. First, define an active unstable path to be an active path (as determined by the rules of -separation [67]) that contains at least one unstable edge. Key Result: The non-existence of active unstable paths is a graphical criterion for determining a distribution’s stability.
Theorem 1**.**
* is stable if there is no active unstable path from to in and the mechanism generating is stable.*
Intuitively, Theorem 1 means that a stable distribution cannot capture a statistical association that relies on the information encoded by an unstable edge. In the pneumonia example of Fig 1a, the edge which denotes the X-ray style mechanism was determined to be unstable. Because is unobserved, a model of will learn an association between and through . Thus, contains an active unstable path, and this distribution is unstable to shifts in the style mechanism. This means that is different in each environment. By contrast, if were observed and we could condition on it, then is stable to shifts in the style mechanism because all paths containing the unstable edge are blocked by . Thus, is invariant across environments.
In the next section we use this edge-based graphical characterization to show that all stable distributions, including those found by existing methods, can be categorized into three levels. Thus, this hierarchy defines the ways in which it is possible to achieve stability to shifts.
3.3 Hierarchy of Shift-Stable Distributions
Many works seek stable distributions in order to make predictions that are stable or invariant to dataset shifts. However, because these methods have been developed in isolation, there has been little discussion of whether these methods find the same stable distributions, or how these distributions differ from one another. As a main contribution of this paper, we show in this section that there exists a hierarchy of stable distributions, in which stable distributions at different levels have distinct graphical properties. Thus, the development of this hierarchy provides a common theoretical underpinning for understanding when and how stability to shifts can be achieved, and in what ways stable distributions can differ. In this section we will define the levels of the hierarchy and show that they correspond to different operators that can remove unstable edges from the graph. Then, in the next section we will further study how differences between levels of the hierarchy affect worst-case performance across environments.
3.3.1 Levels of the Hierarchy
Armed with the graphical characterization of stability from the previous section, we now introduce a hierarchy of the 3 categories of stable distributions. The levels of the hierarchy are: 1) observational conditionals, 2) conditional interventionals, and 3) counterfactuals. This hierarchy is related to the hierarchy of causal queries, which defines three levels of causal study questions an investigator can have: association, intervention, and counterfactuals [57]. Also relatedly, in [68] the authors connect the identification of different types of causal effects to a hierarchy of graphical interventions: node, edge, and path interventions. While these works develop hierarchies that relate different types of causal queries and effects, in this paper we develop a hierarchy of shift-stable distributions which connects different types of stable distributions to interventions which remove unstable parts of the data generating process from the underlying graph of the DGP.
Each level of the hierarchy of stable distributions corresponds to graphical operators which differ in the precision with which they can remove edges in the graph (Corollary 2, main result of this subsection). Using the graph in Fig 2(a) as a common example, we discuss each level in detail below. Note that in Fig 2(a), the goal is to predict from , and the edge is unstable.
Definition 6** (Stable Level 1 Distribution).**
Let be an ADMG with unstable edges defining a set of environments . A stable level 1 distribution is an observational conditional distribution of the form such that, for any two environments , holds.
Level 1: Methods at level 1 of the hierarchy seek invariant conditional distributions of the form that use a subset of observed features for prediction [26, 21]. These distributions only have conditioning (i.e., the standard rules of -separation) as a tool for disabling unstable edges. For this reason, the conditioning operator is coarse and removes large pieces of the graph. Consider Figure 2a, in which the maximal stable level 1 distribution is , since conditioning on either or activates the path through the unstable (orange) edge. The conditioning operator disables all paths from and to to produce Fig 2b. While the operator successfully removes the unstable edge, many stable edges were removed as well.
Definition 7** (Stable Level 2 Distribution).**
Let be an ADMG with unstable edges defining a set of environments . A stable level 2 distribution is a conditional interventional distribution of the form such that, for any two environments , holds.
Level 2: Methods at level 2 [22] find conditional interventional distributions [57] of the form . In addition to conditioning, level 2 distributions use the operator, which deletes all edges into an intervened variable [57]. Fig 2c shows the result of applied to Fig 2a: the edges into (including the unstable edge) are removed. Thus, is stable and retains statistical information along stable paths from and that the level 1 distribution did not. However, the stable edge was also removed by the operator. Intervening interacts with the factorization (according to the graph) of the joint distribution of the observed variables by deleting the terms corresponding to mechanisms of the intervened variable: in Fig 2a. The term corresponds to the stable information we retain by intervening that we could not capture by conditioning.
Definition 8** (Stable Level 3 Distribution).**
Let be an ADMG with unstable edges defining a set of environments , and let denote the counterfactual value of had been set to for variables . A stable level 3 distribution is a counterfactual distribution of the form such that, for any two environments , holds.
Level 3: Finally, level 3 methods [12, 24] seek counterfactual distributions, which allow us to consider conflicting values of a variable, or to replace a mechanism with a new one. For example, let and denote two children of a variable . If we hypothetically set to for but left as its observed value for , this corresponds to counterfactual and factual . By setting a variable to a reference value (e.g., [math]) for one edge but not others, computing counterfactuals effectively removes (or replaces) a single edge. In Fig 2c, we saw is stable and deletes both edges into , including the stable edge. However, if we compute the counterfactual , depicted in Fig 2d, then the level 3 distribution is stable and only deletes the unstable edge, retaining information along the path. More generally, level 3 distributions allow us to counterfactually replace mechanisms (and thus replace the influence along unstable edges) with new ones. We will exploit this fact in Section 4 when we investigate accuracy. The effects of the three operators produce the following result:
Corollary 2**.**
Distributions at increasing levels of the hierarchy of stability grant increased precision in disabling individual edges (and thus paths).
Key Result: Thus, the difference between operators associated with the different levels of stable distributions is the precision of their ability to disable edges into a variable. Level 1, conditioning, must remove large amounts of the graph to disable edges. Level 2, intervening, deletes all edges into a variable. Level 3, computing counterfactuals, can precisely disable a single edge into a variable. Since paths encode statistical influence this also provides a natural definition for a maximally stable distribution as one which deletes the unstable edges, and only the unstable edges. Thus, given a stable distribution found by any method, we can compare to the maximally stable distribution to see which, and how many, stable paths were removed.
Another important fact is that the hierarchy is nested. This means that a level 1 distribution can be expressed as a level 2 distribution (and a level 2 distribution can be expressed as a level 3 distribution):
Lemma 3** ([22], Corollary 1).**
A stable level 1 distribution of the form can be expressed as a stable level 2 distribution of the form for , .
Lemma 4**.**
A stable level 2 distribution of the form can be expressed as a stable level 3 distribution of the form .
3.3.2 Consequences
There are a number of practical consequences of the hierarchy of shift-stable distributions: First, level 1 distributions can always be learned from the available data because conditional distributions are observational quantities. This means that we can simply fit and learn a model of from the training data. However, because the conditioning operator throws away large parts of the graph, including many stable paths, models of level 1 distributions will generally have higher error compared to models of level 2 and level 3 distributions. A tradeoff exists, though, since level 2 and level 3 distributions are not always identifiable—they cannot always be estimated as a function of the observational training data. Level 2 distributions model the effects of hypothetical interventions, and, just as in causal inference, unobserved confounding can lead to identifiability issues (for more detail on identification and level 2 stable distributions see [22]). In addition to identifiability challenges, level 3 counterfactual distributions require further assumptions about the functional form of the causal mechanisms in the SCM. Under a fully specified SCM (i.e., the functions defining mechanisms and their parameters are all known), counterfactual inference can be performed using a three step abduction, action, prediction procedure described in [57, Chapter 7]. For example, the method of [12] assumes linear causal mechanisms to compute level 3 distributions. However, we often have limited information about functional forms and the distribution of the exogenous noise variables in an SCM. If we want to make counterfactual queries with fewer or no parametric assumptions, then identifiability becomes even more difficult: In general, not all counterfactual queries will be testable. That is, experimental data cannot be used to uniquely verify the result of a counterfactual query (where as experimental data can verify the result of any interventional query). For non-parametric SCMs, [69] provide an algorithm for determining if a counterfactual query is empirically testable. Thus, one must balance strong parametric assumptions about the form of causal mechanisms against the possibility of untestable counterfactuals.
The nested nature of the hierarchy means that it has consequences on the existence of stable distributions: If there is no stable level 3 distribution, then no stable level 1 or level 2 distributions exist. Considering the other direction, if we find that no stable level 1 distribution exists, there may still be a stable level 2 or 3 distribution. This is an important consideration as more methods for finding stable distributions are developed. For example, [22] developed a sound and complete algorithm for finding stable level 2 distributions in a graph. This means that the algorithm returns a distribution if and only if a stable level 2 distribution exists. An open problem is to develop a sound and complete algorithm for finding stable level 3 distributions. Such a result would be very powerful: If a complete algorithm failed to find a stable level 3 distribution, then that would mean no stable distributions (level 1, 2, or 3) exist.
We have shown that the hierarchy of stable distributions defines graphical operators which can be used to construct stable distributions by disabling edges in the the underlying graph. Next, we show how the ability of counterfactual level 3 distributions to replace edges can be used to achieve minimax optimal performance under dataset shifts.
4 Worst-case Performance of Shift-Stable Distributions
We now compare stable distributions with respect to their minimax performance under dataset shifts. Specifically, we show that there is a hypothetical environment in which counterfactually training a model would yield minimax optimal performance across environments. We further show that this level 3 counterfactual distribution is not, in general, a level 2 interventional distribution. Counter to the increasing interest in invariant interventional solutions like Invariant Risk Minimization and its related follow-ups (e.g., [27, 28, 29]), these results motivate the development of counterfactual (as opposed to level 2) learning algorithms.
4.1 A Decision Theoretic View
We now present our result characterizing the stable distribution that achieves minimax optimal performance. First, recall that dataset shifts result in a set of hypothetical environments generated from the same graph such that the mechanisms associated with unstable edges in differ in each environment. For simplicity, we will assume that the mechanism of a single variable is subject to shifts while the mechanisms of all other variables remain stable across environments. Each distribution in the set of data distributions corresponding to each environment factorizes according to , but differs only in the term which corresponds to the mechanism for generating .
Now consider the following game: Suppose the data modeler (DM) wishes to pick the distribution such that the corresponding Bayes predictor (i.e., the true ) minimizes the worst-case expected loss (i.e., worst-case risk) across all distributions in . This can be written as
[TABLE]
Following a game theoretic result [70, Theorem 6.1], this game has a solution for bounded loss functions (e.g., the Brier score but not the log loss):
Theorem 5**.**
Consider a classification problem and suppose is a bounded loss function. Then Equation 1 has a solution, and the maximum generalized entropy distribution satisfies .
**Key Result: ** That this game has a solution means that is the “optimal training environment” such that counterfactually training a predictor in to learn the true would produce the minimax optimal predictor. Importantly, this optimal environment depends on the choice of loss function. There are two consequences of this result: First, is not, in general, a level 2 distribution (and thus level 2 distributions are not, in general, minimax optimal). Second, there is a level 3 distribution which corresponds to and thus is minimax optimal.
Proposition 6**.**
The level 2 stable distribution , where is the member of such that has a uniform distribution, i.e., for .
In the appendix we provide a counterexample in which . This shows that the level 2 stable distribution is not minimax optimal.
Proposition 7**.**
The level 3 distribution equals and is minimax optimal, where is the counterfactual generated under the mechanism associated with the environment .
Thus, given training data from , if we could counterfactually learn in the environment associated with then the resulting predictor would be minimax optimal. This means the stable level 3 distribution produces the best, worst-case performance across environments out of all distributions that could be used for prediction.
4.2 A Simple Learning Algorithm
We now consider a simple distributionally robust likelihood reweighting algorithm for learning the minimax optimal level 3 predictor. This approach can serve as a starting point for developing new stable learning algorithms which achieve minimax optimal performance under dataset shift.444Alternatively, one could try to directly compute the maximum generalized entropy distribution. See, e.g., [71] for a simple example.
For simplicity, suppose there are no unobserved confounders (i.e., has no bidirected edges). We relax this condition in the Appendix. Then, learning in the environment using training data from can be done by reweighting the training data:
[TABLE]
assuming full shared support (i.e., overlap between and for all values of ).
Because the minimax optimal training environment is unknown, we now seek to train the minimax optimal predictor by parameterizing environments and iteratively finding the worst-case environment. Let s.t. and be a reweighting function parameterized by . Note that different values of correspond to different hypothetical training environments . The learning problem becomes
[TABLE]
with model parameters . This objective resembles those of distributionally robust methods (e.g., [43]) without restrictions on the density ratio or the divergence between and .
While many possibilities exist, perhaps the simplest version of is to explicitly learn a parametric density model (e.g., logistic regression for discrete ) for and use the same density model class to model . Algorithm 1 describes a gradient descent ascent learning procedure (GDA) for this case, which alternates between finding environmental parameters that maximize the risk of the previous prediction model (with parameters ), and finding model parameters that minimize risk under the previously found worst-case environment (with parameters ). It is important note that this general minimax learning problem is often very challenging with complicated convergence and equilibrium dynamics (see, e.g., [72, 73, 74]). Thus, Algorithm 1 only serves as a starting point for designing counterfactual level 3 learning algorithms.
5 Experiments
We turn to semisynthetic experiments on a real medical prediction task to demonstrate practical performance implications of the hierarchy. To carefully study model behavior under dataset shifts, we posit a graph of the DGP for this dataset and reweight the data to simulate a large number of dataset shifts. We investigate how the performance of models of stable distributions at different levels of the hierarchy behave as test environments differ from the training environment. Our results show that though level 3 models can produce the best worst-case performance (i.e., minimax optimal), level 2 models may perform better on average. This further highlights that model developers need to carefully choose how they achieve stability.
5.1 Motivation and Data
One prominent application of machine learning is patient risk stratification in healthcare. It has been widely noted that developing reliable clinical decision support models is difficult due to changes in clinical practice patterns [7]. The resulting behavior-related associations are often brittle—policies change over time and differ across hospitals—and can cause models to make dangerous predictions if left unaccounted [10]. We investigate practical implications of the hierarchy on this important risk prediction challenge.
Below we describe our setup which loosely follows the setup of [75] for predicting patient risk of sepsis, a life-threatening response to infection. We use electronic health record data collected over four years at our institution’s hospital. The dataset consists of 278,947 patient encounters that began in the emergency department. The prevalence of sepsis (S) is . Three categories of variables were extracted: vital signs (V) (heart rate, respiratory rate, and temperature), lab tests (L) (lactate), and demographics (D) (age and gender). For encounters that resulted in sepsis, physiologic data available prior to sepsis onset time was used. For non-sepsis encounters all data available until the time the patient was discharged from the hospital was used. Min, max, and median features were derived for each time-series variable. Unlike vitals, lab measurements are not always ordered (O), so a binary missingness indicator was given. The graph of the DGP is shown in Fig 3.
5.1.1 Shifts in Lab Test Ordering Patterns
Different lab test ordering policies correspond to shifts in the conditional . As a result, missingness patterns vary across datasets derived from different hospitals, because the lab test rate can vary from one institution to another [76]. To compare across datasets corresponding to differing lab testing patterns, we simulated one hundred datasets as follows: For a given test split, we fit a (logistic regression) model of the ordering policy (i.e., the model). Then, for a new ordering policy , we reweight the test samples by to mimic data from a new hospital which differs only in the ordering policy. Reweighting the examples makes it such that the overall distribution of the reweighted test set is different from the distribution of the original test set without perturbing the feature values of individual examples. Thus, the reweighted datasets consist entirely of examples that were observed in the original dataset.
To simulate an edge shift, we created new ordering policies by perturbing the coefficient of sepsis in the model. This corresponds to changing the log odds ratio for sepsis of a patient receiving a lab test. A log odds mean that lab test orders are more likely for non-sepsis patients than for sepsis patients, while a log odds means that lab test orders are more likely for sepsis patients than for non-sepsis patients. To simulate a mechanism shift, we perturbed all coefficients in the model.
5.2 Experimental Setup
Train/test splits were generated via 5-fold cross-validation. Full experimental details are in the Appendix. Models were fit using the Brier score (which for binary classification is , where is the true label and is the predicted probability of class ) as the loss since it is a bounded loss function (required by Theorem 5).
5.2.1 Models
We consider the four possible models: stable models for each level of the hierarchy and an unstable baseline that does not adjust for shifts. In fitting models, any model structure (e.g., random forests, neural networks, etc.) can be used to fit the marginal/conditional distributions. The choice of model does not impact the study conclusions drawn here. For simplicity, we used logistic regression for all models. The level 1 model excludes the lab-derived and lab order features. With respect to the graph in Fig 3, this effectively deletes the and nodes (and all edges into these nodes). The level 2 model is of , and is an implementation of the “graph surgery estimator” [22]. In Fig 3, the operator deletes both edges into the node. Finally, the level 3 model was trained using Algorithm 1, with a logistic regression counterfactual reweighting model. In Fig 3, this deletes and then replaces the mechanism for with a new ordering policy mechanism. All models were implemented in JAX [77]. Full details of how the level 2 and 3 models were fit are in the Appendix.
5.3 Results
When test environments differ from the training environment, stable models have more robust performance than unconstrained, unstable models. An unconstrained model uses all dependencies present in the training data; in other words, the model captures correlations due to all paths in the underlying graph. As we impose invariance constraints (by disabling edges), stable models show performance improvements over the unstable model as the test distribution deviates further from the training distribution. We see this, for example, in Fig 4a when, due to the edge shift, the correlation flips from being negative to positive: the level 1, 2, and 3 models outperform the unstable model for log odds ratios .
As desired, the level 3 model achieves the best worst-case performance amongst the four models, indicating that training using Algorithm 1 was successful. Further, the performance of the level 3 model is nearly constant across the shifts. This is encouraging evidence, because constant risk is a sufficient condition for a Bayes estimator to be minimax optimal [78]. The results are largely consistent in Fig 5, in which we consider a mechanism shift in the lab test ordering policy. We see that irrespective of the KL divergence between the training and shifted distributions, the level 3 model still has almost constant performance.
Finally, from Theorem 5 we know that the optimal training distribution depends on the choice of loss function (Theorem 5). Thus, we do not expect a minimax optimal predictor under one loss to be optimal when measured under a different loss. Indeed, in Fig 4b when the four models are evaluated with respect to the log loss, the level 3 model is no longer minimax optimal. In fact, its performance is strictly worse than that of the level 2 model. Even when evaluated using the Brier score (4a), the worst-case performance of the level 3 model is only slightly better than the worst-case performance of the level 2 model. Further, the level 2 model sees performance improvements when the log odds increase that the level 3 model does not (loss drops noticeably for x-axis values ). Thus, on average, the level 2 model might be preferable on this data, and a conservative objective like worst-case performance may not be desirable. This illustrates a classic problem in statistical decision theory: while minimax objectives can be too conservative, it may be difficult to characterize the “average” environment or to specify a reasonable prior over environments.
6 Limitations
The primary limitation of the framework presented in this paper is its reliance on a known causal graph of the data generating process. Correct specification of the graph is important because the addition of an edge, or the change of orientation of an edge, can change the stability of a distribution. Adding an edge can open new active unstable paths, while a change in orientation of an edge can cause an inactive path to become active (e.g., conditioning on a chain we have vs conditioning on a collider we have ). As the number of variables increases, it becomes difficult to manually specify an entire causal graph with confidence. In this section, we discuss options for addressing the limitation of misspecification of (or inability to specify) the graph.
When domain knowledge is insufficient to specify a causal graph, one can try to learn the structure of the graph from data, a problem known as causal discovery or structure learning (see [79] for an overview). Constraint-based structure learning algorithms work by using (conditional) independence tests to determine edge adjacencies. Thus, by testing compatibility with the data, it is possible to learn the structure of the graph up to an equivalence class: a set of fully specified graphs which imply the same independences. Notably, some constraint-based structure learning algorithms tolerate and account for the possibility of unknown confounding variables (see [80] for a recent review of structure learning algorithms).
Using structure learning, it is possible to learn the range of causal structures which are compatible with the data. Given this range of causal structures, there are two main approaches one could use to find stable distributions. One approach is to enumerate each member of the equivalence class, find and fit models of stable distributions in the fully specified member, and compare across the members of the equivalence class. This approach is akin to sensitivity analysis approaches for finding the range of causal effect estimates in the equivalence class (see, e.g., [81, 82]). The challenge with this approach is that it does not produce a single model that is guaranteed to be stable, but rather a range of candidate “possibly stable” models. One would require data from a new environment to test the stability of the candidate models.
A second approach is to find a distribution which is stable in every member of the equivalence class. Such a distribution is guaranteed to be stable, regardless of which member of the equivalence class represents the “true” data generating process. While this could be done through enumeration of each member of the equivalence class (as in the previously outlined sensitivity analysis approach), recent approaches allow us to find stable distributions in graphical representations of the equivalence class [56, 23]). The output of many constraint-based structure learning algorithms is a partial graph in which edges may be partially directed (i.e., edge endpoints may be an arrowhead, an arrow tail, or representing that either is possible). One can then consider extensions of graphical operators from the hierarchy to partial graphs. As one example, [23] propose a method for finding stable level 1 and level 2 distributions in partial graphs. More generally, a promising direction for future work is to extend results from the proposed framework to partial graphs. Because partial graphs can be learned from data, this would relax the requirement of a fully specified graph as the starting point for this graphical framework.
7 Contrast with Invariant Risk Minimization
The discussion in this paper has focused on a graphical perspective—explicitly starting with knowledge of the data generating process and using this to determine when and how stability to shifts is achievable. An alternative emerging paradigm in machine learning has focused on invariant risk minimization (IRM) [27, 29, 28]. IRM is applicable when multiple datasets from different environments are available, and the goal is to learn a representation that produces an optimal predictor which is invariant across these environments. In this section, we discuss an important limitation of the invariant risk minimization paradigm which highlights a key advantage of graphical approaches. We also discuss how graphical analyses can guide future work to address this.
A critical question that determines the usefulness of an invariant predictor is: to what set of shifts is the predictor stable? The answer to this question defines the set of new environments to which an invariant predictor can be safely applied. In the graphical approach, the answer is transparent by design. Shifts are defined as (arbitrary) changes to particular causal mechanisms in the graph, so an invariant predictor is exactly one which is stable to the specified shifts in mechanisms. Further, the graph allows model developers to choose the set of shifts to which a predictor should be stable and provide guarantees about shifts that are protected against.
In contrast, IRM methods currently struggle to answer this critical question. First, existing IRM methods do not identify the differences that exist across the observed environments. Thus, they are unable to provide guarantees about the nature of the shifts in environment (i.e., the causal mechanisms) against which they protect. This also means it is difficult to state the set of new environments to which the invariant predictor can be safely applied. Further, because IRM automatically determines invariance from datasets, there is no opportunity for developers to specify particular invariances that they want to hold.
Outside of invariant risk minimization, there are opportunities to leverage ideas from other works on invariant learning and ideas from the proposed graphical framework to improve IRM-type methods. For example, [83] shows a relationship between invariant predictors and calibration across environments. This suggests a possible approach for probing an invariant predictor for stability to particular mechanisms shifts. First, using structure learning [79], it is possible to detect particular mechanism shifts that occur across environments [84, 56, 23]. Then, when mechanisms of interest have been identified, one can test for stability to particular mechanisms shifts by examining how the calibration of the predictor changes as evaluation data is reweighted according to the distribution associated with the mechanism shift. This would provide a post-hoc way to verify the integrity of a trained invariant predictor.
As another example, [24] show that counterfactual invariances leave observable distributional signatures that can be used to design regularizers to enforce the given invariance. This motivates the combination of IRM-type objectives with regularizers which explicitly capture desired invariances at different levels of the hierarchy of shift-stability. This would allow developers to specify particular invariances they want to guarantee while also automatically learning other invariances from the data. In the context of image classification, [44] show how multiple views of an image and data augmentation can be used to learn models which are invariant to shifts in known and unknown style features. This provides ideas for learning specified invariances in settings with unstructured data (e.g., images and text).
8 Conclusion
The use of machine learning in production represents a shift from applying models to static datasets to applying them in the real world. As a result, aspects of the underyling DGP are almost certain to change. Many methods have been developed to find distributions that are stable to dataset shift, but as a field we have lacked common underlying theory to characterize and relate different stable distributions. To address this, we developed a common framework for expressing the different types of shifts as unstable edges in a graphical representation of the DGP. We further showed that stable distributions belong to a causal hierarchy in which stable distributions at different levels have distinct operators that can remove unstable edges in the graph. This provides a new, but natural, way to characterize and construct stable models by only removing unstable edges. This also motivates a new paradigm for future work developing methods that can modify individual edges. We also showed that popular invariant solutions (level 2; invariant under intervention) do not, in general, achieve minimax optimal performance across environments. Our experiments showed that there is a tradeoff between worst-case average performance. Thus, model developers need to carefully determine when and how they achieve invariance.
Appendix A Medical Risk Prediction Experiment
A.1 Data
Our experimental setup follows that of [85]. The dataset contains electronic health record data collected over four years at our institution’s hospital. The dataset consists of 278,947 emergency department patient encounters. The prevalence of the target disease, sepsis (S), is . Features pertaining to vital signs (V) (heart rate, respiratory rate, temperature), lab tests (L) (lactate), and demographics (D) (age, gender) were extracted. For encounters that resulted in sepsis (i.e., positive encounters), physiologic data available up until sepsis onset time was used. For non-sepsis encounters, all data available until discharge was used. For each of the time-series physiologic features (V and L), min, max, and median summary features were derived. Unlike vitals, lab measurements are not always ordered (O) and are subject to missingness (lactate missing). To model lab missingness, missingness indicators (O) for the lab features were added, and lab value-missingness interactions terms were used in place of lab value features.
A.2 Experimental Details
Logistic Regression models were fit using a custom JAX [77] implementation. regularization with regularization coefficient was used (hyperparameter chosen via grid search using the performance of the unstable model on a hold-out 10% of the initial dataset). These same hyperparameters were used to train the Level 1-3 models. For the predictive models, a b-spline basis feature expansion was used for continuous features (lab values and vital signs). Following the standards in [85] for accounting for missingness, the missingness feature and the missingness-lab value interaction features were added.
The specific shift in lab test ordering patterns considered was a shift in lactate ordering, as these patterns have seen great variation across hospitals and are known to be associated with sepsis [76]. Lactate missingness has a correlation of -0.36 with sepsis in this dataset (i.e., the presence of the measurement is predictive of the target variable).
Thus, to simulate the edge shift, in each test fold, we first fit a logistic regression model (no b-spline basis expansion, with default scikit-learn hyperparameters) to the test fold’s lactate missingness () given . That is, a logistic regression model of . Then, to simulate the edge shifted lactate ordering policies, we replaced the coefficient for sepsis in the logistic regression model with 100 values on a grid from -6 to 8. The resulting logistic regression model is of the hypothetical shifted hospital’s ordering policy . Evaluating the loss under each shift was then done by using sample weights computed as for each test sample using the two models.
The mechanism shift was simulated in a similar manner to the edge shift. However, instead of only perturbing the coefficient of sepsis in the model, all coefficients and the intercept were perturbed. Specifically, for a single test fold, 1000 new coefficients were sampled as follows: Let denote the weight in the model. The new coefficients/intercepts were drawn from . Because all weights of the logistic regression model changed, we plotted the shifts according to the estimated (using the test set) KL-divergence between the and logistic regression models: .
The level 1 model was fit using a reduced feature set that excluded the lactate features (min, max, median) and lactate missingness indicator. The level 2 model , an instance of the “graph surgery estimator” [22], was fit by inverse probability weighting (IPW). The term corresponding to in the factorization of the DAG in Fig 4a is , so we fit a logistic regression model of this distribution using the training data. Then, the main logistic regression prediction model with the full feature set was trained using sample weights . The resulting model was the level 2 model. The level 3 model is similar to the level 2 model, but instead corresponds to a counterfactual ordering policy . The level 3 logistic regression model was trained using the Gradient Ascent Descent procedure in Algorithm 1. As noted in the main paper, this procedure has complicated dynamics and we found it was quite sensitive to the choice of step size (or learning rate ). Through grid search using performance on a hold-out 10% of the initial dataset the value was selected. The model parameters were initialized using the learned level 2 model parameters, and the reweighting parameters were initialized via random draws from . The resulting model was the level 3 model.
Appendix B Proofs
See 1
Proof.
We first recall that all of the shifts considered in Section 3.2 are types of arbitrary shifts in mechanism: mean-shifted mechanism are a special parametric case, and edge-strength shifts correspond to a constrained class of mechanism shifts in which the natural direct effect associated with the mechanism has changed. Thus, if a distribution is stable to arbitrary shifts in mechanisms, then it will also be stable to mean-shifts and edge-shifts. For this reason, in our proof we will prove a distribution is stable by leveraging previous graphical results on stability under shifts in mechanisms (and stability to specific cases follows).
To do so, we will leverage results from transportability, which uses a graphical representation called selection diagrams (see [13, 22] for details). A selection diagram is a a graph augmented with selection variables (which each have at most one child) that point to variables whose mechanisms may vary across environments. Prior results have shown that a distribution is stable is if in the selection diagram (see [13, Theorem 2] and [22, Definition 3]). Thus, to prove the theorem, we will first translate our unstable edge representation of the graph to a selection diagram. Then, we will show that if is not -separated from the selection variables that this implies there is an unstable active path to .
We first translate our unstable edge representation of the graph to a selection diagram. For an edge let denote the variables that points into. Now for each , add a unique selection variable that points to each . This indicates that the mechanism that generates V is unstable. We now consider the cases in which there could be an active path from a selection variable to (which would make a distribution unstable), and show that this corresponds to an active path that contains an unstable edge.
There are two possible ways there can be an active path from a variable to . If there is an active forward path from to (e.g., ) then there is a corresponding active path from to that contains the unstable edge : e.g., a path . Alternatively, an active forward path indicates that the mechanism that generates is unstable.
The other case is if there is an active path beginning with a collider from to (e.g., ). Then there is a corresponding active path from to that contains : e.g., . Thus, in a selection diagram if is unstable, then there is an active unstable path to in the original unstable edge-denoted graph. Taking the contrapositive of this statement proves the theorem. ∎
See 3
Proof.
This is a restatement of Corollary 1 in [22]. ∎
See 4
Proof.
Consider the (level 2) intervention . For a variable letting denote the value would have taken had been set to we have that . When interventions are consistent (i.e., for there are no conflicting interventions and ) counterfactuals reduce to the potential responses of interventions expressible with the operator [57, Definition 7.1.4]. ∎
For completeness, we restate the following result from [70]. For the present paper, both the action space and the set of distributions are (the DM is picking a training distribution (the action from and nature is picking the test distribution from ).
Theorem 8** (Theorem 6.1, [70]).**
Let be a convex, weakly closed, and tight set of distributions. Suppose that for each the loss function is bounded above and upper semicontinuous in x. Then the restricted game has a value. Moreover, a maximum entropy distribution , attaining
[TABLE]
exists.
See 5
Proof.
This result follows directly from [70, Theorem 6.1].
The preconditions are trivially satisfied: The set of all distributions over is convex, closed, and tight. We consider bounded loss functions, which for finite discrete (i.e., for classification problems) are continuous. Thus, the game has a solution.
Further, by [70, Corollary 4.2], the maximum generalized entropy distribution is also the distribution minimizing the worst-case expected loss. ∎
See 6
Proof.
We know that every distribution in factorizes according to the graph , and that they only differ in the term corresponding to the mechanism for , . Thus, for any , , noting that is the same across all members of . It suffices to show, then, that (within a constant factor), such that .
Recall that, by definition, performing deletes the term from the factorization (or equivalently sets ), resulting in the so called “truncated factorization.” Further, the resulting distribution is a proper distribution (sums to 1) over . Consider two cases: 1) That is a discrete variable or 2) is a continuous variable. With slight abuse of notation, for continuous variables the results will be with respect to the pdf.
Suppose is discrete and that across environments it is observed to take distinct values for . is not a proper distribution over because . However, this can be made proper by by normalizing it such that . Thus, is within a constant factor of where is the member of such that (i.e., where has a discrete uniform distribution). W.r.t. the theorem statement, . 2. 2.
This case follows similarly. Suppose is continuous and that across environments it is observed to be bounded in the interval . Then is not a proper density over because , but this can be made proper by normalizing the pdf of to be . Thus, the level 2 density is within a constant factor of , the member of where has a continuous uniform distribution over the interval . W.r.t. the theorem statement, .
∎
Corollary 9**.**
Stable level two distributions are not, in general, minimax optimal.
Proof.
The following counterexample is adapted from an example in [71].
Consider the DAG in Fig 6 in which the goal is to predict from and , and the mechanism for generating (i.e., ) varies across environments. The distribution factorizes as .
Let all variables be binary, and assume that and if and otherwise. Finally, we will parameterize as follows: and for . For the Brier score, [71] computed the maximum generalized entropy parameter values to be and .
Thus, the minimax optimal that yields the maximum generalized entropy is and . This is different than the that yields the equivalent to the stable level 2 solution , which is (by Proposition 6). Thus, the level 2 solution is not optimal for this graph using the Brier score (this also holds for the log loss; see the computations in [71]). ∎
See 7
Proof.
Given that our training data was generated from , we are interested in if counterfactually had been generated using the mechanism (i.e., we edited the structural equation in the SCM) , the mechanism for associated with the environment identified in Theorem 5. This mechanism change produces a new distribution associated with , .
We can represent this counterfactually by letting be the potential outcome of had been generated according to for some variable (the rhs notation is sometimes used to express policy interventions, see, e.g., [61]). Thus, the counterfactual distribution can be expressed as (noting that because changing the mechanism of does not affect its parents). Because and differ only with respect to the mechanism generating , the counterfactual distribution associated with this mechanism change yields . Thus, is minimax optimal because was shown to be minimax optimal in Theorem 5. ∎
Appendix C Likelihood Reweighting In The Presence of Unobserved Confounders
In Section 4 we developed a likelihood reweighting formulation for a minimax optimal predictor by assuming that the graph has no bidirected edges (no unobserved confounders). We now relax this condition.
First, note that the c-component (or district) of a variable in an ADMG is the set of nodes reachable via purely bidirected paths (i.e., paths of the form ). An ADMG over variables factorizes as:
[TABLE]
where are the exogenous noise variables. Note that an ADMG factorizes as a product of Q-factors over the c-components. That is, if is partitioned into c-components , then [86, Lemma 7].555When the graph has no bidirected edges, each node is its own c-component. Finally, let be a topological order over . Then each c-factor is identifiable and given by , where is the c-component of that contains .
Now we can see that the term is the ADMG generalization of in DAGs, and is the term associated with the mechanism for generating . Thus, if is identifiable in the ADMG, then we we need to perform likelihood reweighting with respect to . That is, let be the minimax optimal training distribution/environment. Then
[TABLE]
and we can define a reweighting function as before.
The reference list from the paper itself. Each links out to its DOI / PubMed record.
- 1[1] Strickland E. Hospitals roll out AI systems to keep patients from dying of sepsis. IEEE Spectrum. 2018;19.
- 2[2] Winston A. Palantir has secretly been using New Orleans to test its predictive policing technology. The Verge. 2018;27.
- 3[3] Angwin J, Larson J, Mattu S, Kirchner L. Machine bias. Pro Publica, May. 2016;23(2016):139–159.
- 4[4] Quiñonero-Candela J, Sugiyama M, Schwaighofer A, Lawrence ND. Dataset shift in machine learning. The MIT Press; 2009.
- 5[5] Finlayson SG, Subbaswamy A, Singh K, Bowers J, Kupke A, Zittrain J, et al. The Clinician and Dataset Shift in Artificial Intelligence. New England Journal of Medicine. 2021;385(3):283–286.
- 6[6] Dickson B. How the Coronavirus Pandemic Is Breaking Artificial Intelligence and How to Fix It. Gizmodo; 2020. Available from: https://gizmodo.com/how-the-coronavirus-pandemic-is-breaking-artificial-int-1844544143 .
- 7[7] Agniel D, Kohane IS, Weber GM. Biases in electronic health record data due to processes within the healthcare system: retrospective observational study. Bmj. 2018;361:k 1479.
- 8[8] Grytten J, Sørensen R. Practice variation and physician-specific effects. Journal of health economics. 2003;22(3):403–418.
