generated from bluegreen-labs/R_project_template
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
457 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Basic xgboost model with limited | ||
# hyperparameter tuning | ||
|
||
# load the ecosystem | ||
library(tidymodels) | ||
library(dplyr) | ||
library(caret) | ||
|
||
source("R/read_ml_data.R") | ||
set.seed(0) | ||
|
||
#---- data partitioning ---- | ||
|
||
# read in training data | ||
ml_df <- read_ml_data( | ||
here::here("data/machine_learning_training_data_landsat.rds"), | ||
spatial = FALSE | ||
) |> | ||
dplyr::select( | ||
-site, | ||
-date | ||
) | ||
|
||
# create a data split across | ||
# across both drought and non-drought days | ||
ml_df_split <- ml_df |> | ||
rsample::initial_split( | ||
strata = is_flue_drought, | ||
prop = 0.8 | ||
) | ||
|
||
#---- center the data ---- | ||
|
||
# select training and testing | ||
# data based on this split | ||
# and center the data on training | ||
# mean / sd | ||
# this scales values within the same | ||
# range for better representation | ||
|
||
train <- rsample::training(ml_df_split) |> | ||
dplyr::select(-is_flue_drought) | ||
test <- rsample::testing(ml_df_split) |> | ||
dplyr::select(-is_flue_drought) | ||
|
||
#---- model definition and tuning ---- | ||
|
||
# setup model | ||
model <- parsnip::boost_tree( | ||
trees = 50, | ||
min_n = tune() | ||
) |> | ||
set_engine("xgboost") |> | ||
set_mode("regression") | ||
|
||
# create workflow | ||
wflow <- | ||
workflows::workflow() |> | ||
workflows::add_model(model) |> | ||
workflows::add_formula(flue ~ .) | ||
|
||
# set hyperparameter selection settings | ||
hp_settings <- dials::grid_latin_hypercube( | ||
tune::extract_parameter_set_dials(wflow), | ||
size = 3 | ||
) | ||
|
||
# cross-validation settings | ||
folds <- rsample::vfold_cv( | ||
train, | ||
v = 2 | ||
) | ||
|
||
# optimize the model (hyper) parameters | ||
# using the: | ||
# 1. workflow (i.e. model) | ||
# 2. the cross-validation across training data | ||
# 3. the (hyper) parameter specifications | ||
# all data are saved for evaluation | ||
results <- tune::tune_grid( | ||
wflow, | ||
resamples = folds, | ||
grid = hp_settings, | ||
control = tune::control_grid( | ||
save_pred = TRUE, | ||
save_workflow = TRUE | ||
) | ||
) | ||
|
||
# select the best model | ||
best <- tune::select_best( | ||
results, | ||
metric = "rmse" | ||
) | ||
|
||
# cook up a model using finalize_workflow | ||
# combining best parameters with a workflow | ||
best_wflow <- tune::finalize_workflow( | ||
wflow, | ||
best | ||
) | ||
|
||
# run (consolidate) fit on best hyperparameters | ||
best_model <- fit(best_wflow, train) | ||
|
||
# run the model on our test data | ||
# using predict() | ||
test_results <- predict(best_model, test) | ||
test_results <- bind_cols(flue = test$flue, test_results) | ||
|
||
# grab test metrics | ||
tm <- test_results |> | ||
metrics(truth = flue, estimate = .pred) | ||
|
||
# save best model | ||
saveRDS( | ||
best_model, | ||
"data/regression_model_landsat.rds", | ||
compress = "xz" | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Basic xgboost model with limited | ||
# hyperparameter tuning with | ||
# leave site out cross validation | ||
|
||
# load the ecosystem | ||
library(tidymodels) | ||
library(dplyr) | ||
library(caret) | ||
source("R/read_ml_data.R") | ||
set.seed(0) | ||
|
||
# read in training data | ||
ml_df <- read_ml_data( | ||
here::here("data/machine_learning_training_data_landsat.rds"), | ||
spatial = FALSE | ||
) | ||
|
||
# Leave-Site-Out cross validation loop | ||
results <- lapply(unique(ml_df$site), function(site){ | ||
|
||
#---- data partitioning ---- | ||
|
||
# create a data split across | ||
# across both drought and non-drought days | ||
ml_df_split <- ml_df |> | ||
filter( | ||
site != !!site | ||
) |> | ||
rsample::initial_split( | ||
strata = is_flue_drought, | ||
prop = 0.8 | ||
) | ||
|
||
# select training and testing | ||
# data based on this split | ||
train <- rsample::training(ml_df_split) |> | ||
dplyr::select( | ||
-site, | ||
-is_flue_drought, | ||
-date | ||
) | ||
|
||
test <- rsample::testing(ml_df_split) |> | ||
dplyr::select( | ||
-site, | ||
-is_flue_drought, | ||
-date | ||
) | ||
|
||
#---- model definition and tuning ---- | ||
|
||
# setup model | ||
model <- parsnip::boost_tree( | ||
trees = 50, | ||
min_n = tune() | ||
) |> | ||
set_engine("xgboost") |> | ||
set_mode("regression") | ||
|
||
# create workflow | ||
wflow <- | ||
workflows::workflow() |> | ||
workflows::add_model(model) |> | ||
workflows::add_formula(flue ~ .) | ||
|
||
# set hyperparameter selection settings | ||
hp_settings <- dials::grid_latin_hypercube( | ||
tune::extract_parameter_set_dials(wflow), | ||
size = 3 | ||
) | ||
|
||
# cross-validation settings | ||
folds <- rsample::vfold_cv( | ||
train, | ||
v = 10 | ||
) | ||
|
||
# optimize the model (hyper) parameters | ||
# using the: | ||
# 1. workflow (i.e. model) | ||
# 2. the cross-validation across training data | ||
# 3. the (hyper) parameter specifications | ||
# all data are saved for evaluation | ||
results <- tune::tune_grid( | ||
wflow, | ||
resamples = folds, | ||
grid = hp_settings, | ||
control = tune::control_grid( | ||
save_pred = TRUE, | ||
save_workflow = TRUE | ||
) | ||
) | ||
|
||
# select the best model | ||
best <- tune::select_best( | ||
results, | ||
metric = "rmse" | ||
) | ||
|
||
# cook up a model using finalize_workflow | ||
# combining best parameters with a workflow | ||
best_wflow <- tune::finalize_workflow( | ||
wflow, | ||
best | ||
) | ||
|
||
# run (consolidate) fit on best hyperparameters | ||
best_model <- fit(best_wflow, train) | ||
|
||
# grab out of sample test data | ||
LSO_test <- ml_df |> | ||
filter( | ||
site == !!site | ||
) | ||
|
||
# run the model on our test data | ||
# using predict() | ||
test_results <- predict(best_model, LSO_test)$.pred | ||
test_results <- bind_cols( | ||
LSO_test, | ||
flue_predicted = test_results | ||
) | ||
|
||
return(test_results) | ||
}) | ||
|
||
# collapse list to data frame | ||
results <- bind_rows(results) | ||
|
||
# write summary results to file | ||
saveRDS( | ||
results, | ||
"data/LSO_results_landsat.rds", | ||
compress = "xz" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# process the Landsat data into a machine learning | ||
# data frame consistent with the one for MODIS | ||
|
||
library(tidyverse) | ||
|
||
# read landsat data | ||
df <- readRDS("data-raw/landsat7_data.rds") |> | ||
mutate( | ||
date = as.Date(date), | ||
doy = as.numeric(format(date, "%j")), | ||
year = as.numeric(format(date, "%Y")) | ||
) |> | ||
rename( | ||
site = sitename | ||
) | ||
|
||
# read in flue data | ||
flue <- readr::read_csv("data/flue_stocker18nphyt.csv") | ||
|
||
# merge with flue | ||
df <- left_join( | ||
flue, | ||
df | ||
) | ||
|
||
df <- df |> | ||
mutate( | ||
across(starts_with("SR"), | ||
function(x){0.0000275* x + -0.2} | ||
) | ||
) | ||
|
||
df <- df |> | ||
mutate( | ||
across(starts_with("ST"), | ||
function(x){0.00341802* x + -149.0} | ||
) | ||
) | ||
|
||
# TODO | ||
# Proper QA screening on the bitlevel of the QA_PIXEL data | ||
# There also duplicate acquisitions in the dataset (sort out source | ||
# of this additional data - not sure if this is an artifact of sorts) | ||
|
||
df <- df |> | ||
filter( | ||
QA_PIXEL == 5440 | ||
) |> | ||
select( | ||
-QA_PIXEL, | ||
-id, | ||
-product, | ||
-latitude, | ||
-longitude | ||
) |> | ||
na.omit() | ||
|
||
# save data | ||
saveRDS( | ||
df, | ||
"data/machine_learning_training_data_landsat.rds", | ||
compress = "xz" | ||
) | ||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.