Skip to content

Commit

Permalink
Add more checking functions:
Browse files Browse the repository at this point in the history
* check_for_free_state_error()
* check_initial_values_correct_class()
* check_missing_infinite_values()
* check_timeout()
* check_sampling_implemented()
* check_truncation_implemented()
* add internal `are_intials()` function
  • Loading branch information
njtierney committed Aug 16, 2024
1 parent 360d23c commit fe3356e
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 75 deletions.
99 changes: 93 additions & 6 deletions R/checkers.R
Original file line number Diff line number Diff line change
Expand Up @@ -1399,14 +1399,16 @@ check_initial_values_match_chains <- function(initial_values,
n_chains,
call = rlang::caller_env()){

if (!is.initials(initial_values) && is.list(initial_values)) {
initials <- initial_values
not_initials_but_list <- !is.initials(initials) && is.list(initials)
if (not_initials_but_list) {
# if the user provided a list of initial values, check elements and length
are_initials <- vapply(initial_values, is.initials, FUN.VALUE = FALSE)
all_initials <- all(are_initials(initials))

n_sets <- length(initial_values)
n_sets <- length(initials)

initial_values_do_not_match_chains <- n_sets != n_chains
if (initial_values_do_not_match_chains && all(are_initials)) {
if (initial_values_do_not_match_chains && all_initials) {
cli::cli_abort(
message = c(
"The number of provided initial values does not match chains",
Expand Down Expand Up @@ -1437,6 +1439,29 @@ check_initial_values_correct_dim <- function(target_dims,

}

check_initial_values_correct_class <- function(initial_values,
call = rlang::caller_env()){

initials <- initial_values
not_initials_but_list <- !is.initials(initials) && is.list(initials)
not_initials_not_list <- !is.initials(initials) && !is.list(initials)
# if the user provided a list of initial values, check elements and the
# length
all_initials <- all(are_initials(initials))
not_all_initials <- !all_initials

if (not_initials_but_list && not_all_initials || not_initials_not_list) {
cli::cli_abort(
message = c(
"{.arg initial_values} must be an initials object created with \\
{.fun initials}, or a simple list of initials objects"
),
call = call
)
}

}

check_nodes_all_variable <- function(nodes,
call = rlang::caller_env()){
types <- lapply(nodes, node_type)
Expand Down Expand Up @@ -1921,16 +1946,78 @@ check_has_representation <- function(repr,
check_is_greta_array <- function(x,
arg = rlang::caller_arg(x),
call = rlang::caller_env()){
# only for greta arrays
if (!is.greta_array(x)) {
cli::cli_abort(
message = c(
"{.arg {arg}} must be {.cls greta_array}",
"Object was is {.cls {class(x)}}"
"{.arg {arg}} is: {.cls {class(x)}}"
),
call = call
)
}
}

check_missing_infinite_values <- function(x,
optional,
call = rlang::caller_env()){
contains_missing_or_inf <- !optional & any(!is.finite(x))
if (contains_missing_or_inf) {
cli::cli_abort(
message = c(
"{.cls greta_array} must not contain missing or infinite values"
),
call = call
)
}
}

check_truncation_implemented <- function(tfp_distribution,
distribution_node,
call = rlang::caller_env()){

cdf <- tfp_distribution$cdf
quantile <- tfp_distribution$quantile

is_truncated <- is.null(cdf) | is.null(quantile)
if (is_truncated) {
cli::cli_abort(
message = c(
"Sampling is not yet implemented for truncated \\
{.val {distribution_node$distribution_name}} distributions"
),
call = call
)
}

}

check_sampling_implemented <- function(sample,
distribution_node,
call = rlang::caller_env()){
if (is.null(sample)) {
cli::cli_abort(
"Sampling is not yet implemented for \\
{.val {distribution_node$distribution_name}} distributions"
)
}
}

check_timeout <- function(it,
maxit,
call = rlang::caller_env()){
# check we didn't time out
if (it == maxit) {
cli::cli_abort(
message = c(
"Could not determine the number of independent models in a reasonable \\
amount of time",
"Iterations = {.val {it}}",
"Maximum iterations = {.cal {maxit}}"
),
call = call
)

Check warning on line 2018 in R/checkers.R

View check run for this annotation

Codecov / codecov/patch

R/checkers.R#L2010-L2018

Added lines #L2010 - L2018 were not covered by tests
}

}


Expand Down
24 changes: 3 additions & 21 deletions R/dag_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -769,13 +769,7 @@ dag_class <- R6Class(
}
}

# check we didn't time out
if (it == maxit) {
cli::cli_abort(
"could not determine the number of independent models in a \\
reasonable amount of time"
)
}
check_timeout(it, maxit)

# find the cluster IDs
n <- nrow(r)
Expand Down Expand Up @@ -812,12 +806,7 @@ dag_class <- R6Class(

sample <- tfp_distribution$sample

if (is.null(sample)) {
cli::cli_abort(
"sampling is not yet implemented for \\
{.val {distribution_node$distribution_name}} distributions"
)
}
check_sampling_implemented(sample, distribution_node)

truncation <- distribution_node$truncation

Expand All @@ -833,14 +822,7 @@ dag_class <- R6Class(

cdf <- tfp_distribution$cdf
quantile <- tfp_distribution$quantile

is_truncated <- is.null(cdf) | is.null(quantile)
if (is_truncated) {
cli::cli_abort(
"sampling is not yet implemented for truncated \\
{.val {distribution_node$distribution_name}} distributions"
)
}
check_truncation_implemented(tfp_distribution, distribution_node)

# generate a random uniform sample of the correct shape and transform
# through truncated inverse CDF to get draws on truncated scale
Expand Down
1 change: 1 addition & 0 deletions R/distribution.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

node <- get_node(greta_array)

# TODO revisit checking functions here
# only for greta arrays without distributions
## TODO provide more detail on the distribution already assigned
## This might come up when the user accidentally runs assignment
Expand Down
2 changes: 2 additions & 0 deletions R/extract_replace_combine.R
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ abind.greta_array <- function(...,
along <- max(1, min(n + 1, ceiling(along)))
}

# TODO revisit checking functions here
along_outside_0_n <- !(along %in% 0:n)
if (along_outside_0_n) {
cli::cli_abort(
Expand Down Expand Up @@ -530,6 +531,7 @@ length.greta_array <- function(x) {

dims <- dims %||% length(x)

# TODO revisit logic / checking functions here
if (length(dims) == 0L) {
cli::cli_abort(
"length-0 dimension vector is invalid"
Expand Down
8 changes: 2 additions & 6 deletions R/greta_array_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,8 @@ as.greta_array.array <- function(x, optional = FALSE, original_x = x, ...) {
# finally, reject if there are any missing values, or set up the greta_array
#' @export
as.greta_array.numeric <- function(x, optional = FALSE, original_x = x, ...) {
contains_missing_or_inf <- !optional & any(!is.finite(x))
if (contains_missing_or_inf) {
cli::cli_abort(
"{.cls greta_array} must not contain missing or infinite values"
)
}
check_missing_infinite_values(x, optional)

as.greta_array.node(data_node$new(x),
optional = optional,
original_x = original_x,
Expand Down
33 changes: 5 additions & 28 deletions R/inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,8 @@ to_free <- function(node, data) {
lower <- node$lower
upper <- node$upper

# TODO
# replace these with more informative errors related to the range of values
unsupported_error <- function() {
cli::cli_abort(
"Some provided initial values are outside the range of values their \\
Expand Down Expand Up @@ -661,7 +663,6 @@ parse_initial_values <- function(initials, dag) {
# correct length, with nice error messages
prep_initials <- function(initial_values, n_chains, dag) {

# TODO: Tidy up the logic here for errors and messages
# if the user passed a single set of initial values, repeat them for all
# chains
if (is.initials(initial_values)) {
Expand All @@ -673,33 +674,9 @@ prep_initials <- function(initial_values, n_chains, dag) {
)
}

not_initials_but_list <- !is.initials(initial_values) && is.list(initial_values)
if (not_initials_but_list) {

# if the user provided a list of initial values, check elements and the
# length
are_initials <- vapply(initial_values, is.initials, FUN.VALUE = FALSE)

if (all(are_initials)) {
check_initial_values_match_chains(initial_values, n_chains)
}
if (!all(are_initials)) {
initial_values <- NULL
}
}
if (!not_initials_but_list) {
initial_values <- NULL
}

# error on a bad object
if (is.null(initial_values)) {
cli::cli_abort(
c(
"{.arg initial_values} must be an initials object created with \\
{.fun initials}, or a simple list of initials objects"
)
)
}
# TODO: revisit logic here for errors and messages
check_initial_values_match_chains(initial_values, n_chains)
check_initial_values_correct_class(initial_values)

# convert them to free state vectors
initial_values <- lapply(
Expand Down
14 changes: 11 additions & 3 deletions R/inference_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,15 @@ sampler <- R6Class(
)
) # closing cleanly

# if it's fine, batch_results is the output
# if it's a non-numerical error, it will error
# if it's a numerical error, batch_results will be an error object
self$check_for_free_state_error(result, n_samples)

result
},

check_for_free_state_error = function(result, n_samples){
# if it's fine, batch_results is the output
# if it's a non-numerical error, it will error
# if it's a numerical error, batch_results will be an error object
Expand All @@ -827,7 +836,7 @@ sampler <- R6Class(
# won't be valid if we just restart, so we need to error here,
# informing the user how to run one sample at a time
cli::cli_abort(
c(
message = c(
"TensorFlow hit a numerical problem that caused it to error",
"{.pkg greta} can handle these as bad proposals if you rerun \\
{.fun mcmc} with the argument {.code one_by_one = TRUE}.",
Expand All @@ -839,9 +848,8 @@ sampler <- R6Class(

}
}

result
},

sampler_parameter_values = function() {

# random number of integration steps
Expand Down
8 changes: 8 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -1166,3 +1166,11 @@ outside_version_range <- function(provided, range) {
}

pretty_dim <- function(x) paste0(dim(x), collapse = "x")

are_initials <- function(x){
vapply(
X = x,
FUN = is.initials,
FUN.VALUE = logical(1)
)
}
2 changes: 1 addition & 1 deletion tests/testthat/_snaps/calculate.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

# calculate errors if a distribution cannot be sampled from

sampling is not yet implemented for "hypergeometric" distributions
Sampling is not yet implemented for "hypergeometric" distributions

# calculate errors nicely if nsim is invalid

Expand Down
16 changes: 12 additions & 4 deletions tests/testthat/_snaps/distributions.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,21 @@

# wishart distribution errors informatively

`Sigma` must be a square 2D greta array
However, `Sigma` has dimensions "3x3x3"
Code
wishart(3, b)
Condition
Error in `initialize()`:
! `Sigma` must be a square 2D greta array
However, `Sigma` has dimensions "3x3x3"

---

`Sigma` must be a square 2D greta array
However, `Sigma` has dimensions "3x2"
Code
wishart(3, c)
Condition
Error in `initialize()`:
! `Sigma` must be a square 2D greta array
However, `Sigma` has dimensions "3x2"

# lkj_correlation distribution errors informatively

Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/iid_samples.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# distributions without RNG error nicely

sampling is not yet implemented for "hypergeometric" distributions
Sampling is not yet implemented for "hypergeometric" distributions

---

sampling is not yet implemented for truncated "f" distributions
Sampling is not yet implemented for truncated "f" distributions

2 changes: 1 addition & 1 deletion tests/testthat/_snaps/simulate.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# simulate errors if a distribution cannot be sampled from

sampling is not yet implemented for "hypergeometric" distributions
Sampling is not yet implemented for "hypergeometric" distributions

# simulate errors nicely if nsim is invalid

Expand Down
8 changes: 6 additions & 2 deletions tests/testthat/_snaps/syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@

# distribution() errors informatively

`distribution()` expects object of type <greta_array>
object was not a <greta_array>, but <array>
Code
distribution(y)
Condition
Error in `distribution()`:
! `greta_array` must be <greta_array>
`greta_array` is: <array>

3 changes: 2 additions & 1 deletion tests/testthat/test_syntax.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ test_that("distribution() errors informatively", {

y <- randn(3)

expect_snapshot_error(
expect_snapshot(
error = TRUE,
distribution(y)
)
})

0 comments on commit fe3356e

Please sign in to comment.