class: center, middle, inverse, title-slide # Prediction and Overfitting ##
Introduction to Global Health Data Science ###
Back to Website
###
Prof. Amy Herring --- layout: true <div class="my-footer"> <span> <a href="https://sta198f2021.github.io/website/" target="_blank">Back to website</a> </span> </div> --- class: middle # Prediction --- ## GATS Data ```r gats <- readr::read_csv("data/gats_rev.csv") gats <- gats %>% filter(CURRENTSMOKE == "1") %>% filter(TRYSTOP != "9") %>% select(-CIGS_DAY,-CASEID,-CURRENTSMOKE,-AGESTART,-REGION6,-REGION3,-HEARDOFECIG,-ECIGUSE) gats$RESIDENCE=factor(gats$RESIDENCE,levels=c(1,2),labels=c("Urban","Rural")) gats$PROVINCE=factor(gats$PROVINCE,levels=seq(1:31),labels=c("Beijing","Tianjin","Hebei","Shanxi","Neimenggu","Liaoning","Jilin","Heilongjiang","Shanghai","Jiangsu","Zhejiang","Anhui","Fujian","Jiangxi","Shangdong","Henan","Hubei","Hunan","Guangdong","Guangxi","Hainan","Chongqing","Sichuan","Guizhou","Yunnan","Xizang","Shaanxi","Gansu","Qinghai","Ningxia","Xinjiang")) gats$GENDER=factor(gats$GENDER,levels=c(1,2),labels=c("Male","Female")) gats$EDUCATION=factor(gats$EDUCATION,levels=c(1,2,3,4,5,6,7,8,77,99),labels=c("None","Less than Primary School","Primary School","Less than Secondary School","Secondary School","High School","University","Postgraduate","Don't Know","Refused")) gats$OCCUPATION=factor(gats$OCCUPATION,levels=c(1,2,3,4,5,6,7,8,9,10,77,99),labels=c("Farming","Government","Business","Teacher","Healthcare","Student","Soldier","Unemployed","Retired","Other","Don't Know","Refused")) gats$TRYSTOP=factor(gats$TRYSTOP,levels=c(1,2),labels=c("Yes","No")) gats$TRYSTOP=relevel(gats$TRYSTOP, ref = "No") gats$TRYSTOP=as.factor(gats$TRYSTOP) gats$HOMESMOKERULES=factor(gats$HOMESMOKERULES,levels=c(1,2,3,4,7,9),labels=c("Allowed","Not Allowed but Exceptions","Never Allowed","No Rules","Don't Know","Refused")) gats$SMOKESICK=factor(gats$SMOKESICK,levels=c(1,2,7,9),labels=c("Yes","No","Don't Know","Refused")) gats$SMOKECANCER=factor(gats$SMOKECANCER,levels=c(1,2,7,9),labels=c("Yes","No","Don't Know","Refused")) ``` --- ## Goal: Predict Who Has Tried to Stop Smoking - Data: Set of survey responses, and we know who tried to stop smoking, and who hasn't - Use logistic regression to predict the probability that someone has tried to stop smoking - Consider how we might compare different models in terms of predictive performance -- - Building a model to predict the probability that someone will try to stop smoking is only half of the battle! We also need a decision rule about whom we predict to have tried to stop (e.g. what probability should we use as out cutoff?) -- - A simple approach: choose a single threshold probability and any person who exceeds that probability is flagged as someone who may have tried to stop --- ## A multiple logistic regression approach Let's start by including a lot of predictors in a big model. ```r bigmod <- logistic_reg() %>% set_engine("glm") %>% fit(TRYSTOP ~ GENDER+PROVINCE+RESIDENCE+EDUCATION+OCCUPATION+HOMESMOKERULES+SMOKESICK+SMOKECANCER, data = gats, family = "binomial") %>% tidy() ``` --- ```r print(bigmod,n=17) ``` ``` #> # A tibble: 64 × 5 #> term estimate std.error statistic p.value #> <chr> <dbl> <dbl> <dbl> <dbl> #> 1 (Intercept) -0.805 0.406 -1.98 0.0474 #> 2 GENDERFemale 0.249 0.158 1.57 0.115 #> 3 PROVINCETianjin -0.847 0.636 -1.33 0.183 #> 4 PROVINCEHebei 0.387 0.398 0.972 0.331 #> 5 PROVINCEShanxi 0.335 0.398 0.841 0.400 #> 6 PROVINCENeimenggu 0.369 0.424 0.869 0.385 #> 7 PROVINCELiaoning 0.474 0.424 1.12 0.264 #> 8 PROVINCEJilin 0.494 0.408 1.21 0.226 #> 9 PROVINCEHeilongjiang -0.115 0.415 -0.277 0.782 #> 10 PROVINCEShanghai -0.248 0.423 -0.587 0.557 #> 11 PROVINCEJiangsu 0.334 0.392 0.852 0.394 #> 12 PROVINCEZhejiang 0.234 0.416 0.563 0.574 #> 13 PROVINCEAnhui 1.03 0.402 2.57 0.0103 #> 14 PROVINCEFujian 0.725 0.404 1.79 0.0729 #> 15 PROVINCEJiangxi -0.0774 0.445 -0.174 0.862 #> 16 PROVINCEShangdong 0.399 0.397 1.01 0.314 #> 17 PROVINCEHenan 0.810 0.398 2.04 0.0417 #> # … with 47 more rows ``` --- ```r print(bigmod[18:33,],n=16) ``` ``` #> # A tibble: 16 × 5 #> term estimate std.error statistic p.value #> <chr> <dbl> <dbl> <dbl> <dbl> #> 1 PROVINCEHubei 0.115 0.399 0.288 0.773 #> 2 PROVINCEHunan 0.757 0.390 1.94 0.0526 #> 3 PROVINCEGuangdong 0.188 0.390 0.482 0.629 #> 4 PROVINCEGuangxi 0.816 0.407 2.00 0.0450 #> 5 PROVINCEHainan 1.70 0.623 2.72 0.00650 #> 6 PROVINCEChongqing -0.0311 0.415 -0.0751 0.940 #> 7 PROVINCESichuan 0.223 0.388 0.575 0.566 #> 8 PROVINCEGuizhou 0.714 0.401 1.78 0.0748 #> 9 PROVINCEYunnan 0.692 0.421 1.64 0.100 #> 10 PROVINCEXizang -0.142 0.701 -0.202 0.840 #> 11 PROVINCEShaanxi 0.646 0.403 1.60 0.109 #> 12 PROVINCEGansu 0.785 0.432 1.82 0.0693 #> 13 PROVINCEQinghai 1.16 0.498 2.34 0.0195 #> 14 PROVINCENingxia 0.175 0.586 0.298 0.766 #> 15 PROVINCEXinjiang 0.120 0.436 0.275 0.783 #> 16 RESIDENCERural 0.137 0.0795 1.73 0.0840 ``` --- ```r print(bigmod[34:52,],n=19) ``` ``` #> # A tibble: 19 × 5 #> term estimate std.error statistic p.value #> <chr> <dbl> <dbl> <dbl> <dbl> #> 1 EDUCATIONLess than Prima… 0.285 0.159 1.79 0.0731 #> 2 EDUCATIONPrimary School 0.247 0.156 1.58 0.114 #> 3 EDUCATIONLess than Secon… 0.371 0.178 2.08 0.0373 #> 4 EDUCATIONSecondary School 0.206 0.151 1.36 0.173 #> 5 EDUCATIONHigh School 0.446 0.164 2.72 0.00652 #> 6 EDUCATIONUniversity 0.337 0.185 1.82 0.0689 #> 7 EDUCATIONPostgraduate -0.721 0.837 -0.862 0.389 #> 8 EDUCATIONDon't Know -11.8 325. -0.0364 0.971 #> 9 EDUCATIONRefused -10.9 325. -0.0334 0.973 #> 10 OCCUPATIONGovernment 0.252 0.200 1.26 0.206 #> 11 OCCUPATIONBusiness 0.0873 0.0978 0.893 0.372 #> 12 OCCUPATIONTeacher 0.320 0.348 0.920 0.358 #> 13 OCCUPATIONHealthcare 0.370 0.351 1.06 0.291 #> 14 OCCUPATIONStudent -1.07 0.637 -1.68 0.0926 #> 15 OCCUPATIONSoldier 0.255 1.02 0.249 0.803 #> 16 OCCUPATIONUnemployed 0.122 0.129 0.950 0.342 #> 17 OCCUPATIONRetired 0.292 0.126 2.33 0.0199 #> 18 OCCUPATIONOther 0.0844 0.0998 0.845 0.398 #> 19 OCCUPATIONDon't Know -11.6 230. -0.0504 0.960 ``` --- ```r print(bigmod[53:64,],n=13) ``` ``` #> # A tibble: 12 × 5 #> term estimate std.error statistic p.value #> <chr> <dbl> <dbl> <dbl> <dbl> #> 1 OCCUPATIONRefused -1.10 1.13 -0.971 3.31e-1 #> 2 HOMESMOKERULESNot Allo… 0.129 0.0903 1.43 1.53e-1 #> 3 HOMESMOKERULESNever Al… 0.425 0.0931 4.57 4.89e-6 #> 4 HOMESMOKERULESNo Rules -0.103 0.0852 -1.21 2.28e-1 #> 5 HOMESMOKERULESDon't Kn… -0.0354 0.340 -0.104 9.17e-1 #> 6 HOMESMOKERULESRefused 0.174 1.05 0.165 8.69e-1 #> 7 SMOKESICKNo -0.352 0.127 -2.78 5.46e-3 #> 8 SMOKESICKDon't Know -0.245 0.104 -2.36 1.82e-2 #> 9 SMOKESICKRefused -1.36 0.916 -1.48 1.39e-1 #> 10 SMOKECANCERNo -0.259 0.147 -1.76 7.84e-2 #> 11 SMOKECANCERDon't Know -0.453 0.0909 -4.98 6.46e-7 #> 12 SMOKECANCERRefused 0.393 0.847 0.464 6.43e-1 ``` --- ## Some observations - Our model has a lot of predictors - Some information is likely repetitive - Some information is likely noise - How do we pick a "best" model to get the "best" predictions? --- ## Prediction - The mechanics of prediction are **easy**: - Plug in values of predictors to the model equation - Calculate the predicted value of the response variable, `\(\hat{y}\)` -- - Getting it right is **hard**! - There is no guarantee the model estimates you have are correct - Or that your model will perform as well with new data as it did with your sample data --- ## Underfitting and overfitting <img src="w12-l01-prediction-overfitting_files/figure-html/unnamed-chunk-2-1.png" width="70%" style="display: block; margin: auto;" /> --- ## Spending our data - Several steps to create a useful model: parameter estimation, model selection, performance assessment, etc. - Doing all of this on the entire data we have available can lead to **overfitting** - Allocate specific subsets of data for different tasks, as opposed to allocating the largest possible amount to the model parameter estimation only (which is what we've done so far) --- class: middle # Splitting data --- ## Splitting data - **Training set:** - Sandbox for model building - Spend most of your time using the training set to develop the model - Majority of the data (often 80%) - **Testing set:** - Held in reserve to determine efficacy of one or two chosen models - Critical to look at it once, otherwise it becomes part of the modeling process - Remainder of the data (often 20%) --- ## Performing the split ```r # Fix random numbers by setting the seed # Enables analysis to be reproducible when random numbers are used set.seed(1213) # Put 80% of the data into the training set gats_split <- initial_split(gats, prop = 0.80) # Create data frames for the two sets: train_data <- training(gats_split) test_data <- testing(gats_split) ``` --- ## Peek at the split .small[ .pull-left[ ```r glimpse(train_data) ``` ``` #> Rows: 3,751 #> Columns: 10 #> $ RESIDENCE <fct> Urban, Urban, Urban, Urban, Rural, Rural… #> $ PROVINCE <fct> Shanghai, Heilongjiang, Shangdong, Shang… #> $ AGE <dbl> 28.07671, 75.91507, 70.53973, 30.91233, … #> $ GENDER <fct> Male, Female, Male, Male, Male, Male, Ma… #> $ EDUCATION <fct> University, Secondary School, None, High… #> $ OCCUPATION <fct> Other, Retired, Farming, Business, Farmi… #> $ TRYSTOP <fct> Yes, Yes, No, No, Yes, No, No, Yes, No, … #> $ HOMESMOKERULES <fct> Allowed, Not Allowed but Exceptions, All… #> $ SMOKESICK <fct> Yes, Yes, Don't Know, Yes, Yes, Yes, Don… #> $ SMOKECANCER <fct> Yes, Don't Know, Don't Know, Yes, Yes, Y… ``` ] .pull-right[ ```r glimpse(test_data) ``` ``` #> Rows: 938 #> Columns: 10 #> $ RESIDENCE <fct> Urban, Urban, Urban, Urban, Urban, Urban… #> $ PROVINCE <fct> Beijing, Beijing, Beijing, Beijing, Beij… #> $ AGE <dbl> 60.60548, 55.14795, 64.94795, 66.03014, … #> $ GENDER <fct> Female, Male, Male, Male, Male, Male, Fe… #> $ EDUCATION <fct> Primary School, University, Secondary Sc… #> $ OCCUPATION <fct> Retired, Retired, Retired, Retired, Reti… #> $ TRYSTOP <fct> Yes, No, No, No, Yes, No, No, No, No, No… #> $ HOMESMOKERULES <fct> Allowed, Not Allowed but Exceptions, All… #> $ SMOKESICK <fct> Yes, Yes, Yes, Yes, Don't Know, Yes, Yes… #> $ SMOKECANCER <fct> Yes, Yes, Yes, Yes, Don't Know, Yes, Yes… ``` ] ] --- ## Fit our model to the training dataset ```r trystop_m1 <- logistic_reg() %>% set_engine("glm") %>% fit(TRYSTOP ~ GENDER+PROVINCE+RESIDENCE+EDUCATION+OCCUPATION+HOMESMOKERULES+SMOKESICK+SMOKECANCER, data = train_data, family = "binomial") ``` .small[ ```r trystop_m1 ``` ``` #> parsnip model object #> #> Fit time: 162ms #> #> Call: stats::glm(formula = TRYSTOP ~ GENDER + PROVINCE + RESIDENCE + #> EDUCATION + OCCUPATION + HOMESMOKERULES + SMOKESICK + SMOKECANCER, #> family = stats::binomial, data = data) #> #> Coefficients: #> (Intercept) GENDERFemale #> -6.708e-01 2.610e-01 #> PROVINCETianjin PROVINCEHebei #> -5.248e-01 3.793e-01 #> PROVINCEShanxi PROVINCENeimenggu #> 3.616e-01 2.684e-01 #> PROVINCELiaoning PROVINCEJilin #> 3.961e-01 5.275e-01 #> PROVINCEHeilongjiang PROVINCEShanghai #> -1.762e-01 -5.718e-01 #> PROVINCEJiangsu PROVINCEZhejiang #> 1.769e-01 4.243e-01 #> PROVINCEAnhui PROVINCEFujian #> 1.264e+00 5.890e-01 #> PROVINCEJiangxi PROVINCEShangdong #> -2.023e-01 2.893e-01 #> PROVINCEHenan PROVINCEHubei #> 6.924e-01 9.467e-02 #> PROVINCEHunan PROVINCEGuangdong #> 7.593e-01 1.913e-01 #> PROVINCEGuangxi PROVINCEHainan #> 8.570e-01 1.798e+00 #> PROVINCEChongqing PROVINCESichuan #> 4.410e-02 2.711e-01 #> PROVINCEGuizhou PROVINCEYunnan #> 5.734e-01 7.005e-01 #> PROVINCEXizang PROVINCEShaanxi #> -2.320e-01 5.954e-01 #> PROVINCEGansu PROVINCEQinghai #> 7.537e-01 1.344e+00 #> PROVINCENingxia PROVINCEXinjiang #> 1.110e-02 4.020e-02 #> RESIDENCERural EDUCATIONLess than Primary School #> 1.063e-01 3.090e-01 #> EDUCATIONPrimary School EDUCATIONLess than Secondary School #> 7.829e-02 1.287e-01 #> EDUCATIONSecondary School EDUCATIONHigh School #> 1.482e-01 3.880e-01 #> EDUCATIONUniversity EDUCATIONPostgraduate #> 3.018e-01 -1.370e+00 #> EDUCATIONDon't Know EDUCATIONRefused #> -1.172e+01 -1.102e+01 #> OCCUPATIONGovernment OCCUPATIONBusiness #> 3.996e-01 7.594e-02 #> OCCUPATIONTeacher OCCUPATIONHealthcare #> 2.885e-01 2.437e-01 #> OCCUPATIONStudent OCCUPATIONSoldier #> -1.934e+00 -4.921e-01 #> OCCUPATIONUnemployed OCCUPATIONRetired #> 1.134e-01 2.939e-01 #> OCCUPATIONOther OCCUPATIONDon't Know #> -7.435e-04 -1.165e+01 #> OCCUPATIONRefused HOMESMOKERULESNot Allowed but Exceptions #> -1.063e+00 9.078e-02 #> HOMESMOKERULESNever Allowed HOMESMOKERULESNo Rules #> 3.801e-01 -1.388e-01 #> HOMESMOKERULESDon't Know HOMESMOKERULESRefused #> 3.701e-02 -1.931e-01 #> SMOKESICKNo SMOKESICKDon't Know #> -3.419e-01 -1.801e-01 #> SMOKESICKRefused SMOKECANCERNo #> -1.661e+00 -2.183e-01 #> SMOKECANCERDon't Know SMOKECANCERRefused #> -5.140e-01 -9.663e-01 #> #> Degrees of Freedom: 3750 Total (i.e. Null); 3687 Residual #> Null Deviance: 5197 #> Residual Deviance: 4936 AIC: 5064 ``` ] --- ## Predict outcome on the testing dataset ```r predict(trystop_m1, test_data) ``` ``` #> # A tibble: 938 × 1 #> .pred_class #> <fct> #> 1 No #> 2 Yes #> 3 No #> 4 Yes #> 5 No #> 6 No #> # … with 932 more rows ``` --- ## Predict probabilities on the testing dataset ```r trystop_pred1 <- predict(trystop_m1, test_data, type = "prob") %>% bind_cols(test_data %>% select(TRYSTOP)) trystop_pred1 ``` ``` #> # A tibble: 938 × 3 #> .pred_No .pred_Yes TRYSTOP #> <dbl> <dbl> <fct> #> 1 0.509 0.491 Yes #> 2 0.496 0.504 No #> 3 0.557 0.443 No #> 4 0.496 0.504 No #> 5 0.596 0.404 Yes #> 6 0.680 0.320 No #> # … with 932 more rows ``` --- ## A closer look at predictions .pull-left-wide[ ```r trystop_pred1 %>% arrange(desc(.pred_Yes)) %>% print(n = 15) ``` ``` #> # A tibble: 938 × 3 #> .pred_No .pred_Yes TRYSTOP #> <dbl> <dbl> <fct> #> 1 0.158 0.842 No #> 2 0.180 0.820 Yes *#> 3 0.193 0.807 No #> 4 0.196 0.804 No *#> 5 0.197 0.803 Yes #> 6 0.205 0.795 Yes #> 7 0.220 0.780 No #> 8 0.222 0.778 No #> 9 0.223 0.777 Yes #> 10 0.227 0.773 No #> 11 0.227 0.773 Yes #> 12 0.228 0.772 No #> 13 0.241 0.759 No #> 14 0.243 0.757 Yes #> 15 0.247 0.753 No #> # … with 923 more rows ``` ] --- ## Evaluate the performance **Receiver operating characteristic (ROC) curve**<sup>+</sup> plots true positive rate vs. false positive rate (1 - specificity) .pull-left[ ```r trystop_pred1 %>% roc_curve( truth = TRYSTOP, .pred_Yes, event_level = "second" ) %>% autoplot() ``` ] .pull-right[ <img src="w12-l01-prediction-overfitting_files/figure-html/unnamed-chunk-11-1.png" width="100%" style="display: block; margin: auto;" /> ] .footnote[ .small[ <sup>+</sup>Originally developed for operators of military radar receivers, hence the name. ] ] --- ## Evaluate the performance Find the area under the curve: .pull-left[ ```r trystop_pred1 %>% roc_auc( truth = TRYSTOP, .pred_Yes, event_level = "second" ) ``` ``` #> # A tibble: 1 × 3 #> .metric .estimator .estimate #> <chr> <chr> <dbl> #> 1 roc_auc binary 0.588 ``` ] .pull-right[ <img src="w12-l01-prediction-overfitting_files/figure-html/unnamed-chunk-13-1.png" width="100%" style="display: block; margin: auto;" /> ] OK, this is not a lot better than randomly predicting based on the prevalence of having tried to stop smoking in the data (AUC=0.5). --- ## Comparing Models Let's compare the fit of this large model to one that only contains province, home smoking rules, and beliefs about whether smoking makes you sick or causes cancer. ```r trystop_m2 <- logistic_reg() %>% set_engine("glm") %>% fit(TRYSTOP ~ PROVINCE+HOMESMOKERULES+SMOKESICK+SMOKECANCER, data = train_data, family = "binomial") trystop_pred2 <- predict(trystop_m2, test_data, type = "prob") %>% bind_cols(test_data %>% select(TRYSTOP)) ``` --- .pull-left[ "Big" model with gender, education, occupation, residence, province, home rules, beliefs <img src="w12-l01-prediction-overfitting_files/figure-html/unnamed-chunk-14-1.png" width="100%" style="display: block; margin: auto;" /> ``` #> # A tibble: 1 × 3 #> .metric .estimator .estimate #> <chr> <chr> <dbl> #> 1 roc_auc binary 0.588 ``` ] .pull-right[ Model without gender, education, occupation, and residence <img src="w12-l01-prediction-overfitting_files/figure-html/unnamed-chunk-16-1.png" width="100%" style="display: block; margin: auto;" /> ``` #> # A tibble: 1 × 3 #> .metric .estimator .estimate #> <chr> <chr> <dbl> #> 1 roc_auc binary 0.599 ``` ] --- The model with only an intercept should yield an AUC of 0.5 (chance assignment). ```r trystop_m3 <- logistic_reg() %>% set_engine("glm") %>% fit(TRYSTOP ~ 1, data = train_data, family = "binomial") trystop_pred3 <- predict(trystop_m3, test_data, type = "prob") %>% bind_cols(test_data %>% select(TRYSTOP)) trystop_pred3 %>% roc_auc( truth = TRYSTOP, .pred_Yes, event_level = "second" ) ``` ``` #> # A tibble: 1 × 3 #> .metric .estimator .estimate #> <chr> <chr> <dbl> #> 1 roc_auc binary 0.5 ``` --- ```r trystop_pred3 %>% roc_curve( truth = TRYSTOP, .pred_Yes, event_level = "second" ) %>% autoplot() ``` <img src="w12-l01-prediction-overfitting_files/figure-html/unnamed-chunk-18-1.png" width="60%" style="display: block; margin: auto;" />