Skip to content
This repository has been archived by the owner on Feb 11, 2024. It is now read-only.

Commit

Permalink
Use quanteda::index() ref #38
Browse files Browse the repository at this point in the history
  • Loading branch information
chainsawriot committed Nov 21, 2023
1 parent f324e6a commit bf530b4
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 38 deletions.
3 changes: 1 addition & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,5 @@ Suggests:
Config/testthat/edition: 3
Imports:
quanteda,
Matrix,
utils
Matrix
VignetteBuilder: knitr
63 changes: 37 additions & 26 deletions R/get_dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,51 @@ row_mins_c <- function(mat) {
.Call("row_mins_", mat, as.integer(nrow(mat)), as.integer(ncol(mat)))
}

cal_dist <- function(y, poss) {
return(abs(y - poss))
cal_dist <- function(from, to, poss) {
return(pmin(abs(to - poss), abs(from - poss)))
}

cal_proximity <- function(tokenized_text, keywords_poss, get_min = TRUE, count_from = 1) {
target_idx <- which(tokenized_text %in% keywords_poss)
poss <- seq_along(tokenized_text)
if (length(target_idx) == 0) {
cal_proximity <- function(tokenized_text, pattern, get_min = TRUE, count_from = 1, valuetype) {
## target_idx <- which(tokenized_text %in% keywords_poss)
poss <- seq_along(as.character(tokenized_text))
idx <- quanteda::index(tokenized_text, pattern, valuetype = valuetype)
if (nrow(idx) == 0) {
return(rep(length(poss) + count_from, length(poss)))
}
res <- sapply(target_idx, cal_dist, poss = poss)
res <- mapply(cal_dist, from = idx$from, to = idx$to, MoreArgs = list("poss" = poss))
if (get_min) {
return(row_mins_c(res) + count_from)
}
return(res)
}

get_proximity <- function(x, keywords, get_min = TRUE, count_from = 1) {
keywords_poss <- which(attr(x, "types") %in% keywords)
return(lapply(unclass(x), cal_proximity, keywords_poss = keywords_poss, get_min = get_min, count_from = count_from))
}

resolve_keywords <- function(keywords, features, valuetype) {
if (valuetype == "fixed") {
return(keywords)
}
if (valuetype == "glob") {
regex <- paste(utils::glob2rx(keywords), collapse = "|")
get_proximity <- function(x, pattern, get_min = TRUE, count_from = 1, valuetype) {
output <- list()
for (i in seq_along(x)) {
output[[i]] <- cal_proximity(x[i], pattern = pattern, get_min = get_min, count_from = count_from, valuetype = valuetype)
}
if (valuetype == "regex") {
regex <- paste(keywords, collapse = "|")
}
return(grep(regex, features, value = TRUE))
names(output) <- quanteda::docnames(x)
return(output)
}

## resolve_keywords <- function(keywords, features, valuetype) {
## if (valuetype == "fixed") {
## return(keywords)
## }
## if (valuetype == "glob") {
## regex <- paste(utils::glob2rx(keywords), collapse = "|")
## }
## if (valuetype == "regex") {
## regex <- paste(keywords, collapse = "|")
## }
## return(grep(regex, features, value = TRUE))
## res <- quanteda::pattern2fixed(pattern = keywords, types = features, valuetype = valuetype)
## if (is.null(res)) {
## return(list())
## }
## return(res)
## }

#' Extract Proximity Information
#'
#' This function extracts distance information from a [quanteda::tokens()] object.
Expand Down Expand Up @@ -91,11 +101,12 @@ tokens_proximity <- function(x, pattern, get_min = TRUE, valuetype = c("glob", "
x <- quanteda::tokens_tolower(x, keep_acronyms = keep_acronyms)
}
valuetype <- match.arg(valuetype)
keywords <- resolve_keywords(pattern, attr(x, "types"), valuetype)
## Maybe this is now only for pretty print?
## keywords <- resolve_keywords(pattern, attr(x, "types"), valuetype)
toks <- x
proximity <- get_proximity(x = toks, keywords = keywords, get_min = get_min, count_from = count_from)
proximity <- get_proximity(x = toks, pattern = pattern, get_min = get_min, count_from = count_from, valuetype = valuetype)
quanteda::docvars(toks)$proximity <- I(proximity)
quanteda::meta(toks, field = "keywords") <- keywords
quanteda::meta(toks, field = "pattern") <- pattern
quanteda::meta(toks, field = "get_min") <- get_min
quanteda::meta(toks, field = "tolower") <- tolower
quanteda::meta(toks, field = "keep_acronyms") <- keep_acronyms
Expand All @@ -116,7 +127,7 @@ convert_df <- function(tokens_obj, proximity_obj, doc_id) {
print.tokens_with_proximity <- function(x, ...) {
print(as.tokens(x), ...)
cat("With proximity vector(s).\n")
cat("keywords: ", quanteda::meta(x, field = "keywords"), "\n")
cat("Pattern: ", quanteda::meta(x, field = "pattern"), "\n")
}

#' @importFrom quanteda as.tokens
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ tok1
#> [ ... and 31 more ]
#>
#> With proximity vector(s).
#> keywords: turkish
#> Pattern: turkish
```

You can access the proximity vectors by
Expand Down Expand Up @@ -150,7 +150,7 @@ tok2
#> [ ... and 31 more ]
#>
#> With proximity vector(s).
#> keywords: hamas
#> Pattern: hamas
```

``` r
Expand All @@ -177,7 +177,7 @@ tok3
#> [ ... and 31 more ]
#>
#> With proximity vector(s).
#> keywords: eu brussels
#> Pattern: eu brussels
```

``` r
Expand Down
14 changes: 7 additions & 7 deletions tests/testthat/test-tokens_dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ test_that("edge cases", {
expect_error("" %>% tokens() %>% tokens_proximity("") %>% convert(), NA)
})

test_that("resolve_keywords", {
expect_equal(resolve_keywords(c("abc", "def"), c("abcd", "defa"), valuetype = "fixed"), c("abc", "def"))
expect_equal(resolve_keywords(c("abc*", "def*"), c("abcd", "defa"), valuetype = "glob"), c("abcd", "defa"))
expect_equal(resolve_keywords(c("a"), c("abcd", "defa"), valuetype = "regex"), c("abcd", "defa"))
})
## test_that("resolve_keywords", {
## expect_equal(resolve_keywords(c("abc", "def"), c("abcd", "defa"), valuetype = "fixed"), character(0))
## expect_equal(resolve_keywords(c("abc*", "def*"), c("abcd", "defa"), valuetype = "glob"), c("abcd", "defa"))
## expect_equal(resolve_keywords(c("a"), c("abcd", "defa"), valuetype = "regex"), c("abcd", "defa"))
## })

test_that("count_from", {
suppressPackageStartupMessages(library(quanteda))
Expand All @@ -36,11 +36,11 @@ test_that("convert no strange rownames, #39", {
expect_equal(rownames(res), c("1", "2", "3", "4")) ## default rownames
})

test_that("Changing keywords", {
test_that("Changing pattern", {
suppressPackageStartupMessages(library(quanteda))
"this is my life" %>% tokens() %>% tokens_proximity("my") -> res
expect_error(res2 <- tokens_proximity(res, "life"), NA)
expect_equal(meta(res2, "keywords"), "life")
expect_equal(meta(res2, "pattern"), "life")
})

test_that("token_proximity() only emit token_proximity #35", {
Expand Down

0 comments on commit bf530b4

Please sign in to comment.