Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentarelbundock committed Mar 9, 2023
1 parent 29d3d74 commit 190d946
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
4 changes: 3 additions & 1 deletion R/sanitize_newdata.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ sanitize_newdata <- function(model, newdata, by, modeldata) {
dedup_newdata <- function(model, newdata, by, wts, comparison = "difference", cross = FALSE, byfun = NULL) {

flag <- isTRUE(checkmate::check_string(comparison, pattern = "avg"))
if (!flag && (isFALSE(by) ||
if (!flag && (
isFALSE(by) || # weights only make sense when we are marginalizing
!is.null(wts) ||
!is.null(byfun) ||
!isFALSE(cross) ||
Expand All @@ -196,6 +197,7 @@ dedup_newdata <- function(model, newdata, by, wts, comparison = "difference", cr
out[, "rowid" := NULL]
}


categ <- c("factor", "character", "logical", "strata", "cluster", "binary")
if (!all(vclass %in% categ)) {
return(newdata)
Expand Down
44 changes: 42 additions & 2 deletions sandbox/benchmarks.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ results[["categorical_predictors"]]
## Many factor variables

```{r}
N <- 1e4
dat <- data.frame(lapply(1:10, \(i) as.factor(sample(1:4, N, replace = TRUE))))
N <- 1e6
dat <- data.frame(lapply(1:20, \(i) as.factor(sample(1:4, N, replace = TRUE))))
dat <- setNames(dat, paste0("X", seq_along(dat)))
dat$y <- rnorm(nrow(dat))
mod <- lm(y ~ ., dat)
Expand All @@ -166,4 +166,44 @@ results[["predictions_many_factors"]]
```{r}
fn <- paste0("benchmarks_", packageVersion("marginaleffects"), ".rds")
saveRDS(results, fn)
```


# Compare results

```{r, eval = FALSE}
library(patchwork)
library(ggplot2)
old <- readRDS("sandbox/benchmarks_0.9.0.rds")
new <- readRDS("sandbox/benchmarks_0.11.0.rds")
new[["tiny"]] <- old[["tiny"]] <- new[["many_variables_vcov"]] <- old[["many_variables_vcov"]] <- NULL
tim <- list()
mem <- list()
for (i in names(old)) {
tim[[i]] <- old[[i]]$median / new[[i]]$median
mem[[i]] <- old[[i]]$mem_alloc / new[[i]]$mem_alloc
}
mem <- unlist(mem)
tim <- unlist(tim)
hist(mem)
old[[1]]$median
res <- data.frame(
old_time = unlist(sapply(old, \(x) x$median)),
new_time = unlist(sapply(new, \(x) x$median)),
old_memory = unlist(sapply(old, \(x) x$mem_alloc)),
new_memory = unlist(sapply(new, \(x) x$mem_alloc))
)
p1 <- ggplot(res, aes(new_time, old_time)) + geom_point()
lapply(old, class)
do.call("rbind", new)
```

0 comments on commit 190d946

Please sign in to comment.