class: bottom, left, title-slide # Machine learning ## Changing the world ### Joshua Loftus --- class: inverse <style type="text/css"> .remark-slide-content { font-size: 1.2rem; padding: 1em 4em 1em 4em; } </style> .pull-left[ # Observational Interpreting predictive models in various ways # Causal Changing the world ([Famous quote](https://en.wikipedia.org/wiki/Theses_on_Feuerbach)) ] .pull-right[ ![](https://highgatecemetery.org/images/made/images/photos/%5E_E_HXT_0008A_Marx_full_550_826.JPG) ] --- ### Should/do we always care about causality? > How does the model depend on a specific variable or set of variables? *Why* focus on that variable (those variables)? > Which variables does the model depend on most? What will you do with this information? > I don't care about interpretation, I just want the most accurate predictions! Predictions are (usually) used to make decisions / take actions --- ### Causality assumptions Remember: "[No causes in, no causes out](https://oxford.universitypressscholarship.com/view/10.1093/0198235070.001.0001/acprof-9780198235071-chapter-3)" 1. Assume some mathematical (probability) model relating variables (possibly including unobserved variables) 2. **Also** assume a direction for each relationship 3. **Also** assume that changing/manipulating/intervening a variable results in corresponding changes in other variables (or their distributions) which are functionally "downstream" Deterministic e.g.: `\(X \to Y\)`, so `\(y = f(x)\)` and changing `\(x\)` to `\(x'\)` means `\(y\)` also changes to `\(f(x')\)` [but if `\(f\)` has an inverse, this model does *not* mean that changing `\(y\)` to `\(y'\)` results in `\(x\)` changing to `\(f^{-1}(y')\)`] (temperature `\(\to\)` thermometer) --- ### This is how we *want* to interpret regression Suppose we estimate a CEF `\(\hat f(x) \approx \mathbb E[Y |X = x]\)` We would like to interpret this as meaning > If we change `\(x\)` to `\(x'\)` then `\(\mathbb E[Y|X=x]\)` will change from `\(\hat f(x)\)` to `\(\hat f(x')\)` But this interpretation depends on the extra, causal assumptions (2 and 3 on the previous slide) #### *These assumptions are often **importantly wrong*** We must use them carefully, always be clear about it, and try to improve --- ## Various inferential goals/methods - **Causal discovery**: learn a graph (DAG) from data (very difficult especially if there are many variables) Otherwise we may assume the DAG structure is known and want to estimate the functions - **Estimation/inference** - *Parametric*: estimate `\(f\)` and/or `\(g\)` assuming they are linear or have some other (low-dimensional) parametric form, get CIs - *Non-parametric*: use machine learning instead - **Optimization**: find the "best" intervention/policy --- class: inverse, middle, center ### A few examples of estimation tasks and ML --- ### Mediation analysis Does `\(X\)` influence `\(Y\)` directly, and/or through a mediator variable `\(M\)`? *How much* of the relationship between `\(X\)` and `\(Y\)` is "explained" by `\(M\)`? e.g. Does `gdp per capita` influence `life expectancy` directly, and/or through `healthcare expenditure`? e.g. Does `SES` influence `graduation rate` directly, and/or through `major`? e.g. (dear to me) How much of `weightlifting` influences `strength` directly, vs how much of the change in `strength` is mediated through `muscle size`? --- ### Simple model for mediation If we know/estimate functions, we can simulate/compute the consequences of an intervention `$$x = \text{exogeneous}, \quad m = f(x) + \varepsilon_m, \quad y = g(m, x) + \varepsilon_y$$` <img src="10-2-changetheworld_files/figure-html/unnamed-chunk-1-1.png" width="504" style="display: block; margin: auto;" /> Counterfactual: `\(x \gets x'\)`, so `$$m \gets f(x') + \varepsilon_m, \quad y \gets g(f(x') + \varepsilon_m, x') + \varepsilon_y$$` --- ### Mediation questions Direct/indirect effects: holding `\(m\)` constant or allowing it to be changed by `\(x\)`, various comparisons ("controlled" or "natural") #### Direct effect Quantifying/estimating the strength of the arrow `\(X \to Y\)` #### Total and indirect effects How much of the resulting change in `\(Y\)` (total effect) goes through the path `\(X \to M \to Y\)` (i.e. indirectly)? --- ### Treatment with **observed** confounding Let's relabel from `\(M\)` to `\(T\)`, a "treatment" (or "exposure") of interest but for which we do not have experimental data #### The effect of `\(T\)`, "controlling" for `\(X\)` Counterfactual: `\(T \gets t\)`, so `\(Y^t := Y \gets g(t, X) + \varepsilon_y\)` -- Often `\(T\)` is binary, and we are interested in - Average treatment effect (ATE) `$$\tau = \mathbb E[Y^1 - Y^0]$$` - Conditional ATE (CATE) `$$\tau(x) = \mathbb E[Y^1 - Y^0 | X = x]$$` --- class: middle ### Note If `\(T\)` is not binary, we could fix a baseline value `\(t_0\)` and then compare other treatment values relative to this one, e.g. estimate a causal contrast `$$\tau_{t,t_0} = \mathbb E[Y^t - Y^{t_0}]$$` for various values of `\(t\)` (and similarly for the CATE) --- ### Strategy: two staged regression - Parametric $$ Y = T \theta + X \beta + \varepsilon_Y $$ $$ T = X \alpha + \varepsilon_T $$ Estimate `\(\hat \theta\)` from e.g. 2SLS -- - Non/semi-parametric [PLM](https://en.wikipedia.org/wiki/Partially_linear_model) ("*double machine learning*") $$ Y = T \theta + g(X) + \varepsilon_Y $$ $$ T = m(X) + \varepsilon_T $$ Replace LS (least squares) with machine learning methods to estimate "*nuisance functions*" `\(\hat g\)` and `\(\hat m\)` --- ### High-dimensional regression example **Double selection** ("post-double-selection") Use a high-dimensional regression method that chooses sparse regression models, like lasso, in the first stage of a two stage procedure 1. Estimate nuisance functions with e.g. lasso - Fit `\(T \sim X\)`, get a subset of predictors `\(X_m\)` - Fit `\(Y \sim X\)`, get a subset of predictors `\(X_g\)` 2. Fit least squares `\(Y \sim T + X_m + X_g\)` The coefficient of `\(T\)` in the least squares model is our estimate --- class: middle ### Note Can combine (sparse) variable selection (e.g. lasso) with (global) non-linearity, e.g. selecting interaction/polynomial/spline basis functions Even if our original predictors satisfy `\(p \ll n\)`, we can still select (sparse) models from a high-dimensional model space this way --- ### Local/flexible non-linear regression **Double machine learning** Instead of variable selection in high-dimensional regression, we might use e.g. tree ensembles, kernel methods, deep networks, etc to get estimates of nuisance functions -- #### Warning/reminder: ML uses bias to avoid overfitting ML's main focus is prediction, not inference Convergence rates, e.g. of `\(\| g - \hat g\|\)`, are slower for non/semi-parametric estimation than for parametric In finite samples, ML models are biased (regularized) to avoid overfitting. This can harm causal effect estimates --- ### Fitting nuisance functions is hard In double ML, the estimation error `\(\hat \theta - \theta\)` involves terms that look (up to some scaling) like `$$\sum_i {\varepsilon_T}_i |\hat g(X_i) - g(X_i) |$$` Estimating an entire function (goal in ML) is hard, we can't guarantee `\(|\hat g(X_i) - g(X_i) |\)` will be small (enough) -- Idea: make sure this is a product of *independent* terms, because `\(\mathbb E[\varepsilon_T] = 0\)` We'll come back to this --- ### Binary treatment and propensity models When `\(T\)` is binary we may keep the same *outcome model* (e.g. if `\(Y\)` is numeric then we usually use) $$ Y = T \theta + g(X) + \varepsilon_Y $$ But now the treatment model is called a *propensity function* $$ \pi(x) = \mathbb P(T = 1 | X = x) $$ e.g. `\(T\)` smoking, `\(Y\)` some health outcome, `\(X\)` could be SES etc -- Intuition: can't randomize, so compare people who have similar values of `\(\hat \pi(x)\)` Use ML **classification** methods to get `\(\hat \pi(x)\)` --- ### Tree based: causal forests Suppose we want to learn the CATE `\(\tau(x)\)` - *heterogeneous treatment effects* - could be used for *individualized treatments* - maybe no longer one global parameter `\(\theta\)` e.g. Randomized trial, but we want to learn how `\(X\)` interacts with the treatment effect. When is `\(T\)` more/less effective? -- Idea: search predictor space `\(X\)` for regions where `\(Y\)` is maximally different between the treated and controls Causal forests use this splitting strategy when growing trees, instead of splitting to improve predictions #### Consider: won't this overfit? (we'll come back to this) --- class: inverse, middle, center ### Validation for causal ML --- ### Causal relationships should generalize (Specifically, the *kinds of causal relationships we study with statistical methods* should generalize, otherwise why are we using probability?) - If `\(T \to Y\)` and we estimate this causal relationship well then it should be useful for predicting on new data - If `\(X \to T\)` then we should be able to predict `\(T\)` (using `\(\hat m\)` or `\(\hat \pi\)`) on test data - If `\(X \to Y\)` then `\(\hat g\)` should be accurate on test data, etc Causal ML applies this in various ways to different problems --- ## Cross-fitting Sample splitting, using separate subsets of data to compute 1. estimates of nuisance functions, like `\(\hat g\)` or `\(\hat \pi\)` 2. the desired causal estimates, like `\(\hat \theta\)` -- Two ways to think of why this is a good idea - Independence between `\(\hat g\)` and `\(\varepsilon_T\)` makes terms like `\({\varepsilon_T}_i |\hat g(X_i) - g(X_i) |\)` small (on average) - If `\(\hat g\)` is (a little) overfit, for example, this has less effect on `\(\hat \theta\)` because `\(\hat g\)` is overfit to a different subset of the data (than the one used to estimate `\(\hat \theta\)`) Can be iterated: swap subsets and average `\(\hat \theta\)` estimates --- ## Honest trees/forests Split training data, use one subset for deciding tree structure (which variables to use for splitting) and the other subset for predictions e.g. Suppose we fit a tree with depth 1, on one subset of data we find the best split is `\(X_7 < c\)`. Then for predictions we use the values of `\(Y\)` *in a separate subset of data*, averaging/voting those `\(Y\)` values with `\(X_7 < c\)` for one prediction and those with `\(X_7 \geq c\)` for the other prediction. (No test/validation yet) -- **Exercise**: explain how this is different from the out-of-bag validation approach in bagging. Draw diagrams (perhaps with inspiration from ISLR Figure 5.5) to represent the two cases --- class: inverse, middle, center ### These methods can be importantly wrong --- ### Why causal inference (and ML) is difficult #### Out-of-distribution generalization An ML model which is not overfit--has good test error, good in-distribution generalization--can still fail at OOD generalization For a different data-generating process, the *probability distribution changes*, and (maybe) *causal relationships change* -- #### Identifiability The same data can be consistent with many different causal structures. We rely on assumptions which may be wrong ML specific: flexible models can relax some assumptions --- ### Directions, cycles (These issues are not specific to machine learning) - Directionality: If the DAG structure is unknown, it is non-trivial to estimate (from observational data) whether `\(X \to Y\)` or `\(Y \to X\)` ML methods may help, for example by [using non-linearity](https://papers.nips.cc/paper/2008/hash/f7664060cc52bc6f3d620bcedc94a4b6-Abstract.html) -- - Feedback: Many real-world phenomena are not one-directional but evolve over time with feedback May need to use more complex models to capture dynamics --- ### Unobserved confounding *Observed* predictors must capture *everything* about how treatment assignment and treatment effectiveness are related <img src="10-2-changetheworld_files/figure-html/unnamed-chunk-2-1.png" width="360" style="display: block; margin: auto;" /> Practically all methods assume this isn't happening (some conduct sensitivity analysis) --- class: middle **Exercise**: draw a few DAGs representing the partially linear model and give real-world examples both with and without unobserved confounders. For the case with unobserved confounding, explain how our estimates of `\(\theta\)` would fail to capture the correct causal relationship (i.e. why an intervention on `\(T\)` would fail to have the estimated effect) --- ### More warnings for data about humans Social science is hard! Health is also (at least partially) social. It is very difficult to [predict life outcomes](https://www.pnas.org/content/117/15/8398), even with machine learning models and [big data](https://hdsr.mitpress.mit.edu/pub/uan1b4m9/release/3) Strategic/adversarial/adaptive behavior: [Goodhart's law / Campbell's law](https://twitter.com/CT_Bergstrom/status/1329175501882019841) / [Lucas critique](https://en.wikipedia.org/wiki/Lucas_critique) / [Friedman's thermostat](https://worthwhile.typepad.com/worthwhile_canadian_initi/2012/07/why-are-almost-all-economists-unaware-of-milton-friedmans-thermostat.html) / [Washback effect](https://en.wikipedia.org/wiki/Washback_effect) / [Performativity](https://en.wikipedia.org/wiki/Performativity) / [Reflexivity](https://en.wikipedia.org/wiki/Reflexivity_(social_theory%29) / [Hammer's law](https://twitter.com/MCHammer/status/1363908982289559553) For these reasons, *when dealing with human data*: - Almost all variables correlate - Unconfoundedness assumptions are always wrong - Can't separate "normative vs scientific" (ethics always important) Remember: focus on what is **importantly wrong** --- class: inverse, center, middle ### You are close to the cutting edge of research Practice using some of these methods, combine them with interpretable plots (previous lecture), and your labor will be in high demand! --- ## R packages [DoubleML](https://cran.r-project.org/web/packages/DoubleML/vignettes/DoubleML.html) package Causal forests in [grf](https://grf-labs.github.io/grf/index.html) package, and a [blog post](https://www.markhw.com/blog/causalforestintro) introduction A few [DML packages](https://github.com/MCKnaus) by Michael C. Knaus