Skip to content

Commit

Permalink
Java JNI for Multiple contains (#17281)
Browse files Browse the repository at this point in the history
This is Java JNI interface for [multiple contains PR](#16900)

Authors:
  - Chong Gao (https://github.com/res-life)

Approvers:
  - Alessandro Bellina (https://github.com/abellina)
  - Robert (Bobby) Evans (https://github.com/revans2)

URL: #17281
  • Loading branch information
res-life authored Nov 14, 2024
1 parent 5d5b35d commit 4cd40ee
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 0 deletions.
37 changes: 37 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -3332,6 +3332,36 @@ public final ColumnVector stringContains(Scalar compString) {
return new ColumnVector(stringContains(getNativeView(), compString.getScalarHandle()));
}

/**
* @brief Searches for the given target strings within each string in the provided column
*
* Each column in the result table corresponds to the result for the target string at the same
* ordinal. i.e. 0th column is the BOOL8 column result for the 0th target string, 1th for 1th,
* etc.
*
* If the target is not found for a string, false is returned for that entry in the output column.
* If the target is an empty string, true is returned for all non-null entries in the output column.
*
* Any null input strings return corresponding null entries in the output columns.
*
* input = ["a", "b", "c"]
* targets = ["a", "c"]
* output is a table with two boolean columns:
* column 0: [true, false, false]
* column 1: [false, false, true]
*
* @param targets UTF-8 encoded strings to search for in each string in `input`
* @return BOOL8 columns
*/
public final ColumnVector[] stringContains(ColumnView targets) {
assert type.equals(DType.STRING) : "column type must be a String";
assert targets.getType().equals(DType.STRING) : "targets type must be a string";
assert targets.getNullCount() == 0 : "targets must not contain nulls";
assert targets.getRowCount() > 0 : "targets must not be empty";
long[] resultPointers = stringContainsMulti(getNativeView(), targets.getNativeView());
return Arrays.stream(resultPointers).mapToObj(ColumnVector::new).toArray(ColumnVector[]::new);
}

/**
* Replaces values less than `lo` in `input` with `lo`,
* and values greater than `hi` with `hi`.
Expand Down Expand Up @@ -4437,6 +4467,13 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat
*/
private static native long stringContains(long cudfViewHandle, long compString) throws CudfException;

/**
* Native method for searching for the given target strings within each string in the provided column.
* @param cudfViewHandle native handle of the cudf::column_view being operated on.
* @param targetViewHandle handle of the column view containing the strings being searched for.
*/
private static native long[] stringContainsMulti(long cudfViewHandle, long targetViewHandle) throws CudfException;

/**
* Native method for extracting results from a regex program pattern. Returns a table handle.
*
Expand Down
20 changes: 20 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include <cudf/strings/convert/convert_urls.hpp>
#include <cudf/strings/extract.hpp>
#include <cudf/strings/find.hpp>
#include <cudf/strings/find_multiple.hpp>
#include <cudf/strings/findall.hpp>
#include <cudf/strings/padding.hpp>
#include <cudf/strings/regex/regex_program.hpp>
Expand Down Expand Up @@ -2827,4 +2828,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_toHex(JNIEnv* env, jclass
}
CATCH_STD(env, 0);
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringContainsMulti(
JNIEnv* env, jobject j_object, jlong j_view_handle, jlong j_target_view_handle)
{
JNI_NULL_CHECK(env, j_view_handle, "column is null", 0);
JNI_NULL_CHECK(env, j_target_view_handle, "targets is null", 0);

try {
cudf::jni::auto_set_device(env);
auto* column_view = reinterpret_cast<cudf::column_view*>(j_view_handle);
auto* targets_view = reinterpret_cast<cudf::column_view*>(j_target_view_handle);
auto const strings_column = cudf::strings_column_view(*column_view);
auto const targets_column = cudf::strings_column_view(*targets_view);
auto contains_results = cudf::strings::contains_multiple(strings_column, targets_column);
return cudf::jni::convert_table_for_return(env, std::move(contains_results));
}
CATCH_STD(env, 0);
}

} // extern "C"
24 changes: 24 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3828,6 +3828,30 @@ void testStringOpsEmpty() {
}
}

@Test
void testStringContainsMulti() {
ColumnVector[] results = null;
try (ColumnVector haystack = ColumnVector.fromStrings("tést strings",
"Héllo cd",
"1 43 42 7",
"scala spark 42 other",
null,
"");
ColumnVector targets = ColumnVector.fromStrings("é", "42");
ColumnVector expected0 = ColumnVector.fromBoxedBooleans(true, true, false, false, null, false);
ColumnVector expected1 = ColumnVector.fromBoxedBooleans(false, false, true, true, null, false)) {
results = haystack.stringContains(targets);
assertColumnsAreEqual(results[0], expected0);
assertColumnsAreEqual(results[1], expected1);
} finally {
if (results != null) {
for (ColumnVector c : results) {
c.close();
}
}
}
}

@Test
void testStringFindOperations() {
try (ColumnVector testStrings = ColumnVector.fromStrings("", null, "abCD", "1a\"\u0100B1", "a\"\u0100B1", "1a\"\u0100B",
Expand Down

0 comments on commit 4cd40ee

Please sign in to comment.