From bf530b47130011249559d46b8a217aae00f7fe80 Mon Sep 17 00:00:00 2001 From: chainsawriot Date: Tue, 21 Nov 2023 18:17:48 +0100 Subject: [PATCH] Use quanteda::index() ref #38 --- DESCRIPTION | 3 +- R/get_dist.R | 63 ++++++++++++++++++------------- README.md | 6 +-- tests/testthat/test-tokens_dist.R | 14 +++---- 4 files changed, 48 insertions(+), 38 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 5f7bc90..a84801a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -19,6 +19,5 @@ Suggests: Config/testthat/edition: 3 Imports: quanteda, - Matrix, - utils + Matrix VignetteBuilder: knitr diff --git a/R/get_dist.R b/R/get_dist.R index ab39f28..5d8fa42 100644 --- a/R/get_dist.R +++ b/R/get_dist.R @@ -3,41 +3,51 @@ row_mins_c <- function(mat) { .Call("row_mins_", mat, as.integer(nrow(mat)), as.integer(ncol(mat))) } -cal_dist <- function(y, poss) { - return(abs(y - poss)) +cal_dist <- function(from, to, poss) { + return(pmin(abs(to - poss), abs(from - poss))) } -cal_proximity <- function(tokenized_text, keywords_poss, get_min = TRUE, count_from = 1) { - target_idx <- which(tokenized_text %in% keywords_poss) - poss <- seq_along(tokenized_text) - if (length(target_idx) == 0) { +cal_proximity <- function(tokenized_text, pattern, get_min = TRUE, count_from = 1, valuetype) { + ## target_idx <- which(tokenized_text %in% keywords_poss) + poss <- seq_along(as.character(tokenized_text)) + idx <- quanteda::index(tokenized_text, pattern, valuetype = valuetype) + if (nrow(idx) == 0) { return(rep(length(poss) + count_from, length(poss))) } - res <- sapply(target_idx, cal_dist, poss = poss) + res <- mapply(cal_dist, from = idx$from, to = idx$to, MoreArgs = list("poss" = poss)) if (get_min) { return(row_mins_c(res) + count_from) } return(res) } -get_proximity <- function(x, keywords, get_min = TRUE, count_from = 1) { - keywords_poss <- which(attr(x, "types") %in% keywords) - return(lapply(unclass(x), cal_proximity, keywords_poss = keywords_poss, get_min = get_min, count_from = count_from)) -} - -resolve_keywords <- function(keywords, features, valuetype) { - if (valuetype == "fixed") { - return(keywords) - } - if (valuetype == "glob") { - regex <- paste(utils::glob2rx(keywords), collapse = "|") +get_proximity <- function(x, pattern, get_min = TRUE, count_from = 1, valuetype) { + output <- list() + for (i in seq_along(x)) { + output[[i]] <- cal_proximity(x[i], pattern = pattern, get_min = get_min, count_from = count_from, valuetype = valuetype) } - if (valuetype == "regex") { - regex <- paste(keywords, collapse = "|") - } - return(grep(regex, features, value = TRUE)) + names(output) <- quanteda::docnames(x) + return(output) } +## resolve_keywords <- function(keywords, features, valuetype) { + ## if (valuetype == "fixed") { + ## return(keywords) + ## } + ## if (valuetype == "glob") { + ## regex <- paste(utils::glob2rx(keywords), collapse = "|") + ## } + ## if (valuetype == "regex") { + ## regex <- paste(keywords, collapse = "|") + ## } + ## return(grep(regex, features, value = TRUE)) +## res <- quanteda::pattern2fixed(pattern = keywords, types = features, valuetype = valuetype) +## if (is.null(res)) { +## return(list()) +## } +## return(res) +## } + #' Extract Proximity Information #' #' This function extracts distance information from a [quanteda::tokens()] object. @@ -91,11 +101,12 @@ tokens_proximity <- function(x, pattern, get_min = TRUE, valuetype = c("glob", " x <- quanteda::tokens_tolower(x, keep_acronyms = keep_acronyms) } valuetype <- match.arg(valuetype) - keywords <- resolve_keywords(pattern, attr(x, "types"), valuetype) + ## Maybe this is now only for pretty print? + ## keywords <- resolve_keywords(pattern, attr(x, "types"), valuetype) toks <- x - proximity <- get_proximity(x = toks, keywords = keywords, get_min = get_min, count_from = count_from) + proximity <- get_proximity(x = toks, pattern = pattern, get_min = get_min, count_from = count_from, valuetype = valuetype) quanteda::docvars(toks)$proximity <- I(proximity) - quanteda::meta(toks, field = "keywords") <- keywords + quanteda::meta(toks, field = "pattern") <- pattern quanteda::meta(toks, field = "get_min") <- get_min quanteda::meta(toks, field = "tolower") <- tolower quanteda::meta(toks, field = "keep_acronyms") <- keep_acronyms @@ -116,7 +127,7 @@ convert_df <- function(tokens_obj, proximity_obj, doc_id) { print.tokens_with_proximity <- function(x, ...) { print(as.tokens(x), ...) cat("With proximity vector(s).\n") - cat("keywords: ", quanteda::meta(x, field = "keywords"), "\n") + cat("Pattern: ", quanteda::meta(x, field = "pattern"), "\n") } #' @importFrom quanteda as.tokens diff --git a/README.md b/README.md index ac3f5cf..e160133 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ tok1 #> [ ... and 31 more ] #> #> With proximity vector(s). -#> keywords: turkish +#> Pattern: turkish ``` You can access the proximity vectors by @@ -150,7 +150,7 @@ tok2 #> [ ... and 31 more ] #> #> With proximity vector(s). -#> keywords: hamas +#> Pattern: hamas ``` ``` r @@ -177,7 +177,7 @@ tok3 #> [ ... and 31 more ] #> #> With proximity vector(s). -#> keywords: eu brussels +#> Pattern: eu brussels ``` ``` r diff --git a/tests/testthat/test-tokens_dist.R b/tests/testthat/test-tokens_dist.R index fbb85bf..383e952 100644 --- a/tests/testthat/test-tokens_dist.R +++ b/tests/testthat/test-tokens_dist.R @@ -6,11 +6,11 @@ test_that("edge cases", { expect_error("" %>% tokens() %>% tokens_proximity("") %>% convert(), NA) }) -test_that("resolve_keywords", { - expect_equal(resolve_keywords(c("abc", "def"), c("abcd", "defa"), valuetype = "fixed"), c("abc", "def")) - expect_equal(resolve_keywords(c("abc*", "def*"), c("abcd", "defa"), valuetype = "glob"), c("abcd", "defa")) - expect_equal(resolve_keywords(c("a"), c("abcd", "defa"), valuetype = "regex"), c("abcd", "defa")) -}) +## test_that("resolve_keywords", { +## expect_equal(resolve_keywords(c("abc", "def"), c("abcd", "defa"), valuetype = "fixed"), character(0)) +## expect_equal(resolve_keywords(c("abc*", "def*"), c("abcd", "defa"), valuetype = "glob"), c("abcd", "defa")) +## expect_equal(resolve_keywords(c("a"), c("abcd", "defa"), valuetype = "regex"), c("abcd", "defa")) +## }) test_that("count_from", { suppressPackageStartupMessages(library(quanteda)) @@ -36,11 +36,11 @@ test_that("convert no strange rownames, #39", { expect_equal(rownames(res), c("1", "2", "3", "4")) ## default rownames }) -test_that("Changing keywords", { +test_that("Changing pattern", { suppressPackageStartupMessages(library(quanteda)) "this is my life" %>% tokens() %>% tokens_proximity("my") -> res expect_error(res2 <- tokens_proximity(res, "life"), NA) - expect_equal(meta(res2, "keywords"), "life") + expect_equal(meta(res2, "pattern"), "life") }) test_that("token_proximity() only emit token_proximity #35", {